[MLIR][HLO] Generalize merged witnesses in `move-up-dynamic-broadcasts-for-fusion`

PiperOrigin-RevId: 368012460
This commit is contained in:
A. Unique TensorFlower 2021-04-12 08:53:02 -07:00 committed by TensorFlow MLIR Team
parent 1007995ea2
commit 0ec0a23e61
2 changed files with 12 additions and 21 deletions

View File

@ -250,23 +250,12 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
if (!preceding_op) return failure();
// For now, both witnesses must be cstr_broadcastable.
// TODO(frgossen): Generalize this.
auto bcastable_a = op.witness().getDefiningOp<shape::CstrBroadcastableOp>();
auto bcastable_b =
preceding_op.witness().getDefiningOp<shape::CstrBroadcastableOp>();
if (!bcastable_a || !bcastable_b) return failure();
// Merge witnesses.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(preceding_op);
SmallVector<Value, 8> new_operands;
new_operands.append(bcastable_a->getOperands().begin(),
bcastable_a->getOperands().end());
new_operands.append(bcastable_b->getOperands().begin(),
bcastable_b->getOperands().end());
Value new_witness = rewriter.create<shape::CstrBroadcastableOp>(
bcastable_a.getLoc(), new_operands);
Value new_witness = rewriter.create<shape::AssumingAllOp>(
op.witness().getDefiningOp()->getLoc(),
ValueRange{preceding_op.witness(), op.witness()});
// Merge assuming ops.
Block *body_a = preceding_op.getBody();

View File

@ -260,8 +260,8 @@ func @merge_assuming_ops(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
// CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
// CHECK: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]]
// CHECK: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]]
// CHECK: %[[WITNESS_MERGED:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]]
// CHECK: %[[MERGED:.*]]:2 = shape.assuming %[[WITNESS_MERGED]]
// CHECK: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]]
// CHECK: %[[MERGED:.*]]:2 = shape.assuming %[[COMBINED_WITNESS]]
// CHECK-SAME: {
// CHECK: "some.op"
// CHECK: %[[RESULT0:.*]] = "some.producer"
@ -297,11 +297,13 @@ func @merge_assuming_ops(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
%arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> {
// CHECK: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
// CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
// CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE1]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]]
// CHECK-DAG: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
// CHECK-DAG: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]]
// CHECK-DAG: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]]
// CHECK-DAG: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]]
// CHECK-SAME: {
// CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
// CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]