Support complex types when converting HLO compare op (EQ/NE).
We can lower it to the EqualOp / NotEqualOp in the complex dialect. PiperOrigin-RevId: 375655092
This commit is contained in:
parent
5504f82f11
commit
758ae7da6b
|
@ -279,6 +279,16 @@ inline Value MapCompareOpToStdScalarOp(Location loc,
|
||||||
return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
|
return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||||
rhs);
|
rhs);
|
||||||
}
|
}
|
||||||
|
if (auto complex_type = element_type.dyn_cast<ComplexType>()) {
|
||||||
|
if (complex_type.getElementType().isa<FloatType>()) {
|
||||||
|
if (comparison_direction == "EQ") {
|
||||||
|
return b->create<complex::EqualOp>(loc, lhs, rhs);
|
||||||
|
}
|
||||||
|
if (comparison_direction == "NE") {
|
||||||
|
return b->create<complex::NotEqualOp>(loc, lhs, rhs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -188,7 +188,7 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||||
|
|
||||||
// CHECK-LABEL: func @int_cmp
|
// CHECK-LABEL: func @int_cmp
|
||||||
func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||||
%result: memref<2x2xi1>) {
|
%result: memref<2x2xi1>) {
|
||||||
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
|
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
|
||||||
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> ()
|
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> ()
|
||||||
return
|
return
|
||||||
|
@ -200,6 +200,34 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @complex_cmp_eq
|
||||||
|
func @complex_cmp_eq(%lhs: memref<2xcomplex<f32>>, %rhs: memref<2xcomplex<f32>>,
|
||||||
|
%result: memref<2xi1>) {
|
||||||
|
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
|
||||||
|
: (memref<2xcomplex<f32>>, memref<2xcomplex<f32>>, memref<2xi1>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f32>, %[[RHS_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]: i1):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = complex.eq %[[LHS_IN]], %[[RHS_IN]] : complex<f32>
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @complex_cmp_neq
|
||||||
|
func @complex_cmp_neq(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
|
||||||
|
%result: memref<2xi1>) {
|
||||||
|
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "NE"}
|
||||||
|
: (memref<2xcomplex<f64>>, memref<2xcomplex<f64>>, memref<2xi1>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: i1):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = complex.neq %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @select
|
// CHECK-LABEL: func @select
|
||||||
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||||
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
|
Loading…
Reference in New Issue