Introduce constant folds for ReduceOp with single LogicalAnd or LogicalOr op.
PiperOrigin-RevId: 370551483
This commit is contained in:
parent
1fff544339
commit
e500ab37a1
|
@ -646,6 +646,7 @@ def HLO_ReduceOp: HLO_Op<"reduce", [
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
|
||||||
// TODO(hinsu): Verify that the attached body arguments and results are
|
// TODO(hinsu): Verify that the attached body arguments and results are
|
||||||
// compatible with reduce op's operands.
|
// compatible with reduce op's operands.
|
||||||
|
|
|
@ -1917,9 +1917,82 @@ LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If all returned values in the ReduceOp region exists outside
|
||||||
|
// the region replace the ReduceOp with those values.
|
||||||
|
mlir::Block& bb = this->body().front();
|
||||||
|
SmallVector<Value> replaced_results;
|
||||||
|
if (auto ret_op = mlir::dyn_cast<ReturnOp>(bb.back())) {
|
||||||
|
for (Value result : ret_op.results()) {
|
||||||
|
if (result.getParentRegion() == ret_op->getParentRegion())
|
||||||
|
return failure();
|
||||||
|
replaced_results.push_back(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
results.insert(results.end(), replaced_results.begin(),
|
||||||
|
replaced_results.end());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enable constant folding to occur within the region of the ReduceOp
|
||||||
|
// by replacing block argument uses with constants if:
|
||||||
|
// 1. All the ReduceOp operands are splat constants.
|
||||||
|
// 2. The ReduceOp region consists of a single logical AND or logical OR.
|
||||||
|
// The pattern leverages the idempotent property of the AND and OR operators
|
||||||
|
// to determine the value of a reduction on splat constants. Other boolean
|
||||||
|
// operators do not have this property, and need separate patterns to resolve
|
||||||
|
// reductions of their splat constants.
|
||||||
|
struct LowerBoolSplatConstantsIntoRegion : public OpRewritePattern<ReduceOp> {
|
||||||
|
using OpRewritePattern<ReduceOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ReduceOp op,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
mlir::Block& bb = op.body().front();
|
||||||
|
|
||||||
|
// Ensure only a compute op and return op exist and the
|
||||||
|
// compute op is an AND or OR op.
|
||||||
|
if (bb.getOperations().size() != 2) return failure();
|
||||||
|
if (!mlir::isa<mhlo::AndOp, mhlo::OrOp>(bb.front())) return failure();
|
||||||
|
|
||||||
|
// Ensure all operands are splat constants.
|
||||||
|
SmallVector<DenseElementsAttr, 4> barg_cst_attrs;
|
||||||
|
for (auto inp_and_barg : llvm::zip(op.getOperands(), bb.getArguments())) {
|
||||||
|
Value inp = std::get<0>(inp_and_barg);
|
||||||
|
BlockArgument barg = std::get<1>(inp_and_barg);
|
||||||
|
ConstOp cst = inp.getDefiningOp<ConstOp>();
|
||||||
|
if (!cst) return failure();
|
||||||
|
|
||||||
|
auto cst_attr = cst.value().dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
if (!cst_attr.isSplat()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "Must be splat constant.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto barg_shaped_type = barg.getType().dyn_cast<ShapedType>();
|
||||||
|
if (!barg_shaped_type) return failure();
|
||||||
|
|
||||||
|
auto barg_cst_attr =
|
||||||
|
DenseElementsAttr::get(barg_shaped_type, cst_attr.getSplatValue());
|
||||||
|
barg_cst_attrs.push_back(barg_cst_attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new splat constants to replace block arguments.
|
||||||
|
for (BlockArgument barg : bb.getArguments()) {
|
||||||
|
int arg_idx = barg.getArgNumber();
|
||||||
|
mhlo::ConstOp new_cst = rewriter.create<mhlo::ConstOp>(
|
||||||
|
bb.front().getLoc(), barg.getType(), barg_cst_attrs[arg_idx]);
|
||||||
|
barg.replaceAllUsesWith(new_cst);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void ReduceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
|
MLIRContext* context) {
|
||||||
|
results.insert<LowerBoolSplatConstantsIntoRegion>(context);
|
||||||
|
}
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// RngNormalOp
|
// RngNormalOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @noop
|
// CHECK-LABEL: func @noop
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)
|
||||||
|
@ -12,3 +14,53 @@ func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||||
}) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
|
}) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
|
||||||
return %2 : tensor<4x8xf32>
|
return %2 : tensor<4x8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @and_fold
|
||||||
|
func @and_fold() -> (tensor<i1>, tensor<i1>) {
|
||||||
|
%0 = mhlo.constant dense<true> : tensor<2xi1>
|
||||||
|
%2 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%3 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%4 = "mhlo.reduce"(%0, %2) ( {
|
||||||
|
^bb0(%arg2: tensor<i1>, %arg3: tensor<i1>):
|
||||||
|
%11 = mhlo.and %arg2, %arg3 : tensor<i1>
|
||||||
|
"mhlo.return"(%11) : (tensor<i1>) -> ()
|
||||||
|
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<i1>) -> tensor<i1>
|
||||||
|
|
||||||
|
%5 = "mhlo.reduce"(%0, %3) ( {
|
||||||
|
^bb0(%arg4: tensor<i1>, %arg5: tensor<i1>):
|
||||||
|
%12 = mhlo.and %arg4, %arg5 : tensor<i1>
|
||||||
|
"mhlo.return"(%12) : (tensor<i1>) -> ()
|
||||||
|
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<i1>) -> tensor<i1>
|
||||||
|
return %4, %5 : tensor<i1>, tensor<i1>
|
||||||
|
|
||||||
|
// CHECK: %[[CST:.*]] = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
// CHECK: %[[CST1:.*]] = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
// CHECK: return %[[CST]], %[[CST1]] : tensor<i1>, tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @or_fold
|
||||||
|
func @or_fold() -> (tensor<i1>, tensor<i1>) {
|
||||||
|
%0 = mhlo.constant dense<false> : tensor<2xi1>
|
||||||
|
%2 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
%3 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
%4 = "mhlo.reduce"(%0, %2) ( {
|
||||||
|
^bb0(%arg2: tensor<i1>, %arg3: tensor<i1>):
|
||||||
|
%11 = mhlo.or %arg2, %arg3 : tensor<i1>
|
||||||
|
"mhlo.return"(%11) : (tensor<i1>) -> ()
|
||||||
|
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<i1>) -> tensor<i1>
|
||||||
|
|
||||||
|
%5 = "mhlo.reduce"(%0, %3) ( {
|
||||||
|
^bb0(%arg4: tensor<i1>, %arg5: tensor<i1>):
|
||||||
|
%12 = mhlo.or %arg4, %arg5 : tensor<i1>
|
||||||
|
"mhlo.return"(%12) : (tensor<i1>) -> ()
|
||||||
|
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<i1>) -> tensor<i1>
|
||||||
|
return %4, %5 : tensor<i1>, tensor<i1>
|
||||||
|
|
||||||
|
// CHECK: %[[CST:.*]] = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
// CHECK: %[[CST1:.*]] = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
// CHECK: return %[[CST]], %[[CST1]] : tensor<i1>, tensor<i1>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue