Support complex types when converting HLO compare op (EQ/NE).

We can lower it to the EqualOp / NotEqualOp in the complex dialect.

PiperOrigin-RevId: 375655092
This commit is contained in:
Adrian Kuegel 2021-05-25 01:53:00 -07:00 committed by TensorFlow MLIR Team
parent 5504f82f11
commit 758ae7da6b
2 changed files with 39 additions and 1 deletions

View File

@ -279,6 +279,16 @@ inline Value MapCompareOpToStdScalarOp(Location loc,
return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
rhs);
}
if (auto complex_type = element_type.dyn_cast<ComplexType>()) {
if (complex_type.getElementType().isa<FloatType>()) {
if (comparison_direction == "EQ") {
return b->create<complex::EqualOp>(loc, lhs, rhs);
}
if (comparison_direction == "NE") {
return b->create<complex::NotEqualOp>(loc, lhs, rhs);
}
}
}
return nullptr;
}

View File

@ -188,7 +188,7 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
// CHECK-LABEL: func @int_cmp
func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi1>) {
%result: memref<2x2xi1>) {
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> ()
return
@ -200,6 +200,34 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
// -----
// CHECK-LABEL: func @complex_cmp_eq
func @complex_cmp_eq(%lhs: memref<2xcomplex<f32>>, %rhs: memref<2xcomplex<f32>>,
%result: memref<2xi1>) {
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
: (memref<2xcomplex<f32>>, memref<2xcomplex<f32>>, memref<2xi1>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f32>, %[[RHS_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]: i1):
// CHECK-NEXT: %[[RESULT:.*]] = complex.eq %[[LHS_IN]], %[[RHS_IN]] : complex<f32>
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @complex_cmp_neq
func @complex_cmp_neq(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
%result: memref<2xi1>) {
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "NE"}
: (memref<2xcomplex<f64>>, memref<2xcomplex<f64>>, memref<2xi1>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: i1):
// CHECK-NEXT: %[[RESULT:.*]] = complex.neq %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {