From b6d81606118a3377faf2519fb62946bd00b831d2 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 9 Jun 2021 05:08:36 -0700 Subject: [PATCH] Add Broadcasting and BroadcastingElementwise traits to ConstantLikeOp. This allows to include such ops in rank specialization clusters. PiperOrigin-RevId: 378380915 --- include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td | 4 ++-- tests/rank-specialization.mlir | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 538a475..fd806d1 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -581,8 +581,8 @@ def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", } def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like", - [NoSideEffect, SameOperandsAndResultShape, - InferTypeOpInterface, + [NoSideEffect, HLOClient_Broadcasting, HLOClient_BroadcastingElementwise, + SameOperandsAndResultShape, InferTypeOpInterface, DeclareOpInterfaceMethods, NativeOpTrait<"InferTensorType">]> { diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index e6c2476..117fbe7 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -19,6 +19,22 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, return %1 : tensor<*xf32> } +// CHECK-LABEL: @compare_const_like +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>) +func @compare_const_like(%arg0 : tensor<*xf32>) -> tensor<*xi1> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]]) ( { + // CHECK: ^bb0(%[[ARG1:.*]]: tensor<*xf32>): + // CHECK: %[[ZERO:.*]] = "chlo.constant_like"(%[[ARG1]]) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: %[[CMP_GT:.*]] = chlo.broadcast_compare %[[ARG1]], %[[ZERO]] {comparison_direction = "GT"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[CMP_GT]]) : (tensor<*xi1>) -> () + // CHECK: }) : (tensor<*xf32>) -> tensor<*xi1> + // CHECK: return %[[RES]] : tensor<*xi1> + %0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> + %1 = chlo.broadcast_compare %arg0, %0 {comparison_direction = "GT"} + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> + return %1 : tensor<*xi1> +} + // CHECK-SCF-LABEL: @add_mul // CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) // CHECK-SCF-DAG: %[[C1:.*]] = constant 1