diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 122a90c..efb041a 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -549,6 +549,8 @@ def HLO_IfOp: HLO_Op<"if", [ // TODO(b/129422361): ConditionalOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; + + let hasCanonicalizer = 1; } // Xla Client API has two separate calls for indexed and predicated conditional, @@ -569,6 +571,8 @@ def HLO_CaseOp: HLO_Op<"case", [ let results = (outs Variadic); let hasCustomHLOConverter = 1; + + let hasCanonicalizer = 1; } diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index bdc296d..1ccfed8 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -45,6 +45,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" @@ -131,6 +132,19 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices, return GetI64ElementsAttr(slice_limits, builder); } +/// Replaces the given op with the contents of the given single-block region, +/// using the operands of the block terminator to replace operation results. +static void ReplaceOpWithRegion(PatternRewriter& rewriter, Operation* op, + Region& region, ValueRange blockArgs = {}) { + assert(llvm::hasSingleElement(region) && "expected single-block region"); + Block* block = ®ion.front(); + Operation* terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.mergeBlockBefore(block, op, blockArgs); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); +} + #include "mhlo_canonicalize.inc" } // namespace @@ -2129,6 +2143,24 @@ static LogicalResult Verify(IfOp op) { return success(); } +static LogicalResult InlineIfConstantCondition(IfOp ifOp, + PatternRewriter& rewriter) { + DenseIntElementsAttr pred_attr; + if (!matchPattern(ifOp.pred(), m_Constant(&pred_attr))) return failure(); + + if (pred_attr.getSplatValue().getValue()) { + ReplaceOpWithRegion(rewriter, ifOp, ifOp.true_branch(), ifOp.true_arg()); + } else { + ReplaceOpWithRegion(rewriter, ifOp, ifOp.false_branch(), ifOp.false_arg()); + } + return success(); +} + +void IfOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.add(&InlineIfConstantCondition); +} + //===----------------------------------------------------------------------===// // Case Op //===----------------------------------------------------------------------===// @@ -2150,6 +2182,31 @@ static LogicalResult Verify(CaseOp op) { return success(); } +static LogicalResult InlineCaseConstantCondition(CaseOp caseOp, + PatternRewriter& rewriter) { + DenseIntElementsAttr index_attr; + if (!matchPattern(caseOp.index(), m_Constant(&index_attr))) { + return failure(); + } + int64_t index = + index_attr.getSplatValue().getValue().getSExtValue(); + // For an OOB index, the last branch is executed as the default branch: + // https://www.tensorflow.org/xla/operation_semantics#conditional + if (index < 0 || index >= caseOp.getNumRegions()) + index = caseOp.getNumRegions() - 1; + + Region& region = caseOp.getRegion(index); + if (!llvm::hasSingleElement(region)) return failure(); + ReplaceOpWithRegion(rewriter, caseOp, region, + caseOp.branch_operands()[index]); + return success(); +} + +void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.add(&InlineCaseConstantCondition); +} + //===----------------------------------------------------------------------===// // SqrtOp //===----------------------------------------------------------------------===// diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index c234d32..70a86e1 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -1280,6 +1280,108 @@ func @not_fold_sqrt_neg_constants() -> tensor<4xf32> { return %1 : tensor<4xf32> } +// CHECK-LABEL: func @fold_if_true( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] +// CHECK-SAME: ) +func @fold_if_true(%arg0 : tensor, %arg1 : tensor) -> tensor { + // CHECK-NOT: mhlo.if + // CHECK: return %[[ARG0]] + %true = mhlo.constant dense : tensor + %0 = "mhlo.if"(%true, %arg0, %arg1) ( { + ^bb0(%bbarg0: tensor): + "mhlo.return"(%bbarg0) : (tensor) -> () + }, { + ^bb0(%bbarg1: tensor): + "mhlo.return"(%bbarg1) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @fold_if_false( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] +// CHECK-SAME: ) +func @fold_if_false(%arg0 : tensor, %arg1 : tensor) -> tensor { + // CHECK-NOT: mhlo.if + // CHECK: return %[[ARG1]] + %false = mhlo.constant dense : tensor + %0 = "mhlo.if"(%false, %arg0, %arg1) ( { + ^bb0(%bbarg0: tensor): + "mhlo.return"(%bbarg0) : (tensor) -> () + }, { + ^bb0(%bbarg1: tensor): + "mhlo.return"(%bbarg1) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @fold_case( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] +// CHECK-SAME: ) +func @fold_case(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-NOT: mhlo.case + // CHECK: return %[[ARG1]] + %c1 = mhlo.constant dense<1> : tensor + %0 = "mhlo.case"(%c1, %arg0, %arg1, %arg2) ( { + ^bb0(%bbarg0: tensor): + "mhlo.return"(%bbarg0) : (tensor) -> () + }, { + ^bb0(%bbarg1: tensor): + "mhlo.return"(%bbarg1) : (tensor) -> () + }, { + ^bb0(%bbarg2: tensor): + "mhlo.return"(%bbarg2) : (tensor) -> () + }) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @fold_case_negative_index( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] +// CHECK-SAME: ) +func @fold_case_negative_index(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-NOT: mhlo.case + // CHECK: return %[[ARG2]] + %m1000 = mhlo.constant dense<-1000> : tensor + %0 = "mhlo.case"(%m1000, %arg0, %arg1, %arg2) ( { + ^bb0(%bbarg0: tensor): + "mhlo.return"(%bbarg0) : (tensor) -> () + }, { + ^bb0(%bbarg1: tensor): + "mhlo.return"(%bbarg1) : (tensor) -> () + }, { + ^bb0(%bbarg2: tensor): + "mhlo.return"(%bbarg2) : (tensor) -> () + }) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @fold_case_oob_index( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] +// CHECK-SAME: ) +func @fold_case_oob_index(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-NOT: mhlo.case + // CHECK: return %[[ARG2]] + %c1000 = mhlo.constant dense<1000> : tensor + %0 = "mhlo.case"(%c1000, %arg0, %arg1, %arg2) ( { + ^bb0(%bbarg0: tensor): + "mhlo.return"(%bbarg0) : (tensor) -> () + }, { + ^bb0(%bbarg1: tensor): + "mhlo.return"(%bbarg1) : (tensor) -> () + }, { + ^bb0(%bbarg2: tensor): + "mhlo.return"(%bbarg2) : (tensor) -> () + }) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + // CHECK-LABEL: @tensor_flow_scatter_v1_update func @tensor_flow_scatter_v1_update() -> tensor<3x3xi32> { %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>