Generate Equal and NotEqual kernels for complex types.
PiperOrigin-RevId: 368586877
This commit is contained in:
parent
bfc8cca38f
commit
db9f298505
|
@ -87,3 +87,11 @@ def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
|
|||
(HLO_CosOp (HLO_ImagOp:$imag $val)),
|
||||
(HLO_ExpOp:$exp (HLO_RealOp:$real $val))),
|
||||
(HLO_MulOp (HLO_SinOp $imag), $exp))>;
|
||||
|
||||
foreach pair = [[HLO_COMPARISON_DIRECTION_NE, HLO_OrOp],
|
||||
[HLO_COMPARISON_DIRECTION_EQ, HLO_AndOp]] in {
|
||||
def : Pat<(HLO_CompareOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, pair[0], $compare_type),
|
||||
(pair[1]
|
||||
(HLO_CompareOp (HLO_RealOp $lhs), (HLO_RealOp $rhs), pair[0], $compare_type),
|
||||
(HLO_CompareOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs), pair[0], $compare_type))>;
|
||||
}
|
||||
|
|
|
@ -238,3 +238,35 @@ func @exp_unranked(%arg0 : tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>) {
|
|||
// CHECK: [[OUT]]
|
||||
return %0 : tensor<*xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @compare_eq
|
||||
// CHECK: ([[LHS:%.+]]: tensor<2xcomplex<f32>>, [[RHS:%.+]]: tensor<2xcomplex<f32>>)
|
||||
func @compare_eq(%lhs : tensor<2xcomplex<f32>>, %rhs: tensor<2xcomplex<f32>>) -> (tensor<2xi1>) {
|
||||
// CHECK-DAG: [[REAL_LHS:%.+]] = "mhlo.real"([[LHS]])
|
||||
// CHECK-DAG: [[REAL_RHS:%.+]] = "mhlo.real"([[RHS]])
|
||||
// CHECK-DAG: [[OUTR:%.+]] = "mhlo.compare"([[REAL_LHS]], [[REAL_RHS]]) {comparison_direction = "EQ"}
|
||||
// CHECK-DAG: [[IMAG_LHS:%.+]] = "mhlo.imag"([[LHS]])
|
||||
// CHECK-DAG: [[IMAG_RHS:%.+]] = "mhlo.imag"([[RHS]])
|
||||
// CHECK-DAG: [[OUTI:%.+]] = "mhlo.compare"([[IMAG_LHS]], [[IMAG_RHS]]) {comparison_direction = "EQ"}
|
||||
// CHECK-DAG: [[OUT:%.+]] = mhlo.and [[OUTR]], [[OUTI]]
|
||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xi1>
|
||||
|
||||
// CHECK: return [[OUT]]
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @compare_ne
|
||||
// CHECK: ([[LHS:%.+]]: tensor<2xcomplex<f32>>, [[RHS:%.+]]: tensor<2xcomplex<f32>>)
|
||||
func @compare_ne(%lhs : tensor<2xcomplex<f32>>, %rhs: tensor<2xcomplex<f32>>) -> (tensor<2xi1>) {
|
||||
// CHECK-DAG: [[REAL_LHS:%.+]] = "mhlo.real"([[LHS]])
|
||||
// CHECK-DAG: [[REAL_RHS:%.+]] = "mhlo.real"([[RHS]])
|
||||
// CHECK-DAG: [[OUTR:%.+]] = "mhlo.compare"([[REAL_LHS]], [[REAL_RHS]]) {comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[IMAG_LHS:%.+]] = "mhlo.imag"([[LHS]])
|
||||
// CHECK-DAG: [[IMAG_RHS:%.+]] = "mhlo.imag"([[RHS]])
|
||||
// CHECK-DAG: [[OUTI:%.+]] = "mhlo.compare"([[IMAG_LHS]], [[IMAG_RHS]]) {comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[OUT:%.+]] = mhlo.or [[OUTR]], [[OUTI]]
|
||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"} : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xi1>
|
||||
|
||||
// CHECK: return [[OUT]]
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue