[MLIR][HLO] Canonicalize chained broadcasts
Compose two subsequent `dynamic_broadcast_in_dim` ops into one. PiperOrigin-RevId: 367630360
This commit is contained in:
parent
fdb653788c
commit
6d2209e301
|
@ -898,6 +898,7 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
|
// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
|
||||||
// BroadcastInDimOp.
|
// BroadcastInDimOp.
|
||||||
class DynamicBroadcastInDimOpNotActuallyDynamic
|
class DynamicBroadcastInDimOpNotActuallyDynamic
|
||||||
|
@ -915,9 +916,40 @@ class DynamicBroadcastInDimOpNotActuallyDynamic
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ChainedDynamicBroadcastInDimCanonicalization
|
||||||
|
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
auto preceding_bcast =
|
||||||
|
bcast.operand().getDefiningOp<DynamicBroadcastInDimOp>();
|
||||||
|
if (!preceding_bcast) return failure();
|
||||||
|
|
||||||
|
// Compose broadcast dimensions.
|
||||||
|
DenseIntElementsAttr preceding_bcast_dims =
|
||||||
|
preceding_bcast.broadcast_dimensions();
|
||||||
|
DenseIntElementsAttr bcast_dims = bcast.broadcast_dimensions();
|
||||||
|
SmallVector<APInt, 4> composition;
|
||||||
|
for (APInt preceding_dim : preceding_bcast_dims) {
|
||||||
|
auto composed_dim = bcast_dims.getValue({preceding_dim.getZExtValue()})
|
||||||
|
.cast<IntegerAttr>();
|
||||||
|
composition.push_back(composed_dim.getValue());
|
||||||
|
}
|
||||||
|
auto composed_bcast_dims =
|
||||||
|
DenseIntElementsAttr::get(preceding_bcast_dims.getType(), composition);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
|
||||||
|
bcast, bcast.getType(), preceding_bcast.operand(),
|
||||||
|
bcast.output_dimensions(), composed_bcast_dims);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
|
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
|
results.insert<ChainedDynamicBroadcastInDimCanonicalization,
|
||||||
|
DynamicBroadcastInDimOpNotActuallyDynamic,
|
||||||
DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2,
|
DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2,
|
||||||
DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>(
|
DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>(
|
||||||
context);
|
context);
|
||||||
|
|
|
@ -300,12 +300,10 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||||
// CHECK: %[[ASSUMING_RESULTS:.*]]:4 = shape.assuming %[[WITNESS]]
|
// CHECK: %[[ASSUMING_RESULTS:.*]]:4 = shape.assuming %[[WITNESS]]
|
||||||
// CHECK-SAME: {
|
// CHECK-SAME: {
|
||||||
// CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
|
// CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
|
||||||
// CHECK: %[[PARTIALLY_BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]>
|
|
||||||
// CHECK: %[[PARTIALLY_BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]>
|
|
||||||
// CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
|
// CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
|
||||||
// CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]>
|
// CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]>
|
||||||
// CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
// CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||||
// CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
// CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||||
// CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]]
|
// CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]]
|
||||||
// CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]]
|
// CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]]
|
||||||
// CHECK: shape.assuming_yield %{{.*}}, %{{.*}}, %{{.*}}, %[[RESULT]]
|
// CHECK: shape.assuming_yield %{{.*}}, %{{.*}}, %{{.*}}, %[[RESULT]]
|
||||||
|
|
Loading…
Reference in New Issue