[MLIR][MHLO] Apply patterns in MoveUpDynamicBroadcastsForFusionPass greedily
PiperOrigin-RevId: 365556488
This commit is contained in:
parent
238c1d8a92
commit
fb819c1de8
|
@ -33,76 +33,60 @@ limitations under the License.
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
bool IsShapeOfOpMovable(Value arg) {
|
struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
|
||||||
return arg.getDefiningOp<InferShapedTypeOpInterface>();
|
explicit ShapeReificationPattern(MLIRContext *context)
|
||||||
}
|
: OpRewritePattern<shape::ShapeOfOp>(context) {
|
||||||
|
|
||||||
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
|
|
||||||
explicit ShapeOfOpConversion(MLIRContext *context)
|
|
||||||
: OpConversionPattern<shape::ShapeOfOp>(context) {
|
|
||||||
// Recursively reify until we hit an op that doesn't support it.
|
// Recursively reify until we hit an op that doesn't support it.
|
||||||
setHasBoundedRewriteRecursion();
|
setHasBoundedRewriteRecursion();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
|
||||||
shape::ShapeOfOp op, ArrayRef<Value> operands,
|
PatternRewriter &rewriter) const override {
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
shape::ShapeOfOp::Adaptor transformed(operands);
|
|
||||||
|
|
||||||
// Only reify shape computation if operand allows for it.
|
// Only reify shape computation if operand allows for it.
|
||||||
if (!IsShapeOfOpMovable(transformed.arg())) return failure();
|
auto shape_origin = op.arg().getDefiningOp<InferShapedTypeOpInterface>();
|
||||||
|
if (!shape_origin) return failure();
|
||||||
|
|
||||||
auto shape_origin =
|
llvm::SmallVector<Value, 1> reifications;
|
||||||
transformed.arg().getDefiningOp<InferShapedTypeOpInterface>();
|
if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reifications)))
|
||||||
llvm::SmallVector<Value, 1> reified_shapes;
|
|
||||||
if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes)))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
assert(reifications.size() == 1);
|
||||||
|
Value reified_shape = reifications.front();
|
||||||
|
|
||||||
assert(reified_shapes.size() == 1);
|
// Insert cast if needed.
|
||||||
Value reified_shape = reified_shapes.front();
|
|
||||||
if (reified_shape.getType() != op.getType()) {
|
if (reified_shape.getType() != op.getType()) {
|
||||||
reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
|
reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
|
||||||
reified_shape);
|
reified_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(op, reified_shapes.front());
|
rewriter.replaceOp(op, reified_shape);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// We can only move up broadcasting ops that apply to the result of a
|
|
||||||
// shape-preserving operation.
|
|
||||||
bool isDynamicBroadcastInDimOpMovable(Value operand) {
|
|
||||||
Operation *producer_op = operand.getDefiningOp();
|
|
||||||
return producer_op != nullptr &&
|
|
||||||
producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() &&
|
|
||||||
producer_op->hasTrait<OpTrait::Elementwise>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||||
struct MoveUpBroadcastInDimOpConversion
|
struct MoveUpBroadcastInDimOpPattern
|
||||||
: public OpConversionPattern<DynamicBroadcastInDimOp> {
|
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||||
explicit MoveUpBroadcastInDimOpConversion(MLIRContext *context)
|
using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
|
||||||
: OpConversionPattern<DynamicBroadcastInDimOp>(context) {}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast_op,
|
||||||
DynamicBroadcastInDimOp bcast_op, ArrayRef<Value> operands,
|
PatternRewriter &rewriter) const override {
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
Operation *producer_op = bcast_op.operand().getDefiningOp();
|
||||||
DynamicBroadcastInDimOp::Adaptor transformed(operands);
|
if (!producer_op ||
|
||||||
if (!isDynamicBroadcastInDimOpMovable(transformed.operand()))
|
!producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() ||
|
||||||
|
!producer_op->hasTrait<OpTrait::Elementwise>()) {
|
||||||
return failure();
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
// Materialize broadcast on operands.
|
// Materialize broadcast on operands.
|
||||||
SmallVector<Value, 2> bcasted_operands;
|
SmallVector<Value, 2> bcasted_operands;
|
||||||
Location loc = bcast_op.getLoc();
|
Location loc = bcast_op.getLoc();
|
||||||
ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape();
|
ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape();
|
||||||
Operation *producer_op = transformed.operand().getDefiningOp();
|
|
||||||
for (Value operand : producer_op->getOperands()) {
|
for (Value operand : producer_op->getOperands()) {
|
||||||
// The broadcast only works on ranked operations.
|
// The broadcast only works on ranked operations.
|
||||||
auto operand_ty = operand.getType().dyn_cast<RankedTensorType>();
|
auto operand_ty = operand.getType().dyn_cast<RankedTensorType>();
|
||||||
|
@ -114,7 +98,7 @@ struct MoveUpBroadcastInDimOpConversion
|
||||||
auto bcasted_operand_ty =
|
auto bcasted_operand_ty =
|
||||||
RankedTensorType::get(ty_shape, operand_ty.getElementType());
|
RankedTensorType::get(ty_shape, operand_ty.getElementType());
|
||||||
bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>(
|
bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>(
|
||||||
loc, bcasted_operand_ty, operand, transformed.output_dimensions(),
|
loc, bcasted_operand_ty, operand, bcast_op.output_dimensions(),
|
||||||
bcast_op.broadcast_dimensions()));
|
bcast_op.broadcast_dimensions()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,18 +124,14 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
// Setup target legality.
|
|
||||||
MLIRContext &ctx = getContext();
|
|
||||||
ConversionTarget target(ctx);
|
|
||||||
PopulateMoveUpDynamicBroadcastsForFusionLegality(&target);
|
|
||||||
|
|
||||||
// Populate rewrite patterns.
|
// Populate rewrite patterns.
|
||||||
OwningRewritePatternList patterns(&ctx);
|
MLIRContext *ctx = &getContext();
|
||||||
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns);
|
RewritePatternSet patterns(ctx);
|
||||||
|
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns);
|
||||||
|
|
||||||
// Apply transformation.
|
// Apply transformation.
|
||||||
if (failed(applyPartialConversion(getFunction(), target,
|
if (failed(
|
||||||
std::move(patterns)))) {
|
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -159,23 +139,11 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void PopulateMoveUpDynamicBroadcastsForFusionLegality(
|
|
||||||
ConversionTarget *target) {
|
|
||||||
target->addLegalDialect<MhloDialect, StandardOpsDialect, shape::ShapeDialect,
|
|
||||||
tensor::TensorDialect>();
|
|
||||||
target->addDynamicallyLegalOp<shape::ShapeOfOp>(
|
|
||||||
[](shape::ShapeOfOp op) { return !IsShapeOfOpMovable(op.arg()); });
|
|
||||||
target->addDynamicallyLegalOp<DynamicBroadcastInDimOp>(
|
|
||||||
[](DynamicBroadcastInDimOp op) {
|
|
||||||
return !isDynamicBroadcastInDimOpMovable(op.operand());
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<ShapeOfOpConversion,
|
patterns->insert<ShapeReificationPattern,
|
||||||
MoveUpBroadcastInDimOpConversion>(context);
|
MoveUpBroadcastInDimOpPattern>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>)
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>)
|
||||||
func @shape_of_unary(%arg : tensor<?x32xi16>) {
|
func @shape_of_unary(%arg : tensor<?x32xi16>) {
|
||||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x32xi16> -> tensor<2xindex>
|
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x32xi16> -> tensor<2xindex>
|
||||||
// CHECK: "use"(%[[SHAPE]])
|
// CHECK: %[[CASTED:.*]] = tensor.cast %[[SHAPE]] : tensor<2xindex> to tensor<?xindex>
|
||||||
|
// CHECK: "use"(%[[CASTED]])
|
||||||
%0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
|
%0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
|
||||||
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
|
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
|
||||||
"use"(%1) : (tensor<?xindex>) -> ()
|
"use"(%1) : (tensor<?xindex>) -> ()
|
||||||
|
@ -19,7 +20,8 @@ func @shape_of_unary(%arg : tensor<?x32xi16>) {
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>)
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>)
|
||||||
func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
|
func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
|
||||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex>
|
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
// CHECK: "use"(%[[SHAPE]])
|
// CHECK: %[[CASTED:.*]] = tensor.cast %[[SHAPE]] : tensor<2xindex> to tensor<?xindex>
|
||||||
|
// CHECK: "use"(%[[CASTED]])
|
||||||
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
|
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
|
||||||
%1 = mhlo.subtract %0, %arg1 : tensor<?x32xf16>
|
%1 = mhlo.subtract %0, %arg1 : tensor<?x32xf16>
|
||||||
%2 = shape.shape_of %1 : tensor<?x32xf16> -> tensor<?xindex>
|
%2 = shape.shape_of %1 : tensor<?x32xf16> -> tensor<?xindex>
|
||||||
|
|
Loading…
Reference in New Issue