diff --git a/lib/Dialect/mhlo/transforms/lower_complex_patterns.td b/lib/Dialect/mhlo/transforms/lower_complex_patterns.td index d132297..f67c786 100644 --- a/lib/Dialect/mhlo/transforms/lower_complex_patterns.td +++ b/lib/Dialect/mhlo/transforms/lower_complex_patterns.td @@ -87,3 +87,11 @@ def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val), (HLO_CosOp (HLO_ImagOp:$imag $val)), (HLO_ExpOp:$exp (HLO_RealOp:$real $val))), (HLO_MulOp (HLO_SinOp $imag), $exp))>; + +foreach pair = [[HLO_COMPARISON_DIRECTION_NE, HLO_OrOp], + [HLO_COMPARISON_DIRECTION_EQ, HLO_AndOp]] in { + def : Pat<(HLO_CompareOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, pair[0], $compare_type), + (pair[1] + (HLO_CompareOp (HLO_RealOp $lhs), (HLO_RealOp $rhs), pair[0], $compare_type), + (HLO_CompareOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs), pair[0], $compare_type))>; +} diff --git a/tests/lower-complex.mlir b/tests/lower-complex.mlir index 141c238..ed9e997 100644 --- a/tests/lower-complex.mlir +++ b/tests/lower-complex.mlir @@ -238,3 +238,35 @@ func @exp_unranked(%arg0 : tensor<*xcomplex>) -> (tensor<*xcomplex>) { // CHECK: [[OUT]] return %0 : tensor<*xcomplex> } + +// CHECK-LABEL: @compare_eq +// CHECK: ([[LHS:%.+]]: tensor<2xcomplex>, [[RHS:%.+]]: tensor<2xcomplex>) +func @compare_eq(%lhs : tensor<2xcomplex>, %rhs: tensor<2xcomplex>) -> (tensor<2xi1>) { + // CHECK-DAG: [[REAL_LHS:%.+]] = "mhlo.real"([[LHS]]) + // CHECK-DAG: [[REAL_RHS:%.+]] = "mhlo.real"([[RHS]]) + // CHECK-DAG: [[OUTR:%.+]] = "mhlo.compare"([[REAL_LHS]], [[REAL_RHS]]) {comparison_direction = "EQ"} + // CHECK-DAG: [[IMAG_LHS:%.+]] = "mhlo.imag"([[LHS]]) + // CHECK-DAG: [[IMAG_RHS:%.+]] = "mhlo.imag"([[RHS]]) + // CHECK-DAG: [[OUTI:%.+]] = "mhlo.compare"([[IMAG_LHS]], [[IMAG_RHS]]) {comparison_direction = "EQ"} + // CHECK-DAG: [[OUT:%.+]] = mhlo.and [[OUTR]], [[OUTI]] + %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<2xcomplex>, tensor<2xcomplex>) -> tensor<2xi1> + + // CHECK: return [[OUT]] + return %0 : tensor<2xi1> +} + +// CHECK-LABEL: @compare_ne +// CHECK: ([[LHS:%.+]]: tensor<2xcomplex>, [[RHS:%.+]]: tensor<2xcomplex>) +func @compare_ne(%lhs : tensor<2xcomplex>, %rhs: tensor<2xcomplex>) -> (tensor<2xi1>) { + // CHECK-DAG: [[REAL_LHS:%.+]] = "mhlo.real"([[LHS]]) + // CHECK-DAG: [[REAL_RHS:%.+]] = "mhlo.real"([[RHS]]) + // CHECK-DAG: [[OUTR:%.+]] = "mhlo.compare"([[REAL_LHS]], [[REAL_RHS]]) {comparison_direction = "NE"} + // CHECK-DAG: [[IMAG_LHS:%.+]] = "mhlo.imag"([[LHS]]) + // CHECK-DAG: [[IMAG_RHS:%.+]] = "mhlo.imag"([[RHS]]) + // CHECK-DAG: [[OUTI:%.+]] = "mhlo.compare"([[IMAG_LHS]], [[IMAG_RHS]]) {comparison_direction = "NE"} + // CHECK-DAG: [[OUT:%.+]] = mhlo.or [[OUTR]], [[OUTI]] + %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"} : (tensor<2xcomplex>, tensor<2xcomplex>) -> tensor<2xi1> + + // CHECK: return [[OUT]] + return %0 : tensor<2xi1> +}