diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 7707155..3bac0ce 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -162,6 +162,36 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern { } }; +/// Move operation out of assuming op. This is only valid for +/// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It +/// will eventually allow to make assuming regions' constraints independent from +/// each other. +template +struct MoveOutOfAssumingOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Must be inside of an assuming op. + auto assuming_op = op->template getParentOfType(); + if (!assuming_op) return failure(); + + // Operands must not be defined within the assuming op. + Block *body = assuming_op.getBody(); + auto is_available = [&](Value v) { + Operation *def = v.getDefiningOp(); + return def == nullptr || def->getBlock() != body; + }; + if (!llvm::all_of(op->getOperands(), is_available)) return failure(); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(assuming_op); + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + op->getOperands(), op->getAttrs()); + return success(); + } +}; + // TODO(frgossen): Only move up broadcasting operations if there is a consumer. struct MoveUpBroadcastInDimOpPattern : public OpRewritePattern { @@ -236,6 +266,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( InlineBroadcastedShapeOperandsPattern, MoveIntoAssumingOpPattern, MoveIntoAssumingOpPattern, + MoveOutOfAssumingOpPattern, + MoveOutOfAssumingOpPattern, MoveUpBroadcastInDimOpPattern, ShapeReificationPattern>(context); // clang-format on diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index 9713004..894d254 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -114,36 +114,40 @@ func @inline_bcasted_shape_operands(%a : tensor, %b : tensor, // ----- // CHECK-LABEL: @move_shape_of_into_assuming -// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor) +// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor) func @move_shape_of_into_assuming(%arg0 : !shape.witness, - %arg1 : tensor, %arg2 : tensor) -> tensor<3xindex> { + %arg1 : tensor) -> tensor<3xindex> { // CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor, tensor, tensor<3xindex>) { - // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG2]] - // CHECK: shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[SHAPE]] + // CHECK: %[[DUMMY_TENSOR:.*]] = "dummy.tensor"() : () -> tensor + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[DUMMY_TENSOR]] + // CHECK: shape.assuming_yield %[[ARG1]], %[[DUMMY_TENSOR]], %[[SHAPE]] // CHECK: } // CHECK-NOT: shape_of // CHECK: return %[[ASSUMING_RESULTS]]#2 %0:2 = shape.assuming %arg0 -> (tensor, tensor) { - shape.assuming_yield %arg1, %arg2 : tensor, tensor + %1 = "dummy.tensor"() : () -> tensor + shape.assuming_yield %arg1, %1 : tensor, tensor } - %1 = shape.shape_of %0#1 : tensor -> tensor<3xindex> - return %1 : tensor<3xindex> + %2 = shape.shape_of %0#1 : tensor -> tensor<3xindex> + return %2 : tensor<3xindex> } // ----- // CHECK-LABEL: @move_cstr_broadcastable_into_assuming -// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>) +// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>) func @move_cstr_broadcastable_into_assuming(%arg0 : !shape.witness, - %arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness { + %arg1 : tensor<2xindex>) -> !shape.witness { // CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<2xindex>, tensor<3xindex>, !shape.witness) { - // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]] - // CHECK: shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[WITNESS]] + // CHECK: %[[DUMMY_TENSOR:.*]] = "dummy.tensor"() : () -> tensor<3xindex> + // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[DUMMY_TENSOR]] + // CHECK: shape.assuming_yield %[[ARG1]], %[[DUMMY_TENSOR]], %[[WITNESS]] // CHECK: } // CHECK-NOT: cstr_broadcastable // CHECK: return %[[ASSUMING_RESULTS]]#2 %0:2 = shape.assuming %arg0 -> (tensor<2xindex>, tensor<3xindex>) { - shape.assuming_yield %arg1, %arg2 : tensor<2xindex>, tensor<3xindex> + %1 = "dummy.tensor"() : () -> tensor<3xindex> + shape.assuming_yield %arg1, %1 : tensor<2xindex>, tensor<3xindex> } %1 = shape.cstr_broadcastable %arg1, %0#1 : tensor<2xindex>, tensor<3xindex> return %1 : !shape.witness @@ -167,3 +171,61 @@ func @not_move_shape_of_into_assuming(%arg0 : !shape.witness, %2 = shape.shape_of %0#1 : tensor -> tensor<3xindex> return %2 : tensor<3xindex> } + +// ----- + +// CHECK-LABEL: @move_cstr_broadcastable_out_of_assuming +// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>) +func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness, + %arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness { + // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]] + // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (!shape.witness) { + // CHECK-NOT: cstr_broadcastable + // CHECK: shape.assuming_yield %[[WITNESS]] + // CHECK: } + // CHECK: return %[[ASSUMING_RESULT]] + %0 = shape.assuming %arg0 -> (!shape.witness) { + %1 = shape.cstr_broadcastable %arg1, %arg2 : tensor<2xindex>, tensor<3xindex> + shape.assuming_yield %1 : !shape.witness + } + return %0 : !shape.witness +} + +// ----- + +// CHECK-LABEL: @move_shape_of_out_of_assuming +// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>) +func @move_shape_of_out_of_assuming(%arg0 : !shape.witness, + %arg1 : tensor<2x?xf32>) -> tensor<2xindex> { + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG1]] + // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (tensor<2xindex>) { + // CHECK-NOT: shape_of + // CHECK: shape.assuming_yield %[[SHAPE]] + // CHECK: } + // CHECK: return %[[ASSUMING_RESULT]] + %0 = shape.assuming %arg0 -> (tensor<2xindex>) { + %1 = shape.shape_of %arg1 : tensor<2x?xf32> -> tensor<2xindex> + shape.assuming_yield %1 : tensor<2xindex> + } + return %0 : tensor<2xindex> +} + +// ----- + +// CHECK-LABEL: @not_move_shape_of_out_of_assuming +// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>) +func @not_move_shape_of_out_of_assuming(%arg0 : !shape.witness, + %arg1 : tensor<2x?xf32>) -> tensor<2xindex> { + // CHECK-NOT: shape_of + // CHECK: shape.assuming + // CHECK-SAME: { + // CHECK: "some.tensor" + // CHECK: shape_of + // CHECK: } + %0 = shape.assuming %arg0 -> (tensor<2xindex>) { + %1 = "some.tensor"() : () -> tensor<2x?xf32> + %2 = shape.shape_of %1 : tensor<2x?xf32> -> tensor<2xindex> + shape.assuming_yield %2 : tensor<2xindex> + } + return %0 : tensor<2xindex> +}