diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 20a255f..a0f63c9 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -898,6 +898,7 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) { return success(); } +namespace { // If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary // BroadcastInDimOp. class DynamicBroadcastInDimOpNotActuallyDynamic @@ -915,9 +916,40 @@ class DynamicBroadcastInDimOpNotActuallyDynamic } }; +class ChainedDynamicBroadcastInDimCanonicalization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast, + PatternRewriter& rewriter) const override { + auto preceding_bcast = + bcast.operand().getDefiningOp(); + if (!preceding_bcast) return failure(); + + // Compose broadcast dimensions. + DenseIntElementsAttr preceding_bcast_dims = + preceding_bcast.broadcast_dimensions(); + DenseIntElementsAttr bcast_dims = bcast.broadcast_dimensions(); + SmallVector composition; + for (APInt preceding_dim : preceding_bcast_dims) { + auto composed_dim = bcast_dims.getValue({preceding_dim.getZExtValue()}) + .cast(); + composition.push_back(composed_dim.getValue()); + } + auto composed_bcast_dims = + DenseIntElementsAttr::get(preceding_bcast_dims.getType(), composition); + + rewriter.replaceOpWithNewOp( + bcast, bcast.getType(), preceding_bcast.operand(), + bcast.output_dimensions(), composed_bcast_dims); + return success(); + } +}; +} // namespace + void DynamicBroadcastInDimOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { - results.insert( context); diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index 86e73b9..8f0401d 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -300,12 +300,10 @@ func @sub_sub(%arg0: tensor, %arg1 : tensor, // CHECK: %[[ASSUMING_RESULTS:.*]]:4 = shape.assuming %[[WITNESS]] // CHECK-SAME: { // CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]] - // CHECK: %[[PARTIALLY_BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]> - // CHECK: %[[PARTIALLY_BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]> // CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]] // CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]> - // CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} - // CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} + // CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} + // CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} // CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] // CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]] // CHECK: shape.assuming_yield %{{.*}}, %{{.*}}, %{{.*}}, %[[RESULT]]