Support complex types when converting HLO compare op (EQ/NE).
We can lower it to the EqualOp / NotEqualOp in the complex dialect. PiperOrigin-RevId: 375655092
This commit is contained in:
parent
5504f82f11
commit
758ae7da6b
|
@ -279,6 +279,16 @@ inline Value MapCompareOpToStdScalarOp(Location loc,
|
|||
return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||
rhs);
|
||||
}
|
||||
if (auto complex_type = element_type.dyn_cast<ComplexType>()) {
|
||||
if (complex_type.getElementType().isa<FloatType>()) {
|
||||
if (comparison_direction == "EQ") {
|
||||
return b->create<complex::EqualOp>(loc, lhs, rhs);
|
||||
}
|
||||
if (comparison_direction == "NE") {
|
||||
return b->create<complex::NotEqualOp>(loc, lhs, rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -188,7 +188,7 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
|||
|
||||
// CHECK-LABEL: func @int_cmp
|
||||
func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||
%result: memref<2x2xi1>) {
|
||||
%result: memref<2x2xi1>) {
|
||||
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
|
||||
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> ()
|
||||
return
|
||||
|
@ -200,6 +200,34 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_cmp_eq
|
||||
func @complex_cmp_eq(%lhs: memref<2xcomplex<f32>>, %rhs: memref<2xcomplex<f32>>,
|
||||
%result: memref<2xi1>) {
|
||||
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
|
||||
: (memref<2xcomplex<f32>>, memref<2xcomplex<f32>>, memref<2xi1>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f32>, %[[RHS_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]: i1):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = complex.eq %[[LHS_IN]], %[[RHS_IN]] : complex<f32>
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_cmp_neq
|
||||
func @complex_cmp_neq(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
|
||||
%result: memref<2xi1>) {
|
||||
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "NE"}
|
||||
: (memref<2xcomplex<f64>>, memref<2xcomplex<f64>>, memref<2xi1>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: i1):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = complex.neq %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
|
Loading…
Reference in New Issue