[MLIR][MHLO] Apply patterns in MoveUpDynamicBroadcastsForFusionPass greedily

PiperOrigin-RevId: 365556488
This commit is contained in:
A. Unique TensorFlower 2021-03-29 06:00:09 -07:00 committed by TensorFlow MLIR Team
parent 238c1d8a92
commit fb819c1de8
2 changed files with 36 additions and 66 deletions

View File

@ -33,76 +33,60 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {
namespace { namespace {
bool IsShapeOfOpMovable(Value arg) { struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
return arg.getDefiningOp<InferShapedTypeOpInterface>(); explicit ShapeReificationPattern(MLIRContext *context)
} : OpRewritePattern<shape::ShapeOfOp>(context) {
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
explicit ShapeOfOpConversion(MLIRContext *context)
: OpConversionPattern<shape::ShapeOfOp>(context) {
// Recursively reify until we hit an op that doesn't support it. // Recursively reify until we hit an op that doesn't support it.
setHasBoundedRewriteRecursion(); setHasBoundedRewriteRecursion();
} }
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(shape::ShapeOfOp op,
shape::ShapeOfOp op, ArrayRef<Value> operands, PatternRewriter &rewriter) const override {
ConversionPatternRewriter &rewriter) const override {
shape::ShapeOfOp::Adaptor transformed(operands);
// Only reify shape computation if operand allows for it. // Only reify shape computation if operand allows for it.
if (!IsShapeOfOpMovable(transformed.arg())) return failure(); auto shape_origin = op.arg().getDefiningOp<InferShapedTypeOpInterface>();
if (!shape_origin) return failure();
auto shape_origin = llvm::SmallVector<Value, 1> reifications;
transformed.arg().getDefiningOp<InferShapedTypeOpInterface>(); if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reifications)))
llvm::SmallVector<Value, 1> reified_shapes;
if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes)))
return failure(); return failure();
assert(reifications.size() == 1);
Value reified_shape = reifications.front();
assert(reified_shapes.size() == 1); // Insert cast if needed.
Value reified_shape = reified_shapes.front();
if (reified_shape.getType() != op.getType()) { if (reified_shape.getType() != op.getType()) {
reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
reified_shape); reified_shape);
} }
rewriter.replaceOp(op, reified_shapes.front()); rewriter.replaceOp(op, reified_shape);
return success(); return success();
} }
}; };
// We can only move up broadcasting ops that apply to the result of a
// shape-preserving operation.
bool isDynamicBroadcastInDimOpMovable(Value operand) {
Operation *producer_op = operand.getDefiningOp();
return producer_op != nullptr &&
producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() &&
producer_op->hasTrait<OpTrait::Elementwise>();
}
// TODO(frgossen): Only move up broadcasting operations if there is a consumer. // TODO(frgossen): Only move up broadcasting operations if there is a consumer.
struct MoveUpBroadcastInDimOpConversion struct MoveUpBroadcastInDimOpPattern
: public OpConversionPattern<DynamicBroadcastInDimOp> { : public OpRewritePattern<DynamicBroadcastInDimOp> {
explicit MoveUpBroadcastInDimOpConversion(MLIRContext *context) using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
: OpConversionPattern<DynamicBroadcastInDimOp>(context) {}
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast_op,
DynamicBroadcastInDimOp bcast_op, ArrayRef<Value> operands, PatternRewriter &rewriter) const override {
ConversionPatternRewriter &rewriter) const override { Operation *producer_op = bcast_op.operand().getDefiningOp();
DynamicBroadcastInDimOp::Adaptor transformed(operands); if (!producer_op ||
if (!isDynamicBroadcastInDimOpMovable(transformed.operand())) !producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() ||
!producer_op->hasTrait<OpTrait::Elementwise>()) {
return failure(); return failure();
}
// Materialize broadcast on operands. // Materialize broadcast on operands.
SmallVector<Value, 2> bcasted_operands; SmallVector<Value, 2> bcasted_operands;
Location loc = bcast_op.getLoc(); Location loc = bcast_op.getLoc();
ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape(); ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape();
Operation *producer_op = transformed.operand().getDefiningOp();
for (Value operand : producer_op->getOperands()) { for (Value operand : producer_op->getOperands()) {
// The broadcast only works on ranked operations. // The broadcast only works on ranked operations.
auto operand_ty = operand.getType().dyn_cast<RankedTensorType>(); auto operand_ty = operand.getType().dyn_cast<RankedTensorType>();
@ -114,7 +98,7 @@ struct MoveUpBroadcastInDimOpConversion
auto bcasted_operand_ty = auto bcasted_operand_ty =
RankedTensorType::get(ty_shape, operand_ty.getElementType()); RankedTensorType::get(ty_shape, operand_ty.getElementType());
bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>( bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>(
loc, bcasted_operand_ty, operand, transformed.output_dimensions(), loc, bcasted_operand_ty, operand, bcast_op.output_dimensions(),
bcast_op.broadcast_dimensions())); bcast_op.broadcast_dimensions()));
} }
@ -140,18 +124,14 @@ struct MoveUpDynamicBroadcastsForFusionPass
} }
void runOnFunction() override { void runOnFunction() override {
// Setup target legality.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
PopulateMoveUpDynamicBroadcastsForFusionLegality(&target);
// Populate rewrite patterns. // Populate rewrite patterns.
OwningRewritePatternList patterns(&ctx); MLIRContext *ctx = &getContext();
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns); RewritePatternSet patterns(ctx);
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns);
// Apply transformation. // Apply transformation.
if (failed(applyPartialConversion(getFunction(), target, if (failed(
std::move(patterns)))) { applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
return signalPassFailure(); return signalPassFailure();
} }
} }
@ -159,23 +139,11 @@ struct MoveUpDynamicBroadcastsForFusionPass
} // namespace } // namespace
void PopulateMoveUpDynamicBroadcastsForFusionLegality(
ConversionTarget *target) {
target->addLegalDialect<MhloDialect, StandardOpsDialect, shape::ShapeDialect,
tensor::TensorDialect>();
target->addDynamicallyLegalOp<shape::ShapeOfOp>(
[](shape::ShapeOfOp op) { return !IsShapeOfOpMovable(op.arg()); });
target->addDynamicallyLegalOp<DynamicBroadcastInDimOp>(
[](DynamicBroadcastInDimOp op) {
return !isDynamicBroadcastInDimOpMovable(op.operand());
});
}
void PopulateMoveUpDynamicBroadcastsForFusionPatterns( void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) { MLIRContext *context, OwningRewritePatternList *patterns) {
// clang-format off // clang-format off
patterns->insert<ShapeOfOpConversion, patterns->insert<ShapeReificationPattern,
MoveUpBroadcastInDimOpConversion>(context); MoveUpBroadcastInDimOpPattern>(context);
// clang-format on // clang-format on
} }

View File

@ -5,7 +5,8 @@
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>) // CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>)
func @shape_of_unary(%arg : tensor<?x32xi16>) { func @shape_of_unary(%arg : tensor<?x32xi16>) {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x32xi16> -> tensor<2xindex> // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x32xi16> -> tensor<2xindex>
// CHECK: "use"(%[[SHAPE]]) // CHECK: %[[CASTED:.*]] = tensor.cast %[[SHAPE]] : tensor<2xindex> to tensor<?xindex>
// CHECK: "use"(%[[CASTED]])
%0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16> %0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex> %1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
"use"(%1) : (tensor<?xindex>) -> () "use"(%1) : (tensor<?xindex>) -> ()
@ -19,7 +20,8 @@ func @shape_of_unary(%arg : tensor<?x32xi16>) {
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>)
func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) { func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex> // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex>
// CHECK: "use"(%[[SHAPE]]) // CHECK: %[[CASTED:.*]] = tensor.cast %[[SHAPE]] : tensor<2xindex> to tensor<?xindex>
// CHECK: "use"(%[[CASTED]])
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16> %0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
%1 = mhlo.subtract %0, %arg1 : tensor<?x32xf16> %1 = mhlo.subtract %0, %arg1 : tensor<?x32xf16>
%2 = shape.shape_of %1 : tensor<?x32xf16> -> tensor<?xindex> %2 = shape.shape_of %1 : tensor<?x32xf16> -> tensor<?xindex>