[MLIR][MHLO] Apply patterns in MoveUpDynamicBroadcastsForFusionPass greedily

PiperOrigin-RevId: 365556488
This commit is contained in:
A. Unique TensorFlower 2021-03-29 06:00:09 -07:00 committed by TensorFlow MLIR Team
parent 238c1d8a92
commit fb819c1de8
2 changed files with 36 additions and 66 deletions

View File

@ -33,76 +33,60 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace mhlo {
namespace {
bool IsShapeOfOpMovable(Value arg) {
return arg.getDefiningOp<InferShapedTypeOpInterface>();
}
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
explicit ShapeOfOpConversion(MLIRContext *context)
: OpConversionPattern<shape::ShapeOfOp>(context) {
struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
explicit ShapeReificationPattern(MLIRContext *context)
: OpRewritePattern<shape::ShapeOfOp>(context) {
// Recursively reify until we hit an op that doesn't support it.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(
shape::ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::ShapeOfOp::Adaptor transformed(operands);
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
// Only reify shape computation if operand allows for it.
if (!IsShapeOfOpMovable(transformed.arg())) return failure();
auto shape_origin = op.arg().getDefiningOp<InferShapedTypeOpInterface>();
if (!shape_origin) return failure();
auto shape_origin =
transformed.arg().getDefiningOp<InferShapedTypeOpInterface>();
llvm::SmallVector<Value, 1> reified_shapes;
if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes)))
llvm::SmallVector<Value, 1> reifications;
if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reifications)))
return failure();
assert(reifications.size() == 1);
Value reified_shape = reifications.front();
assert(reified_shapes.size() == 1);
Value reified_shape = reified_shapes.front();
// Insert cast if needed.
if (reified_shape.getType() != op.getType()) {
reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
reified_shape);
}
rewriter.replaceOp(op, reified_shapes.front());
rewriter.replaceOp(op, reified_shape);
return success();
}
};
// We can only move up broadcasting ops that apply to the result of a
// shape-preserving operation.
bool isDynamicBroadcastInDimOpMovable(Value operand) {
Operation *producer_op = operand.getDefiningOp();
return producer_op != nullptr &&
producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() &&
producer_op->hasTrait<OpTrait::Elementwise>();
}
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
struct MoveUpBroadcastInDimOpConversion
: public OpConversionPattern<DynamicBroadcastInDimOp> {
explicit MoveUpBroadcastInDimOpConversion(MLIRContext *context)
: OpConversionPattern<DynamicBroadcastInDimOp>(context) {}
struct MoveUpBroadcastInDimOpPattern
: public OpRewritePattern<DynamicBroadcastInDimOp> {
using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
DynamicBroadcastInDimOp bcast_op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
DynamicBroadcastInDimOp::Adaptor transformed(operands);
if (!isDynamicBroadcastInDimOpMovable(transformed.operand()))
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast_op,
PatternRewriter &rewriter) const override {
Operation *producer_op = bcast_op.operand().getDefiningOp();
if (!producer_op ||
!producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() ||
!producer_op->hasTrait<OpTrait::Elementwise>()) {
return failure();
}
// Materialize broadcast on operands.
SmallVector<Value, 2> bcasted_operands;
Location loc = bcast_op.getLoc();
ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape();
Operation *producer_op = transformed.operand().getDefiningOp();
for (Value operand : producer_op->getOperands()) {
// The broadcast only works on ranked operations.
auto operand_ty = operand.getType().dyn_cast<RankedTensorType>();
@ -114,7 +98,7 @@ struct MoveUpBroadcastInDimOpConversion
auto bcasted_operand_ty =
RankedTensorType::get(ty_shape, operand_ty.getElementType());
bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>(
loc, bcasted_operand_ty, operand, transformed.output_dimensions(),
loc, bcasted_operand_ty, operand, bcast_op.output_dimensions(),
bcast_op.broadcast_dimensions()));
}
@ -140,18 +124,14 @@ struct MoveUpDynamicBroadcastsForFusionPass
}
void runOnFunction() override {
// Setup target legality.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
PopulateMoveUpDynamicBroadcastsForFusionLegality(&target);
// Populate rewrite patterns.
OwningRewritePatternList patterns(&ctx);
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns);
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns);
// Apply transformation.
if (failed(applyPartialConversion(getFunction(), target,
std::move(patterns)))) {
if (failed(
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
return signalPassFailure();
}
}
@ -159,23 +139,11 @@ struct MoveUpDynamicBroadcastsForFusionPass
} // namespace
void PopulateMoveUpDynamicBroadcastsForFusionLegality(
ConversionTarget *target) {
target->addLegalDialect<MhloDialect, StandardOpsDialect, shape::ShapeDialect,
tensor::TensorDialect>();
target->addDynamicallyLegalOp<shape::ShapeOfOp>(
[](shape::ShapeOfOp op) { return !IsShapeOfOpMovable(op.arg()); });
target->addDynamicallyLegalOp<DynamicBroadcastInDimOp>(
[](DynamicBroadcastInDimOp op) {
return !isDynamicBroadcastInDimOpMovable(op.operand());
});
}
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
// clang-format off
patterns->insert<ShapeOfOpConversion,
MoveUpBroadcastInDimOpConversion>(context);
patterns->insert<ShapeReificationPattern,
MoveUpBroadcastInDimOpPattern>(context);
// clang-format on
}

View File

@ -5,7 +5,8 @@
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>)
func @shape_of_unary(%arg : tensor<?x32xi16>) {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x32xi16> -> tensor<2xindex>
// CHECK: "use"(%[[SHAPE]])
// CHECK: %[[CASTED:.*]] = tensor.cast %[[SHAPE]] : tensor<2xindex> to tensor<?xindex>
// CHECK: "use"(%[[CASTED]])
%0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
"use"(%1) : (tensor<?xindex>) -> ()
@ -19,7 +20,8 @@ func @shape_of_unary(%arg : tensor<?x32xi16>) {
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>)
func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex>
// CHECK: "use"(%[[SHAPE]])
// CHECK: %[[CASTED:.*]] = tensor.cast %[[SHAPE]] : tensor<2xindex> to tensor<?xindex>
// CHECK: "use"(%[[CASTED]])
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
%1 = mhlo.subtract %0, %arg1 : tensor<?x32xf16>
%2 = shape.shape_of %1 : tensor<?x32xf16> -> tensor<?xindex>