Support complex types when converting HLO multiply op.
We can lower it to the MulOp in the complex dialect. PiperOrigin-RevId: 375675079
This commit is contained in:
parent
5816920258
commit
a847109ac7
|
@ -65,6 +65,7 @@ struct LhloToScalarOp<lmhlo::MulOp> {
|
|||
using FOp = ::mlir::MulFOp;
|
||||
using IOp = ::mlir::MulIOp;
|
||||
using UOp = ::mlir::MulIOp;
|
||||
using COp = ::mlir::complex::MulOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<lmhlo::RemOp> {
|
||||
|
@ -631,6 +632,19 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
|
|||
args, loc, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::MulOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::MulOp>,
|
||||
isUnsignedIntegerType, ScalarUOp<lmhlo::MulOp>,
|
||||
isFloatType, ScalarFOp<lmhlo::MulOp>,
|
||||
isComplexType, ScalarCOp<lmhlo::MulOp>>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
|
|
|
@ -65,6 +65,19 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_mul
|
||||
func @complex_mul(%lhs: tensor<2x2xcomplex<f32>>,
|
||||
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: complex.mul
|
||||
%0 = "mhlo.multiply"(%lhs, %rhs)
|
||||
: (tensor<2x2xcomplex<f32>>, tensor<2x2xcomplex<f32>>)
|
||||
-> tensor<2x2xcomplex<f32>>
|
||||
return %0 : tensor<2x2xcomplex<f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_remainder
|
||||
func @float_remainder(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
|
|
@ -242,6 +242,20 @@ func @complex_divide(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_multiply
|
||||
func @complex_multiply(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
|
||||
%result: memref<2xcomplex<f64>>) {
|
||||
"lmhlo.multiply"(%lhs, %rhs, %result)
|
||||
: (memref<2xcomplex<f64>>, memref<2xcomplex<f64>>, memref<2xcomplex<f64>>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: complex<f64>):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = complex.mul %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f64>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
|
Loading…
Reference in New Issue