Generate Equal and NotEqual kernels for complex types.

PiperOrigin-RevId: 368586877
This commit is contained in:
Adrian Kuegel 2021-04-15 00:34:29 -07:00 committed by TensorFlow MLIR Team
parent bfc8cca38f
commit db9f298505
2 changed files with 40 additions and 0 deletions

View File

@ -87,3 +87,11 @@ def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
(HLO_CosOp (HLO_ImagOp:$imag $val)),
(HLO_ExpOp:$exp (HLO_RealOp:$real $val))),
(HLO_MulOp (HLO_SinOp $imag), $exp))>;
foreach pair = [[HLO_COMPARISON_DIRECTION_NE, HLO_OrOp],
[HLO_COMPARISON_DIRECTION_EQ, HLO_AndOp]] in {
def : Pat<(HLO_CompareOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, pair[0], $compare_type),
(pair[1]
(HLO_CompareOp (HLO_RealOp $lhs), (HLO_RealOp $rhs), pair[0], $compare_type),
(HLO_CompareOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs), pair[0], $compare_type))>;
}

View File

@ -238,3 +238,35 @@ func @exp_unranked(%arg0 : tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>) {
// CHECK: [[OUT]]
return %0 : tensor<*xcomplex<f32>>
}
// CHECK-LABEL: @compare_eq
// CHECK: ([[LHS:%.+]]: tensor<2xcomplex<f32>>, [[RHS:%.+]]: tensor<2xcomplex<f32>>)
func @compare_eq(%lhs : tensor<2xcomplex<f32>>, %rhs: tensor<2xcomplex<f32>>) -> (tensor<2xi1>) {
// CHECK-DAG: [[REAL_LHS:%.+]] = "mhlo.real"([[LHS]])
// CHECK-DAG: [[REAL_RHS:%.+]] = "mhlo.real"([[RHS]])
// CHECK-DAG: [[OUTR:%.+]] = "mhlo.compare"([[REAL_LHS]], [[REAL_RHS]]) {comparison_direction = "EQ"}
// CHECK-DAG: [[IMAG_LHS:%.+]] = "mhlo.imag"([[LHS]])
// CHECK-DAG: [[IMAG_RHS:%.+]] = "mhlo.imag"([[RHS]])
// CHECK-DAG: [[OUTI:%.+]] = "mhlo.compare"([[IMAG_LHS]], [[IMAG_RHS]]) {comparison_direction = "EQ"}
// CHECK-DAG: [[OUT:%.+]] = mhlo.and [[OUTR]], [[OUTI]]
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xi1>
// CHECK: return [[OUT]]
return %0 : tensor<2xi1>
}
// CHECK-LABEL: @compare_ne
// CHECK: ([[LHS:%.+]]: tensor<2xcomplex<f32>>, [[RHS:%.+]]: tensor<2xcomplex<f32>>)
func @compare_ne(%lhs : tensor<2xcomplex<f32>>, %rhs: tensor<2xcomplex<f32>>) -> (tensor<2xi1>) {
// CHECK-DAG: [[REAL_LHS:%.+]] = "mhlo.real"([[LHS]])
// CHECK-DAG: [[REAL_RHS:%.+]] = "mhlo.real"([[RHS]])
// CHECK-DAG: [[OUTR:%.+]] = "mhlo.compare"([[REAL_LHS]], [[REAL_RHS]]) {comparison_direction = "NE"}
// CHECK-DAG: [[IMAG_LHS:%.+]] = "mhlo.imag"([[LHS]])
// CHECK-DAG: [[IMAG_RHS:%.+]] = "mhlo.imag"([[RHS]])
// CHECK-DAG: [[OUTI:%.+]] = "mhlo.compare"([[IMAG_LHS]], [[IMAG_RHS]]) {comparison_direction = "NE"}
// CHECK-DAG: [[OUT:%.+]] = mhlo.or [[OUTR]], [[OUTI]]
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"} : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xi1>
// CHECK: return [[OUT]]
return %0 : tensor<2xi1>
}