From 7c93352a40789688e7c88e0cc96873ee2f42392a Mon Sep 17 00:00:00 2001 From: Robert Suderman Date: Tue, 1 Sep 2020 18:33:08 -0700 Subject: [PATCH] Scalar / Trivial folding for mhlo.select This covers the case where the predicate is a splat or the on_true/on_false values are the same. PiperOrigin-RevId: 329622785 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 2 ++ lib/Dialect/mhlo/IR/hlo_ops.cc | 23 +++++++++++++++ tests/canonicalize.mlir | 32 +++++++++++++++++++++ 3 files changed, 57 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 12b9f5a..f4bb33d 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1153,6 +1153,8 @@ def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods operands) { + if (on_true() == on_false()) { + return on_true(); + } + + auto predicate = operands[0].dyn_cast_or_null(); + if (!predicate) { + return {}; + } + + auto predicateTy = predicate.getType().cast(); + if (!predicateTy.getElementType().isInteger(1)) { + return {}; + } + + if (predicate.isSplat()) { + return predicate.getSplatValue().getBoolValue() ? on_true() + : on_false(); + } + + return {}; +} + // Makes it such that a SelectOp that is a non-root operation in a DRR infers // the return type based on operand type. LogicalResult SelectOp::inferReturnTypes( diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 6ee6586..9771fba 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -633,3 +633,35 @@ func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor { // CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor // CHECK-NEXT: return %[[C]] } + +// CHECK-LABEL: func @fold_select_same +func @fold_select_same(%arg0 : tensor, %arg1 : tensor) -> tensor { + %1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor, tensor, tensor) -> tensor + // CHECK: return %arg0 + return %1 : tensor +} + +// CHECK-LABEL: func @fold_select_first +func @fold_select_first(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor, tensor, tensor) -> tensor + // CHECK: return %arg0 + return %1 : tensor +} + +// CHECK-LABEL: func @fold_select_second +func @fold_select_second(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor, tensor, tensor) -> tensor + // CHECK: return %arg1 + return %1 : tensor +} + +// CHECK-LABEL: func @fold_select_vector +func @fold_select_vector(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> tensor<4xf32> { + %0 = mhlo.constant dense<1> : tensor<4xi1> + %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: return %arg0 + return %1 : tensor<4xf32> +} +