[MLIR][HLO] Canonicalize chained broadcasts

Compose two subsequent `dynamic_broadcast_in_dim` ops into one.

PiperOrigin-RevId: 367630360
This commit is contained in:
A. Unique TensorFlower 2021-04-09 07:34:32 -07:00 committed by TensorFlow MLIR Team
parent fdb653788c
commit 6d2209e301
2 changed files with 35 additions and 5 deletions

View File

@ -898,6 +898,7 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) {
return success();
}
namespace {
// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
// BroadcastInDimOp.
class DynamicBroadcastInDimOpNotActuallyDynamic
@ -915,9 +916,40 @@ class DynamicBroadcastInDimOpNotActuallyDynamic
}
};
class ChainedDynamicBroadcastInDimCanonicalization
: public OpRewritePattern<DynamicBroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast,
PatternRewriter& rewriter) const override {
auto preceding_bcast =
bcast.operand().getDefiningOp<DynamicBroadcastInDimOp>();
if (!preceding_bcast) return failure();
// Compose broadcast dimensions.
DenseIntElementsAttr preceding_bcast_dims =
preceding_bcast.broadcast_dimensions();
DenseIntElementsAttr bcast_dims = bcast.broadcast_dimensions();
SmallVector<APInt, 4> composition;
for (APInt preceding_dim : preceding_bcast_dims) {
auto composed_dim = bcast_dims.getValue({preceding_dim.getZExtValue()})
.cast<IntegerAttr>();
composition.push_back(composed_dim.getValue());
}
auto composed_bcast_dims =
DenseIntElementsAttr::get(preceding_bcast_dims.getType(), composition);
rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
bcast, bcast.getType(), preceding_bcast.operand(),
bcast.output_dimensions(), composed_bcast_dims);
return success();
}
};
} // namespace
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
results.insert<ChainedDynamicBroadcastInDimCanonicalization,
DynamicBroadcastInDimOpNotActuallyDynamic,
DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2,
DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>(
context);

View File

@ -300,12 +300,10 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
// CHECK: %[[ASSUMING_RESULTS:.*]]:4 = shape.assuming %[[WITNESS]]
// CHECK-SAME: {
// CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
// CHECK: %[[PARTIALLY_BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]>
// CHECK: %[[PARTIALLY_BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]>
// CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
// CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]>
// CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]]
// CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]]
// CHECK: shape.assuming_yield %{{.*}}, %{{.*}}, %{{.*}}, %[[RESULT]]