[MLIR][HLO] Canonicalize chained broadcasts
Compose two subsequent `dynamic_broadcast_in_dim` ops into one. PiperOrigin-RevId: 367630360
This commit is contained in:
parent
fdb653788c
commit
6d2209e301
|
@ -898,6 +898,7 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
|
||||
// BroadcastInDimOp.
|
||||
class DynamicBroadcastInDimOpNotActuallyDynamic
|
||||
|
@ -915,9 +916,40 @@ class DynamicBroadcastInDimOpNotActuallyDynamic
|
|||
}
|
||||
};
|
||||
|
||||
class ChainedDynamicBroadcastInDimCanonicalization
|
||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto preceding_bcast =
|
||||
bcast.operand().getDefiningOp<DynamicBroadcastInDimOp>();
|
||||
if (!preceding_bcast) return failure();
|
||||
|
||||
// Compose broadcast dimensions.
|
||||
DenseIntElementsAttr preceding_bcast_dims =
|
||||
preceding_bcast.broadcast_dimensions();
|
||||
DenseIntElementsAttr bcast_dims = bcast.broadcast_dimensions();
|
||||
SmallVector<APInt, 4> composition;
|
||||
for (APInt preceding_dim : preceding_bcast_dims) {
|
||||
auto composed_dim = bcast_dims.getValue({preceding_dim.getZExtValue()})
|
||||
.cast<IntegerAttr>();
|
||||
composition.push_back(composed_dim.getValue());
|
||||
}
|
||||
auto composed_bcast_dims =
|
||||
DenseIntElementsAttr::get(preceding_bcast_dims.getType(), composition);
|
||||
|
||||
rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
|
||||
bcast, bcast.getType(), preceding_bcast.operand(),
|
||||
bcast.output_dimensions(), composed_bcast_dims);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
|
||||
results.insert<ChainedDynamicBroadcastInDimCanonicalization,
|
||||
DynamicBroadcastInDimOpNotActuallyDynamic,
|
||||
DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2,
|
||||
DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>(
|
||||
context);
|
||||
|
|
|
@ -300,12 +300,10 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
|||
// CHECK: %[[ASSUMING_RESULTS:.*]]:4 = shape.assuming %[[WITNESS]]
|
||||
// CHECK-SAME: {
|
||||
// CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
|
||||
// CHECK: %[[PARTIALLY_BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]>
|
||||
// CHECK: %[[PARTIALLY_BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]>
|
||||
// CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
|
||||
// CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]>
|
||||
// CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||
// CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||
// CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||
// CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||
// CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]]
|
||||
// CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]]
|
||||
// CHECK: shape.assuming_yield %{{.*}}, %{{.*}}, %{{.*}}, %[[RESULT]]
|
||||
|
|
Loading…
Reference in New Issue