[MHLO] Move broadcasts over elementwise ops
Move up dynamic broadcasts and shape computations to allow for more fusion opportunities. PiperOrigin-RevId: 364514158
This commit is contained in:
parent
98debb127d
commit
54f37abc28
|
@ -76,6 +76,67 @@ struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// We can only move up broadcasting ops that apply to the result of a
|
||||||
|
// shape-preserving operation. For now, we restrict this to unary operations.
|
||||||
|
// TODO(frgossen): Generalize this to n-ary operations.
|
||||||
|
bool isDynamicBroadcastInDimOpMovable(Value operand) {
|
||||||
|
Operation *producer_op = operand.getDefiningOp();
|
||||||
|
return producer_op != nullptr &&
|
||||||
|
producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() &&
|
||||||
|
producer_op->hasTrait<OpTrait::Elementwise>() &&
|
||||||
|
producer_op->getNumOperands() == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||||
|
struct MoveUpBroadcastInDimOpConversion
|
||||||
|
: public OpConversionPattern<DynamicBroadcastInDimOp> {
|
||||||
|
explicit MoveUpBroadcastInDimOpConversion(MLIRContext *context)
|
||||||
|
: OpConversionPattern<DynamicBroadcastInDimOp>(context) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
DynamicBroadcastInDimOp bcast_op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
// We can only move up broadcasting ops that apply to the result of a
|
||||||
|
// shape-preserving operation.
|
||||||
|
DynamicBroadcastInDimOp::Adaptor transformed(operands);
|
||||||
|
if (!isDynamicBroadcastInDimOpMovable(transformed.operand()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Materialize broadcast on operands.
|
||||||
|
SmallVector<Value, 2> bcasted_operands;
|
||||||
|
Location loc = bcast_op.getLoc();
|
||||||
|
ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape();
|
||||||
|
Operation *producer_op = transformed.operand().getDefiningOp();
|
||||||
|
for (Value operand : producer_op->getOperands()) {
|
||||||
|
// The broadcast only works on ranked operations.
|
||||||
|
auto operand_ty = operand.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!operand_ty) {
|
||||||
|
return bcast_op.emitError()
|
||||||
|
<< "Can only move up broadcasts over ranked tensor operands.";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto bcasted_operand_ty =
|
||||||
|
RankedTensorType::get(ty_shape, operand_ty.getElementType());
|
||||||
|
bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>(
|
||||||
|
loc, bcasted_operand_ty, operand, transformed.output_dimensions(),
|
||||||
|
bcast_op.broadcast_dimensions()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a copy of the producer op with the new broadcasted operands.
|
||||||
|
OperationState new_producer_op_state(
|
||||||
|
loc, producer_op->getName().getStringRef(), bcasted_operands,
|
||||||
|
bcast_op.getType(), producer_op->getAttrs());
|
||||||
|
Operation *new_producer_op =
|
||||||
|
rewriter.createOperation(new_producer_op_state);
|
||||||
|
|
||||||
|
// The original result of the broadcast now falls directly out of the new
|
||||||
|
// producer op. Use it instead.
|
||||||
|
rewriter.replaceOp(bcast_op, new_producer_op->getResults());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct MoveUpDynamicBroadcastsForFusionPass
|
struct MoveUpDynamicBroadcastsForFusionPass
|
||||||
: public PassWrapper<MoveUpDynamicBroadcastsForFusionPass, FunctionPass> {
|
: public PassWrapper<MoveUpDynamicBroadcastsForFusionPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
@ -108,12 +169,17 @@ void PopulateMoveUpDynamicBroadcastsForFusionLegality(
|
||||||
tensor::TensorDialect>();
|
tensor::TensorDialect>();
|
||||||
target->addDynamicallyLegalOp<shape::ShapeOfOp>(
|
target->addDynamicallyLegalOp<shape::ShapeOfOp>(
|
||||||
[](shape::ShapeOfOp op) { return !IsShapeOfOpMovable(op.arg()); });
|
[](shape::ShapeOfOp op) { return !IsShapeOfOpMovable(op.arg()); });
|
||||||
|
target->addDynamicallyLegalOp<DynamicBroadcastInDimOp>(
|
||||||
|
[](DynamicBroadcastInDimOp op) {
|
||||||
|
return !isDynamicBroadcastInDimOpMovable(op.operand());
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<ShapeOfOpConversion>(context);
|
patterns->insert<ShapeOfOpConversion,
|
||||||
|
MoveUpBroadcastInDimOpConversion>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,3 +26,54 @@ func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
|
||||||
"use"(%2) : (tensor<?xindex>) -> ()
|
"use"(%2) : (tensor<?xindex>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Broadcasts can be moved up over shape-preserving operations.
|
||||||
|
// CHECK-LABEL: @bcast
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>, %[[OUT_DIMS:.*]]: tensor<3xindex>)
|
||||||
|
func @bcast(%arg : tensor<?x32xi16>, %out_dims : tensor<3xindex>)
|
||||||
|
-> tensor<?x?x32xf16> {
|
||||||
|
// CHECK: %[[BCASTED_OPERAND:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[OUT_DIMS]])
|
||||||
|
// CHECK-SAME: broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xi16>, tensor<3xindex>) -> tensor<?x?x32xi16>
|
||||||
|
// CHECK: "mhlo.convert"(%[[BCASTED_OPERAND]]) : (tensor<?x?x32xi16>) -> tensor<?x?x32xf16>
|
||||||
|
%0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
|
||||||
|
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) {
|
||||||
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } :
|
||||||
|
(tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
|
||||||
|
return %1 : tensor<?x?x32xf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Exemplary IR as it appears in the lowering with `tf.Sub` and `tf.Cast`.
|
||||||
|
// CHECK-LABEL: @cast_sub
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xi16>, %[[ARG1:.*]]: tensor<?x?x32xf16>) -> tensor<?x?x32xf16>
|
||||||
|
func @cast_sub(%arg0: tensor<?x32xi16>, %arg1: tensor<?x?x32xf16>)
|
||||||
|
-> tensor<?x?x32xf16> {
|
||||||
|
// CHECK-NOT: convert
|
||||||
|
// CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %{{.*}})
|
||||||
|
// CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %{{.*}})
|
||||||
|
// CHECK: %[[CONVERTED_BCASTED_ARG0:.*]] = "mhlo.convert"(%[[BCASTED_ARG0]]) : (tensor<?x?x32xi16>) -> tensor<?x?x32xf16>
|
||||||
|
// CHECK: %{{.*}} = mhlo.subtract %[[BCASTED_ARG1]], %[[CONVERTED_BCASTED_ARG0]] : tensor<?x?x32xf16>
|
||||||
|
%0 = "mhlo.convert"(%arg0) : (tensor<?x32xi16>) -> tensor<?x32xf16>
|
||||||
|
%1 = shape.shape_of %arg1 : tensor<?x?x32xf16> -> tensor<?xindex>
|
||||||
|
%2 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
|
||||||
|
%3 = shape.cstr_broadcastable %1, %2 : tensor<?xindex>, tensor<?xindex>
|
||||||
|
%4 = shape.assuming %3 -> (tensor<?x?x32xf16>) {
|
||||||
|
%5 = shape.shape_of %arg1 : tensor<?x?x32xf16> -> tensor<?xindex>
|
||||||
|
%6 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
|
||||||
|
%7 = shape.broadcast %5, %6 : tensor<?xindex>, tensor<?xindex>
|
||||||
|
-> tensor<?xindex>
|
||||||
|
%8 = tensor.cast %7 : tensor<?xindex> to tensor<3xindex>
|
||||||
|
%9 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %8) {
|
||||||
|
broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} :
|
||||||
|
(tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
|
||||||
|
%10 = "mhlo.dynamic_broadcast_in_dim"(%0, %8) {
|
||||||
|
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} :
|
||||||
|
(tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
|
||||||
|
%11 = mhlo.subtract %9, %10 : tensor<?x?x32xf16>
|
||||||
|
shape.assuming_yield %11 : tensor<?x?x32xf16>
|
||||||
|
}
|
||||||
|
return %4 : tensor<?x?x32xf16>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue