[MLIR][HLO] Generalize merged witnesses in `move-up-dynamic-broadcasts-for-fusion`
PiperOrigin-RevId: 368012460
This commit is contained in:
parent
1007995ea2
commit
0ec0a23e61
|
@ -250,23 +250,12 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
|
||||||
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
||||||
if (!preceding_op) return failure();
|
if (!preceding_op) return failure();
|
||||||
|
|
||||||
// For now, both witnesses must be cstr_broadcastable.
|
|
||||||
// TODO(frgossen): Generalize this.
|
|
||||||
auto bcastable_a = op.witness().getDefiningOp<shape::CstrBroadcastableOp>();
|
|
||||||
auto bcastable_b =
|
|
||||||
preceding_op.witness().getDefiningOp<shape::CstrBroadcastableOp>();
|
|
||||||
if (!bcastable_a || !bcastable_b) return failure();
|
|
||||||
|
|
||||||
// Merge witnesses.
|
// Merge witnesses.
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPoint(preceding_op);
|
rewriter.setInsertionPoint(preceding_op);
|
||||||
SmallVector<Value, 8> new_operands;
|
Value new_witness = rewriter.create<shape::AssumingAllOp>(
|
||||||
new_operands.append(bcastable_a->getOperands().begin(),
|
op.witness().getDefiningOp()->getLoc(),
|
||||||
bcastable_a->getOperands().end());
|
ValueRange{preceding_op.witness(), op.witness()});
|
||||||
new_operands.append(bcastable_b->getOperands().begin(),
|
|
||||||
bcastable_b->getOperands().end());
|
|
||||||
Value new_witness = rewriter.create<shape::CstrBroadcastableOp>(
|
|
||||||
bcastable_a.getLoc(), new_operands);
|
|
||||||
|
|
||||||
// Merge assuming ops.
|
// Merge assuming ops.
|
||||||
Block *body_a = preceding_op.getBody();
|
Block *body_a = preceding_op.getBody();
|
||||||
|
|
|
@ -260,8 +260,8 @@ func @merge_assuming_ops(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||||
// CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
|
// CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
|
||||||
// CHECK: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]]
|
// CHECK: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]]
|
||||||
// CHECK: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]]
|
// CHECK: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]]
|
||||||
// CHECK: %[[WITNESS_MERGED:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]]
|
// CHECK: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]]
|
||||||
// CHECK: %[[MERGED:.*]]:2 = shape.assuming %[[WITNESS_MERGED]]
|
// CHECK: %[[MERGED:.*]]:2 = shape.assuming %[[COMBINED_WITNESS]]
|
||||||
// CHECK-SAME: {
|
// CHECK-SAME: {
|
||||||
// CHECK: "some.op"
|
// CHECK: "some.op"
|
||||||
// CHECK: %[[RESULT0:.*]] = "some.producer"
|
// CHECK: %[[RESULT0:.*]] = "some.producer"
|
||||||
|
@ -297,11 +297,13 @@ func @merge_assuming_ops(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
|
||||||
func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||||
%arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> {
|
%arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> {
|
||||||
// CHECK: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
|
// CHECK-DAG: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
|
||||||
// CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
|
// CHECK-DAG: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
|
||||||
// CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
|
// CHECK-DAG: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
|
||||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE1]]
|
// CHECK-DAG: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]]
|
||||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]]
|
// CHECK-DAG: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]]
|
||||||
|
// CHECK-DAG: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]]
|
||||||
|
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]]
|
||||||
// CHECK-SAME: {
|
// CHECK-SAME: {
|
||||||
// CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
|
// CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
|
||||||
// CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
|
// CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
|
||||||
|
|
Loading…
Reference in New Issue