Support complex types when converting HLO multiply op.

We can lower it to the MulOp in the complex dialect.

PiperOrigin-RevId: 375675079
This commit is contained in:
Adrian Kuegel 2021-05-25 04:34:24 -07:00 committed by TensorFlow MLIR Team
parent 5816920258
commit a847109ac7
3 changed files with 41 additions and 0 deletions

View File

@ -65,6 +65,7 @@ struct LhloToScalarOp<lmhlo::MulOp> {
using FOp = ::mlir::MulFOp; using FOp = ::mlir::MulFOp;
using IOp = ::mlir::MulIOp; using IOp = ::mlir::MulIOp;
using UOp = ::mlir::MulIOp; using UOp = ::mlir::MulIOp;
using COp = ::mlir::complex::MulOp;
}; };
template <> template <>
struct LhloToScalarOp<lmhlo::RemOp> { struct LhloToScalarOp<lmhlo::RemOp> {
@ -631,6 +632,19 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
args, loc, b); args, loc, b);
} }
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::MulOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Type> arg_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::MulOp>,
isUnsignedIntegerType, ScalarUOp<lmhlo::MulOp>,
isFloatType, ScalarFOp<lmhlo::MulOp>,
isComplexType, ScalarCOp<lmhlo::MulOp>>{}(
loc, result_types, arg_types, args, b);
}
template <> template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc, inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,

View File

@ -65,6 +65,19 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
// ----- // -----
// CHECK-LABEL: func @complex_mul
func @complex_mul(%lhs: tensor<2x2xcomplex<f32>>,
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
// CHECK: linalg.generic
// CHECK: complex.mul
%0 = "mhlo.multiply"(%lhs, %rhs)
: (tensor<2x2xcomplex<f32>>, tensor<2x2xcomplex<f32>>)
-> tensor<2x2xcomplex<f32>>
return %0 : tensor<2x2xcomplex<f32>>
}
// -----
// CHECK-LABEL: func @float_remainder // CHECK-LABEL: func @float_remainder
func @float_remainder(%lhs: tensor<2x2xf32>, func @float_remainder(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {

View File

@ -242,6 +242,20 @@ func @complex_divide(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
// ----- // -----
// CHECK-LABEL: func @complex_multiply
func @complex_multiply(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
%result: memref<2xcomplex<f64>>) {
"lmhlo.multiply"(%lhs, %rhs, %result)
: (memref<2xcomplex<f64>>, memref<2xcomplex<f64>>, memref<2xcomplex<f64>>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: complex<f64>):
// CHECK-NEXT: %[[RESULT:.*]] = complex.mul %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f64>
// -----
// CHECK-LABEL: func @select // CHECK-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {