[MLIR][MHLO] Move `cstr_broadcastable` and `shape_of` out of `assuming` regions
Add pattern to move operations out of assuming op. This only valid for constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It will eventually allow to make assuming regions' constraints independent from each other so that they can be merged. PiperOrigin-RevId: 365993145
This commit is contained in:
		
							parent
							
								
									af2aaa6144
								
							
						
					
					
						commit
						8ade5d78c8
					
				| 
						 | 
				
			
			@ -162,6 +162,36 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
 | 
			
		|||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/// Move operation out of assuming op. This is only valid for
 | 
			
		||||
/// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It
 | 
			
		||||
/// will eventually allow to make assuming regions' constraints independent from
 | 
			
		||||
/// each other.
 | 
			
		||||
template <typename OpTy>
 | 
			
		||||
struct MoveOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
 | 
			
		||||
  using OpRewritePattern<OpTy>::OpRewritePattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(OpTy op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    // Must be inside of an assuming op.
 | 
			
		||||
    auto assuming_op = op->template getParentOfType<shape::AssumingOp>();
 | 
			
		||||
    if (!assuming_op) return failure();
 | 
			
		||||
 | 
			
		||||
    // Operands must not be defined within the assuming op.
 | 
			
		||||
    Block *body = assuming_op.getBody();
 | 
			
		||||
    auto is_available = [&](Value v) {
 | 
			
		||||
      Operation *def = v.getDefiningOp();
 | 
			
		||||
      return def == nullptr || def->getBlock() != body;
 | 
			
		||||
    };
 | 
			
		||||
    if (!llvm::all_of(op->getOperands(), is_available)) return failure();
 | 
			
		||||
 | 
			
		||||
    OpBuilder::InsertionGuard guard(rewriter);
 | 
			
		||||
    rewriter.setInsertionPoint(assuming_op);
 | 
			
		||||
    rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
 | 
			
		||||
                                      op->getOperands(), op->getAttrs());
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
 | 
			
		||||
struct MoveUpBroadcastInDimOpPattern
 | 
			
		||||
    : public OpRewritePattern<DynamicBroadcastInDimOp> {
 | 
			
		||||
| 
						 | 
				
			
			@ -236,6 +266,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
 | 
			
		|||
      InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
 | 
			
		||||
      MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
 | 
			
		||||
      MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
 | 
			
		||||
      MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
 | 
			
		||||
      MoveOutOfAssumingOpPattern<shape::ShapeOfOp>,
 | 
			
		||||
      MoveUpBroadcastInDimOpPattern,
 | 
			
		||||
      ShapeReificationPattern>(context);
 | 
			
		||||
  // clang-format on
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -114,36 +114,40 @@ func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>,
 | 
			
		|||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: @move_shape_of_into_assuming
 | 
			
		||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?x32xf32>, %[[ARG2:.*]]: tensor<?x32xf32>)
 | 
			
		||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?x32xf32>)
 | 
			
		||||
func @move_shape_of_into_assuming(%arg0 : !shape.witness,
 | 
			
		||||
    %arg1 : tensor<?x32xf32>, %arg2 : tensor<?x32xf32>) -> tensor<3xindex> {
 | 
			
		||||
    %arg1 : tensor<?x32xf32>) -> tensor<3xindex> {
 | 
			
		||||
  // CHECK:     %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<?x32xf32>, tensor<?x32xf32>, tensor<3xindex>) {
 | 
			
		||||
  // CHECK:       %[[SHAPE:.*]] = shape.shape_of %[[ARG2]]
 | 
			
		||||
  // CHECK:       shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[SHAPE]]
 | 
			
		||||
  // CHECK:       %[[DUMMY_TENSOR:.*]] = "dummy.tensor"() : () -> tensor<?x32xf32>
 | 
			
		||||
  // CHECK:       %[[SHAPE:.*]] = shape.shape_of %[[DUMMY_TENSOR]]
 | 
			
		||||
  // CHECK:       shape.assuming_yield %[[ARG1]], %[[DUMMY_TENSOR]], %[[SHAPE]]
 | 
			
		||||
  // CHECK:     }
 | 
			
		||||
  // CHECK-NOT: shape_of
 | 
			
		||||
  // CHECK:     return %[[ASSUMING_RESULTS]]#2
 | 
			
		||||
  %0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) {
 | 
			
		||||
    shape.assuming_yield %arg1, %arg2 : tensor<?x32xf32>, tensor<?x32xf32>
 | 
			
		||||
    %1 = "dummy.tensor"() : () -> tensor<?x32xf32>
 | 
			
		||||
    shape.assuming_yield %arg1, %1 : tensor<?x32xf32>, tensor<?x32xf32>
 | 
			
		||||
  }
 | 
			
		||||
  %1 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
 | 
			
		||||
  return %1 : tensor<3xindex>
 | 
			
		||||
  %2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
 | 
			
		||||
  return %2 : tensor<3xindex>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: @move_cstr_broadcastable_into_assuming
 | 
			
		||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>)
 | 
			
		||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>)
 | 
			
		||||
func @move_cstr_broadcastable_into_assuming(%arg0 : !shape.witness,
 | 
			
		||||
    %arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness {
 | 
			
		||||
    %arg1 : tensor<2xindex>) -> !shape.witness {
 | 
			
		||||
  // CHECK:     %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<2xindex>, tensor<3xindex>, !shape.witness) {
 | 
			
		||||
  // CHECK:       %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
 | 
			
		||||
  // CHECK:       shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[WITNESS]]
 | 
			
		||||
  // CHECK:       %[[DUMMY_TENSOR:.*]] = "dummy.tensor"() : () -> tensor<3xindex>
 | 
			
		||||
  // CHECK:       %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[DUMMY_TENSOR]]
 | 
			
		||||
  // CHECK:       shape.assuming_yield %[[ARG1]], %[[DUMMY_TENSOR]], %[[WITNESS]]
 | 
			
		||||
  // CHECK:     }
 | 
			
		||||
  // CHECK-NOT: cstr_broadcastable
 | 
			
		||||
  // CHECK:     return %[[ASSUMING_RESULTS]]#2
 | 
			
		||||
  %0:2 = shape.assuming %arg0 -> (tensor<2xindex>, tensor<3xindex>) {
 | 
			
		||||
    shape.assuming_yield %arg1, %arg2 : tensor<2xindex>, tensor<3xindex>
 | 
			
		||||
    %1 = "dummy.tensor"() : () -> tensor<3xindex>
 | 
			
		||||
    shape.assuming_yield %arg1, %1 : tensor<2xindex>, tensor<3xindex>
 | 
			
		||||
  }
 | 
			
		||||
  %1 = shape.cstr_broadcastable %arg1, %0#1 : tensor<2xindex>, tensor<3xindex>
 | 
			
		||||
  return %1 : !shape.witness
 | 
			
		||||
| 
						 | 
				
			
			@ -167,3 +171,61 @@ func @not_move_shape_of_into_assuming(%arg0 : !shape.witness,
 | 
			
		|||
  %2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
 | 
			
		||||
  return %2 : tensor<3xindex>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: @move_cstr_broadcastable_out_of_assuming
 | 
			
		||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>)
 | 
			
		||||
func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness,
 | 
			
		||||
    %arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness {
 | 
			
		||||
  // CHECK:     %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
 | 
			
		||||
  // CHECK:     %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (!shape.witness) {
 | 
			
		||||
  // CHECK-NOT:   cstr_broadcastable
 | 
			
		||||
  // CHECK:       shape.assuming_yield %[[WITNESS]]
 | 
			
		||||
  // CHECK:     }
 | 
			
		||||
  // CHECK:     return %[[ASSUMING_RESULT]]
 | 
			
		||||
  %0 = shape.assuming %arg0 -> (!shape.witness) {
 | 
			
		||||
    %1 = shape.cstr_broadcastable %arg1, %arg2 : tensor<2xindex>, tensor<3xindex>
 | 
			
		||||
    shape.assuming_yield %1 : !shape.witness
 | 
			
		||||
  }
 | 
			
		||||
  return %0 : !shape.witness
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: @move_shape_of_out_of_assuming
 | 
			
		||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
 | 
			
		||||
func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,
 | 
			
		||||
    %arg1 : tensor<2x?xf32>) -> tensor<2xindex> {
 | 
			
		||||
  // CHECK:     %[[SHAPE:.*]] = shape.shape_of %[[ARG1]]
 | 
			
		||||
  // CHECK:     %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (tensor<2xindex>) {
 | 
			
		||||
  // CHECK-NOT:   shape_of
 | 
			
		||||
  // CHECK:       shape.assuming_yield %[[SHAPE]]
 | 
			
		||||
  // CHECK:     }
 | 
			
		||||
  // CHECK:     return %[[ASSUMING_RESULT]]
 | 
			
		||||
  %0 = shape.assuming %arg0 -> (tensor<2xindex>) {
 | 
			
		||||
    %1 = shape.shape_of %arg1 : tensor<2x?xf32> -> tensor<2xindex>
 | 
			
		||||
    shape.assuming_yield %1 : tensor<2xindex>
 | 
			
		||||
  }
 | 
			
		||||
  return %0 : tensor<2xindex>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: @not_move_shape_of_out_of_assuming
 | 
			
		||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
 | 
			
		||||
func @not_move_shape_of_out_of_assuming(%arg0 : !shape.witness,
 | 
			
		||||
    %arg1 : tensor<2x?xf32>) -> tensor<2xindex> {
 | 
			
		||||
  // CHECK-NOT:  shape_of
 | 
			
		||||
  // CHECK:      shape.assuming
 | 
			
		||||
  // CHECK-SAME: {
 | 
			
		||||
  // CHECK:        "some.tensor"
 | 
			
		||||
  // CHECK:        shape_of
 | 
			
		||||
  // CHECK:      }
 | 
			
		||||
  %0 = shape.assuming %arg0 -> (tensor<2xindex>) {
 | 
			
		||||
    %1 = "some.tensor"() : () -> tensor<2x?xf32>
 | 
			
		||||
    %2 = shape.shape_of %1 : tensor<2x?xf32> -> tensor<2xindex>
 | 
			
		||||
    shape.assuming_yield %2 : tensor<2xindex>
 | 
			
		||||
  }
 | 
			
		||||
  return %0 : tensor<2xindex>
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue