Support complex types when converting HLO multiply op.
We can lower it to the MulOp in the complex dialect. PiperOrigin-RevId: 375675079
This commit is contained in:
		
							parent
							
								
									5816920258
								
							
						
					
					
						commit
						a847109ac7
					
				| 
						 | 
					@ -65,6 +65,7 @@ struct LhloToScalarOp<lmhlo::MulOp> {
 | 
				
			||||||
  using FOp = ::mlir::MulFOp;
 | 
					  using FOp = ::mlir::MulFOp;
 | 
				
			||||||
  using IOp = ::mlir::MulIOp;
 | 
					  using IOp = ::mlir::MulIOp;
 | 
				
			||||||
  using UOp = ::mlir::MulIOp;
 | 
					  using UOp = ::mlir::MulIOp;
 | 
				
			||||||
 | 
					  using COp = ::mlir::complex::MulOp;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
struct LhloToScalarOp<lmhlo::RemOp> {
 | 
					struct LhloToScalarOp<lmhlo::RemOp> {
 | 
				
			||||||
| 
						 | 
					@ -631,6 +632,19 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
 | 
				
			||||||
      args, loc, b);
 | 
					      args, loc, b);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					inline Value MapLhloOpToStdScalarOp<lmhlo::MulOp>(Location loc,
 | 
				
			||||||
 | 
					                                                  ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                                  ArrayRef<Type> arg_types,
 | 
				
			||||||
 | 
					                                                  ArrayRef<Value> args,
 | 
				
			||||||
 | 
					                                                  OpBuilder* b) {
 | 
				
			||||||
 | 
					  return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::MulOp>,
 | 
				
			||||||
 | 
					                                 isUnsignedIntegerType, ScalarUOp<lmhlo::MulOp>,
 | 
				
			||||||
 | 
					                                 isFloatType, ScalarFOp<lmhlo::MulOp>,
 | 
				
			||||||
 | 
					                                 isComplexType, ScalarCOp<lmhlo::MulOp>>{}(
 | 
				
			||||||
 | 
					      loc, result_types, arg_types, args, b);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
 | 
					inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
 | 
				
			||||||
                                                    ArrayRef<Type> result_types,
 | 
					                                                    ArrayRef<Type> result_types,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -65,6 +65,19 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @complex_mul
 | 
				
			||||||
 | 
					func @complex_mul(%lhs: tensor<2x2xcomplex<f32>>,
 | 
				
			||||||
 | 
					                  %rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
 | 
				
			||||||
 | 
					  // CHECK: linalg.generic
 | 
				
			||||||
 | 
					  // CHECK: complex.mul
 | 
				
			||||||
 | 
					  %0 = "mhlo.multiply"(%lhs, %rhs)
 | 
				
			||||||
 | 
					          : (tensor<2x2xcomplex<f32>>, tensor<2x2xcomplex<f32>>)
 | 
				
			||||||
 | 
					          -> tensor<2x2xcomplex<f32>>
 | 
				
			||||||
 | 
					  return %0 : tensor<2x2xcomplex<f32>>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: func @float_remainder
 | 
					// CHECK-LABEL: func @float_remainder
 | 
				
			||||||
func @float_remainder(%lhs: tensor<2x2xf32>,
 | 
					func @float_remainder(%lhs: tensor<2x2xf32>,
 | 
				
			||||||
                      %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
					                      %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -242,6 +242,20 @@ func @complex_divide(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @complex_multiply
 | 
				
			||||||
 | 
					func @complex_multiply(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
 | 
				
			||||||
 | 
					                       %result: memref<2xcomplex<f64>>) {
 | 
				
			||||||
 | 
					  "lmhlo.multiply"(%lhs, %rhs, %result)
 | 
				
			||||||
 | 
					      : (memref<2xcomplex<f64>>, memref<2xcomplex<f64>>, memref<2xcomplex<f64>>) -> ()
 | 
				
			||||||
 | 
					  return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
 | 
					// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: complex<f64>):
 | 
				
			||||||
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = complex.mul %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
 | 
				
			||||||
 | 
					// CHECK-NEXT:   linalg.yield %[[RESULT]] : complex<f64>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: func @select
 | 
					// CHECK-LABEL: func @select
 | 
				
			||||||
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
 | 
					func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
 | 
				
			||||||
             %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
					             %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue