diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 49d4652..7e4506c 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -250,23 +250,12 @@ struct MergeAssumingOpsPattern : public OpRewritePattern { llvm::dyn_cast_or_null(op->getPrevNode()); if (!preceding_op) return failure(); - // For now, both witnesses must be cstr_broadcastable. - // TODO(frgossen): Generalize this. - auto bcastable_a = op.witness().getDefiningOp(); - auto bcastable_b = - preceding_op.witness().getDefiningOp(); - if (!bcastable_a || !bcastable_b) return failure(); - // Merge witnesses. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(preceding_op); - SmallVector new_operands; - new_operands.append(bcastable_a->getOperands().begin(), - bcastable_a->getOperands().end()); - new_operands.append(bcastable_b->getOperands().begin(), - bcastable_b->getOperands().end()); - Value new_witness = rewriter.create( - bcastable_a.getLoc(), new_operands); + Value new_witness = rewriter.create( + op.witness().getDefiningOp()->getLoc(), + ValueRange{preceding_op.witness(), op.witness()}); // Merge assuming ops. Block *body_a = preceding_op.getBody(); diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index 62b773f..1c6c305 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -260,8 +260,8 @@ func @merge_assuming_ops(%arg0: tensor, %arg1 : tensor, // CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]] // CHECK: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]] // CHECK: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]] - // CHECK: %[[WITNESS_MERGED:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]] - // CHECK: %[[MERGED:.*]]:2 = shape.assuming %[[WITNESS_MERGED]] + // CHECK: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]] + // CHECK: %[[MERGED:.*]]:2 = shape.assuming %[[COMBINED_WITNESS]] // CHECK-SAME: { // CHECK: "some.op" // CHECK: %[[RESULT0:.*]] = "some.producer" @@ -297,11 +297,13 @@ func @merge_assuming_ops(%arg0: tensor, %arg1 : tensor, // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor) func @sub_sub(%arg0: tensor, %arg1 : tensor, %arg2: tensor) -> tensor { - // CHECK: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]] - // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]] - // CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]] - // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE1]] - // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] + // CHECK-DAG: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]] + // CHECK-DAG: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]] + // CHECK-DAG: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]] + // CHECK-DAG: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]] + // CHECK-DAG: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]] + // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]] // CHECK-SAME: { // CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]] // CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]