Implement InferShapedTypeOpInterface for mhlo.complex
Binary companion for 8bcd33e4b7
PiperOrigin-RevId: 334651523
			
			
This commit is contained in:
		
							parent
							
								
									019c5ef106
								
							
						
					
					
						commit
						dfe64d3958
					
				|  | @ -194,7 +194,8 @@ def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", | |||
|     [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; | ||||
| 
 | ||||
| def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag", | ||||
|     [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>], | ||||
|     [NoSideEffect, SameOperandsAndResultShape, | ||||
|      DeclareOpInterfaceMethods<InferTypeOpInterface>], | ||||
|     HLO_ComplexTensor>, BASE_HLO_ImagOp { | ||||
|   let results = (outs HLO_FpTensor); | ||||
|   let hasFolder = 1; | ||||
|  | @ -235,7 +236,8 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", | |||
|     BASE_HLO_PopulationCountOp; | ||||
| 
 | ||||
| def HLO_RealOp: HLO_UnaryElementwiseOp<"real", | ||||
|     [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>], | ||||
|     [NoSideEffect, SameOperandsAndResultShape, | ||||
|      DeclareOpInterfaceMethods<InferTypeOpInterface>], | ||||
|     HLO_ComplexTensor>, BASE_HLO_RealOp { | ||||
|   let results = (outs HLO_FpTensor); | ||||
|   let hasFolder = 1; | ||||
|  | @ -315,12 +317,10 @@ def HLO_AddOp : HLO_BinaryElementwiseOp<"add", | |||
| def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", | ||||
|       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; | ||||
| 
 | ||||
| def HLO_ComplexOp: HLO_Op<"complex", | ||||
|     [NoSideEffect, SameOperandsAndResultShape]>, | ||||
| def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex", | ||||
|     [NoSideEffect, SameOperandsAndResultShape, | ||||
|      DeclareOpInterfaceMethods<InferTypeOpInterface>]>, | ||||
|     BASE_HLO_ComplexOp { | ||||
|   let builders = [OpBuilder< | ||||
|     "OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">]; | ||||
| 
 | ||||
|   let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); | ||||
|   let results = (outs HLO_ComplexTensor); | ||||
|   let hasFolder = 1; | ||||
|  |  | |||
|  | @ -889,9 +889,10 @@ static LogicalResult Verify(ClampOp op) { | |||
| // ComplexOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, | ||||
|                       Value rhs) { | ||||
|   auto type = lhs.getType(); | ||||
| LogicalResult ComplexOp::inferReturnTypes( | ||||
|     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr, | ||||
|     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) { | ||||
|   auto type = operands[0].getType(); | ||||
|   auto element_ty = ComplexType::get(getElementTypeOrSelf(type)); | ||||
|   Type result_ty; | ||||
|   if (auto ranked_type = type.dyn_cast<RankedTensorType>()) { | ||||
|  | @ -901,8 +902,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, | |||
|   } else { | ||||
|     result_ty = element_ty; | ||||
|   } | ||||
| 
 | ||||
|   build(builder, state, result_ty, lhs, rhs); | ||||
|   inferredReturnTypes.push_back(result_ty); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) { | ||||
|  |  | |||
|  | @ -236,6 +236,21 @@ func @complex(%real: memref<2x2xf32>, | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // BOTH-LABEL: func @complex_dyn | ||||
| func @complex_dyn(%real: memref<?xf32>, | ||||
|                   %imag: memref<?xf32>, | ||||
|                   %result: memref<?xcomplex<f32>>) { | ||||
|   %tensor_real = tensor_load %real : memref<?xf32> | ||||
|   %tensor_imag = tensor_load %imag : memref<?xf32> | ||||
|   %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag) | ||||
|       : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xcomplex<f32>> | ||||
|   // BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}}) | ||||
|   tensor_store %tensor_result, %result : memref<?xcomplex<f32>> | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // BOTH-LABEL: func @real | ||||
| func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) { | ||||
|   %tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>> | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue