Support complex types when converting HLO divide op.

We can lower it to the DivOp in the complex dialect.
Also add tests to hlo-legalize-to-linalg.mlir for CompareOp lowering of complex
types. These were forgotten in a previous commit.

PiperOrigin-RevId: 375669125
This commit is contained in:
Adrian Kuegel 2021-05-25 03:42:25 -07:00 committed by TensorFlow MLIR Team
parent 8e28008e38
commit 5816920258
3 changed files with 72 additions and 2 deletions

View File

@ -58,6 +58,7 @@ struct LhloToScalarOp<lmhlo::DivOp> {
using FOp = ::mlir::DivFOp;
using IOp = ::mlir::SignedDivIOp;
using UOp = ::mlir::UnsignedDivIOp;
using COp = ::mlir::complex::DivOp;
};
template <>
struct LhloToScalarOp<lmhlo::MulOp> {
@ -192,6 +193,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
ArrayRef<Type> result_types,
@ -301,6 +303,19 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
return args.front();
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::DivOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Type> arg_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::DivOp>,
isUnsignedIntegerType, ScalarUOp<lmhlo::DivOp>,
isFloatType, ScalarFOp<lmhlo::DivOp>,
isComplexType, ScalarCOp<lmhlo::DivOp>>{}(
loc, result_types, arg_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
ArrayRef<Type> result_types,

View File

@ -323,6 +323,36 @@ func @int_cmp(%lhs: tensor<2x2xi32>,
// -----
// CHECK-LABEL: func @complex_cmp_eq
func @complex_cmp_eq(%lhs: tensor<2xcomplex<f32>>,
%rhs: tensor<2xcomplex<f32>>) -> tensor<2xi1> {
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
: (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xi1>)
return %0 : tensor<2xi1>
}
// CHECK: linalg.init_tensor [2] : tensor<2xi1>
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f32>, %[[RHS_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]: i1):
// CHECK-NEXT: %[[RESULT:.*]] = complex.eq %[[LHS_IN]], %[[RHS_IN]] : complex<f32>
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @complex_cmp_neq
func @complex_cmp_neq(%lhs: tensor<2xcomplex<f64>>,
%rhs: tensor<2xcomplex<f64>>) -> tensor<2xi1> {
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"}
: (tensor<2xcomplex<f64>>, tensor<2xcomplex<f64>>) -> (tensor<2xi1>)
return %0 : tensor<2xi1>
}
// CHECK: linalg.init_tensor [2] : tensor<2xi1>
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: i1):
// CHECK-NEXT: %[[RESULT:.*]] = complex.neq %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @float_cos
func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
@ -2352,6 +2382,17 @@ func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<
// -----
// CHECK-LABEL: complex_divide
func @complex_divide(%lhs: tensor<2xcomplex<f32>>,
%rhs: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
// CHECK: linalg.generic
// CHECK: complex.div
%0 = "mhlo.divide"(%lhs, %rhs) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
return %0 : tensor<2xcomplex<f32>>
}
// -----
// CHECK-LABEL: unsigned_remainder
func @unsigned_remainder(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
// CHECK: linalg.generic

View File

@ -158,8 +158,8 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
// -----
// CHECK-LABEL: func @is_finte
func @is_finte(%input: memref<2x2xf32>, %result: memref<2x2xi1>) {
// CHECK-LABEL: func @is_finite
func @is_finite(%input: memref<2x2xf32>, %result: memref<2x2xi1>) {
"lmhlo.is_finite"(%input, %result) : (memref<2x2xf32>, memref<2x2xi1>) -> ()
return
}
@ -228,6 +228,20 @@ func @complex_cmp_neq(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>
// -----
// CHECK-LABEL: func @complex_divide
func @complex_divide(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
%result: memref<2xcomplex<f64>>) {
"lmhlo.divide"(%lhs, %rhs, %result)
: (memref<2xcomplex<f64>>, memref<2xcomplex<f64>>, memref<2xcomplex<f64>>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: complex<f64>):
// CHECK-NEXT: %[[RESULT:.*]] = complex.div %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f64>
// -----
// CHECK-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {