diff --git a/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td b/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td index ea67c05..6ee6f12 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td +++ b/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td @@ -36,6 +36,10 @@ def IsSameSizePred : CPred< def IsSameSizeConstraint : Constraint; +// Unary Lowering Patterns. +def : Pat<(HLO_CeilOp HLO_FpTensor:$i), (CeilFOp $i)>; + +// Binary Lowering Patterns. def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r), (AndOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; diff --git a/tests/legalize-to-std.mlir b/tests/legalize-to-std.mlir index 37a6149..abe4e87 100644 --- a/tests/legalize-to-std.mlir +++ b/tests/legalize-to-std.mlir @@ -42,6 +42,15 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32 return %4 : tensor<4xi32> } +// CHECK-LABEL: func @unary_ops_float +func @unary_ops_float(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %0 = ceilf %arg0 : tensor<4xf32> + %0 = "mhlo.ceil"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + + // CHECK-NEXT: return %0 : tensor<4xf32> + return %0 : tensor<4xf32> +} + // CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>