[MLIR][MHLO] Apply patterns in MoveUpDynamicBroadcastsForFusionPass greedily
PiperOrigin-RevId: 365556488
This commit is contained in:
parent
238c1d8a92
commit
fb819c1de8
|
@ -33,76 +33,60 @@ limitations under the License.
|
|||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
bool IsShapeOfOpMovable(Value arg) {
|
||||
return arg.getDefiningOp<InferShapedTypeOpInterface>();
|
||||
}
|
||||
|
||||
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
|
||||
explicit ShapeOfOpConversion(MLIRContext *context)
|
||||
: OpConversionPattern<shape::ShapeOfOp>(context) {
|
||||
struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
|
||||
explicit ShapeReificationPattern(MLIRContext *context)
|
||||
: OpRewritePattern<shape::ShapeOfOp>(context) {
|
||||
// Recursively reify until we hit an op that doesn't support it.
|
||||
setHasBoundedRewriteRecursion();
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
shape::ShapeOfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
shape::ShapeOfOp::Adaptor transformed(operands);
|
||||
|
||||
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only reify shape computation if operand allows for it.
|
||||
if (!IsShapeOfOpMovable(transformed.arg())) return failure();
|
||||
auto shape_origin = op.arg().getDefiningOp<InferShapedTypeOpInterface>();
|
||||
if (!shape_origin) return failure();
|
||||
|
||||
auto shape_origin =
|
||||
transformed.arg().getDefiningOp<InferShapedTypeOpInterface>();
|
||||
llvm::SmallVector<Value, 1> reified_shapes;
|
||||
if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes)))
|
||||
llvm::SmallVector<Value, 1> reifications;
|
||||
if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reifications)))
|
||||
return failure();
|
||||
assert(reifications.size() == 1);
|
||||
Value reified_shape = reifications.front();
|
||||
|
||||
assert(reified_shapes.size() == 1);
|
||||
Value reified_shape = reified_shapes.front();
|
||||
// Insert cast if needed.
|
||||
if (reified_shape.getType() != op.getType()) {
|
||||
reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
|
||||
reified_shape);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reified_shapes.front());
|
||||
rewriter.replaceOp(op, reified_shape);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// We can only move up broadcasting ops that apply to the result of a
|
||||
// shape-preserving operation.
|
||||
bool isDynamicBroadcastInDimOpMovable(Value operand) {
|
||||
Operation *producer_op = operand.getDefiningOp();
|
||||
return producer_op != nullptr &&
|
||||
producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() &&
|
||||
producer_op->hasTrait<OpTrait::Elementwise>();
|
||||
}
|
||||
|
||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||
struct MoveUpBroadcastInDimOpConversion
|
||||
: public OpConversionPattern<DynamicBroadcastInDimOp> {
|
||||
explicit MoveUpBroadcastInDimOpConversion(MLIRContext *context)
|
||||
: OpConversionPattern<DynamicBroadcastInDimOp>(context) {}
|
||||
struct MoveUpBroadcastInDimOpPattern
|
||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||
using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
DynamicBroadcastInDimOp bcast_op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
DynamicBroadcastInDimOp::Adaptor transformed(operands);
|
||||
if (!isDynamicBroadcastInDimOpMovable(transformed.operand()))
|
||||
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *producer_op = bcast_op.operand().getDefiningOp();
|
||||
if (!producer_op ||
|
||||
!producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() ||
|
||||
!producer_op->hasTrait<OpTrait::Elementwise>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Materialize broadcast on operands.
|
||||
SmallVector<Value, 2> bcasted_operands;
|
||||
Location loc = bcast_op.getLoc();
|
||||
ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape();
|
||||
Operation *producer_op = transformed.operand().getDefiningOp();
|
||||
for (Value operand : producer_op->getOperands()) {
|
||||
// The broadcast only works on ranked operations.
|
||||
auto operand_ty = operand.getType().dyn_cast<RankedTensorType>();
|
||||
|
@ -114,7 +98,7 @@ struct MoveUpBroadcastInDimOpConversion
|
|||
auto bcasted_operand_ty =
|
||||
RankedTensorType::get(ty_shape, operand_ty.getElementType());
|
||||
bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>(
|
||||
loc, bcasted_operand_ty, operand, transformed.output_dimensions(),
|
||||
loc, bcasted_operand_ty, operand, bcast_op.output_dimensions(),
|
||||
bcast_op.broadcast_dimensions()));
|
||||
}
|
||||
|
||||
|
@ -140,18 +124,14 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
|||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
// Setup target legality.
|
||||
MLIRContext &ctx = getContext();
|
||||
ConversionTarget target(ctx);
|
||||
PopulateMoveUpDynamicBroadcastsForFusionLegality(&target);
|
||||
|
||||
// Populate rewrite patterns.
|
||||
OwningRewritePatternList patterns(&ctx);
|
||||
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns);
|
||||
MLIRContext *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns);
|
||||
|
||||
// Apply transformation.
|
||||
if (failed(applyPartialConversion(getFunction(), target,
|
||||
std::move(patterns)))) {
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -159,23 +139,11 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
|||
|
||||
} // namespace
|
||||
|
||||
void PopulateMoveUpDynamicBroadcastsForFusionLegality(
|
||||
ConversionTarget *target) {
|
||||
target->addLegalDialect<MhloDialect, StandardOpsDialect, shape::ShapeDialect,
|
||||
tensor::TensorDialect>();
|
||||
target->addDynamicallyLegalOp<shape::ShapeOfOp>(
|
||||
[](shape::ShapeOfOp op) { return !IsShapeOfOpMovable(op.arg()); });
|
||||
target->addDynamicallyLegalOp<DynamicBroadcastInDimOp>(
|
||||
[](DynamicBroadcastInDimOp op) {
|
||||
return !isDynamicBroadcastInDimOpMovable(op.operand());
|
||||
});
|
||||
}
|
||||
|
||||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<ShapeOfOpConversion,
|
||||
MoveUpBroadcastInDimOpConversion>(context);
|
||||
patterns->insert<ShapeReificationPattern,
|
||||
MoveUpBroadcastInDimOpPattern>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>)
|
||||
func @shape_of_unary(%arg : tensor<?x32xi16>) {
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x32xi16> -> tensor<2xindex>
|
||||
// CHECK: "use"(%[[SHAPE]])
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[SHAPE]] : tensor<2xindex> to tensor<?xindex>
|
||||
// CHECK: "use"(%[[CASTED]])
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
|
||||
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
|
||||
"use"(%1) : (tensor<?xindex>) -> ()
|
||||
|
@ -19,7 +20,8 @@ func @shape_of_unary(%arg : tensor<?x32xi16>) {
|
|||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>)
|
||||
func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex>
|
||||
// CHECK: "use"(%[[SHAPE]])
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[SHAPE]] : tensor<2xindex> to tensor<?xindex>
|
||||
// CHECK: "use"(%[[CASTED]])
|
||||
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
|
||||
%1 = mhlo.subtract %0, %arg1 : tensor<?x32xf16>
|
||||
%2 = shape.shape_of %1 : tensor<?x32xf16> -> tensor<?xindex>
|
||||
|
|
Loading…
Reference in New Issue