Introduce constant folds for ReduceOp with single LogicalAnd or LogicalOr op.

PiperOrigin-RevId: 370551483
This commit is contained in:
A. Unique TensorFlower 2021-04-26 15:10:20 -07:00 committed by TensorFlow MLIR Team
parent 1fff544339
commit e500ab37a1
3 changed files with 127 additions and 1 deletions

View File

@ -646,6 +646,7 @@ def HLO_ReduceOp: HLO_Op<"reduce", [
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1;
// TODO(hinsu): Verify that the attached body arguments and results are // TODO(hinsu): Verify that the attached body arguments and results are
// compatible with reduce op's operands. // compatible with reduce op's operands.

View File

@ -1917,9 +1917,82 @@ LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
} }
return success(); return success();
} }
// If all returned values in the ReduceOp region exists outside
// the region replace the ReduceOp with those values.
mlir::Block& bb = this->body().front();
SmallVector<Value> replaced_results;
if (auto ret_op = mlir::dyn_cast<ReturnOp>(bb.back())) {
for (Value result : ret_op.results()) {
if (result.getParentRegion() == ret_op->getParentRegion())
return failure();
replaced_results.push_back(result);
}
results.insert(results.end(), replaced_results.begin(),
replaced_results.end());
return success();
}
return failure(); return failure();
} }
// Enable constant folding to occur within the region of the ReduceOp
// by replacing block argument uses with constants if:
// 1. All the ReduceOp operands are splat constants.
// 2. The ReduceOp region consists of a single logical AND or logical OR.
// The pattern leverages the idempotent property of the AND and OR operators
// to determine the value of a reduction on splat constants. Other boolean
// operators do not have this property, and need separate patterns to resolve
// reductions of their splat constants.
struct LowerBoolSplatConstantsIntoRegion : public OpRewritePattern<ReduceOp> {
using OpRewritePattern<ReduceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ReduceOp op,
PatternRewriter& rewriter) const override {
mlir::Block& bb = op.body().front();
// Ensure only a compute op and return op exist and the
// compute op is an AND or OR op.
if (bb.getOperations().size() != 2) return failure();
if (!mlir::isa<mhlo::AndOp, mhlo::OrOp>(bb.front())) return failure();
// Ensure all operands are splat constants.
SmallVector<DenseElementsAttr, 4> barg_cst_attrs;
for (auto inp_and_barg : llvm::zip(op.getOperands(), bb.getArguments())) {
Value inp = std::get<0>(inp_and_barg);
BlockArgument barg = std::get<1>(inp_and_barg);
ConstOp cst = inp.getDefiningOp<ConstOp>();
if (!cst) return failure();
auto cst_attr = cst.value().dyn_cast_or_null<DenseElementsAttr>();
if (!cst_attr.isSplat()) {
return rewriter.notifyMatchFailure(op, "Must be splat constant.");
}
auto barg_shaped_type = barg.getType().dyn_cast<ShapedType>();
if (!barg_shaped_type) return failure();
auto barg_cst_attr =
DenseElementsAttr::get(barg_shaped_type, cst_attr.getSplatValue());
barg_cst_attrs.push_back(barg_cst_attr);
}
// Create new splat constants to replace block arguments.
for (BlockArgument barg : bb.getArguments()) {
int arg_idx = barg.getArgNumber();
mhlo::ConstOp new_cst = rewriter.create<mhlo::ConstOp>(
bb.front().getLoc(), barg.getType(), barg_cst_attrs[arg_idx]);
barg.replaceAllUsesWith(new_cst);
}
return success();
}
};
void ReduceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<LowerBoolSplatConstantsIntoRegion>(context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// RngNormalOp // RngNormalOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1,4 +1,6 @@
// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
// -----
// CHECK-LABEL: func @noop // CHECK-LABEL: func @noop
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)
@ -12,3 +14,53 @@ func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
}) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32> }) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
return %2 : tensor<4x8xf32> return %2 : tensor<4x8xf32>
} }
// -----
// CHECK-LABEL: func @and_fold
func @and_fold() -> (tensor<i1>, tensor<i1>) {
%0 = mhlo.constant dense<true> : tensor<2xi1>
%2 = mhlo.constant dense<true> : tensor<i1>
%3 = mhlo.constant dense<false> : tensor<i1>
%4 = "mhlo.reduce"(%0, %2) ( {
^bb0(%arg2: tensor<i1>, %arg3: tensor<i1>):
%11 = mhlo.and %arg2, %arg3 : tensor<i1>
"mhlo.return"(%11) : (tensor<i1>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<i1>) -> tensor<i1>
%5 = "mhlo.reduce"(%0, %3) ( {
^bb0(%arg4: tensor<i1>, %arg5: tensor<i1>):
%12 = mhlo.and %arg4, %arg5 : tensor<i1>
"mhlo.return"(%12) : (tensor<i1>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<i1>) -> tensor<i1>
return %4, %5 : tensor<i1>, tensor<i1>
// CHECK: %[[CST:.*]] = mhlo.constant dense<true> : tensor<i1>
// CHECK: %[[CST1:.*]] = mhlo.constant dense<false> : tensor<i1>
// CHECK: return %[[CST]], %[[CST1]] : tensor<i1>, tensor<i1>
}
// -----
// CHECK-LABEL: func @or_fold
func @or_fold() -> (tensor<i1>, tensor<i1>) {
%0 = mhlo.constant dense<false> : tensor<2xi1>
%2 = mhlo.constant dense<false> : tensor<i1>
%3 = mhlo.constant dense<true> : tensor<i1>
%4 = "mhlo.reduce"(%0, %2) ( {
^bb0(%arg2: tensor<i1>, %arg3: tensor<i1>):
%11 = mhlo.or %arg2, %arg3 : tensor<i1>
"mhlo.return"(%11) : (tensor<i1>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<i1>) -> tensor<i1>
%5 = "mhlo.reduce"(%0, %3) ( {
^bb0(%arg4: tensor<i1>, %arg5: tensor<i1>):
%12 = mhlo.or %arg4, %arg5 : tensor<i1>
"mhlo.return"(%12) : (tensor<i1>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<i1>) -> tensor<i1>
return %4, %5 : tensor<i1>, tensor<i1>
// CHECK: %[[CST:.*]] = mhlo.constant dense<false> : tensor<i1>
// CHECK: %[[CST1:.*]] = mhlo.constant dense<true> : tensor<i1>
// CHECK: return %[[CST]], %[[CST1]] : tensor<i1>, tensor<i1>
}