Support complex types when converting HLO multiply op.

We can lower it to the MulOp in the complex dialect.

PiperOrigin-RevId: 375675079
This commit is contained in:
Adrian Kuegel 2021-05-25 04:34:24 -07:00 committed by TensorFlow MLIR Team
parent 5816920258
commit a847109ac7
3 changed files with 41 additions and 0 deletions

View File

@ -65,6 +65,7 @@ struct LhloToScalarOp<lmhlo::MulOp> {
using FOp = ::mlir::MulFOp;
using IOp = ::mlir::MulIOp;
using UOp = ::mlir::MulIOp;
using COp = ::mlir::complex::MulOp;
};
template <>
struct LhloToScalarOp<lmhlo::RemOp> {
@ -631,6 +632,19 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
args, loc, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::MulOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Type> arg_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::MulOp>,
isUnsignedIntegerType, ScalarUOp<lmhlo::MulOp>,
isFloatType, ScalarFOp<lmhlo::MulOp>,
isComplexType, ScalarCOp<lmhlo::MulOp>>{}(
loc, result_types, arg_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
ArrayRef<Type> result_types,

View File

@ -65,6 +65,19 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
// -----
// CHECK-LABEL: func @complex_mul
func @complex_mul(%lhs: tensor<2x2xcomplex<f32>>,
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
// CHECK: linalg.generic
// CHECK: complex.mul
%0 = "mhlo.multiply"(%lhs, %rhs)
: (tensor<2x2xcomplex<f32>>, tensor<2x2xcomplex<f32>>)
-> tensor<2x2xcomplex<f32>>
return %0 : tensor<2x2xcomplex<f32>>
}
// -----
// CHECK-LABEL: func @float_remainder
func @float_remainder(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {

View File

@ -242,6 +242,20 @@ func @complex_divide(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
// -----
// CHECK-LABEL: func @complex_multiply
func @complex_multiply(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
%result: memref<2xcomplex<f64>>) {
"lmhlo.multiply"(%lhs, %rhs, %result)
: (memref<2xcomplex<f64>>, memref<2xcomplex<f64>>, memref<2xcomplex<f64>>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: complex<f64>):
// CHECK-NEXT: %[[RESULT:.*]] = complex.mul %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f64>
// -----
// CHECK-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {