[MLIR][MHLO] Add pattern to move ops into the assuming region
This will eventually allow to make assuming regions' constraints independent from each other. PiperOrigin-RevId: 365985081
This commit is contained in:
parent
5ec66775d4
commit
eade942635
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
@ -96,6 +97,71 @@ struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Move operation into a preceeding assuming op. This allows to process
|
||||
/// operations that depend on the assuming op's results. It will eventually
|
||||
/// allow to make assuming regions' constraints independent from each other.
|
||||
template <typename OpTy>
|
||||
struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only move into immediately preceeding `assuming` op.
|
||||
auto assuming_op =
|
||||
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
||||
if (!assuming_op) return failure();
|
||||
|
||||
Block *body = assuming_op.getBody();
|
||||
auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
|
||||
|
||||
// Find the operands to use if the op was within the assuming region. We
|
||||
// will later use their copies, as we copy the assuming op and its body.
|
||||
SmallVector<Value, 8> new_operands_unmapped;
|
||||
for (auto operand : op->getOperands()) {
|
||||
new_operands_unmapped.push_back(operand);
|
||||
for (auto result : llvm::enumerate(assuming_op->getResults())) {
|
||||
if (result.value() == operand)
|
||||
new_operands_unmapped.back() = yield_op->getOperand(result.index());
|
||||
}
|
||||
}
|
||||
|
||||
// Insert the rewritten assuming op right before the old one.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(assuming_op);
|
||||
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
|
||||
assuming_op.getLoc(), assuming_op.witness(),
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
// Copy body.
|
||||
BlockAndValueMapping mapping;
|
||||
for (auto &nested : body->without_terminator())
|
||||
b.clone(nested, mapping);
|
||||
|
||||
// Copy op into the new body and use the mapped operands.
|
||||
SmallVector<Value, 2> new_operands;
|
||||
for (Value v_unmapped : new_operands_unmapped) {
|
||||
Value v = mapping.lookupOrDefault(v_unmapped);
|
||||
new_operands.push_back(v);
|
||||
}
|
||||
Value new_op = b.create<OpTy>(loc, op->getResultTypes(), new_operands,
|
||||
op->getAttrs());
|
||||
|
||||
// Yield the previous results and also the new one.
|
||||
SmallVector<Value, 2> mapped_results;
|
||||
for (auto result : yield_op.operands())
|
||||
mapped_results.push_back(mapping.lookupOrDefault(result));
|
||||
mapped_results.push_back(new_op);
|
||||
return mapped_results;
|
||||
});
|
||||
|
||||
// Replace the assuming op and the root op with the corresponding result
|
||||
// value.
|
||||
ValueRange new_assuming_op_results = new_assuming_op->getResults();
|
||||
rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back());
|
||||
rewriter.replaceOp(op, new_assuming_op_results.back());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||
struct MoveUpBroadcastInDimOpPattern
|
||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||
|
@ -168,6 +234,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
|||
// clang-format off
|
||||
patterns->insert<
|
||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
||||
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
|
||||
MoveUpBroadcastInDimOpPattern,
|
||||
ShapeReificationPattern>(context);
|
||||
// clang-format on
|
||||
|
|
|
@ -110,3 +110,60 @@ func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>,
|
|||
%1 = shape.cstr_broadcastable %0, %c : tensor<?xindex>, tensor<?xindex>
|
||||
return %1 : !shape.witness
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @move_shape_of_into_assuming
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?x32xf32>, %[[ARG2:.*]]: tensor<?x32xf32>)
|
||||
func @move_shape_of_into_assuming(%arg0 : !shape.witness,
|
||||
%arg1 : tensor<?x32xf32>, %arg2 : tensor<?x32xf32>) -> tensor<3xindex> {
|
||||
// CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<?x32xf32>, tensor<?x32xf32>, tensor<3xindex>) {
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG2]]
|
||||
// CHECK: shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[SHAPE]]
|
||||
// CHECK: }
|
||||
// CHECK-NOT: shape_of
|
||||
// CHECK: return %[[ASSUMING_RESULTS]]#2
|
||||
%0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) {
|
||||
shape.assuming_yield %arg1, %arg2 : tensor<?x32xf32>, tensor<?x32xf32>
|
||||
}
|
||||
%1 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
|
||||
return %1 : tensor<3xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @move_cstr_broadcastable_into_assuming
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>)
|
||||
func @move_cstr_broadcastable_into_assuming(%arg0 : !shape.witness,
|
||||
%arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness {
|
||||
// CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<2xindex>, tensor<3xindex>, !shape.witness) {
|
||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
|
||||
// CHECK: shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[WITNESS]]
|
||||
// CHECK: }
|
||||
// CHECK-NOT: cstr_broadcastable
|
||||
// CHECK: return %[[ASSUMING_RESULTS]]#2
|
||||
%0:2 = shape.assuming %arg0 -> (tensor<2xindex>, tensor<3xindex>) {
|
||||
shape.assuming_yield %arg1, %arg2 : tensor<2xindex>, tensor<3xindex>
|
||||
}
|
||||
%1 = shape.cstr_broadcastable %arg1, %0#1 : tensor<2xindex>, tensor<3xindex>
|
||||
return %1 : !shape.witness
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @not_move_shape_of_into_assuming
|
||||
func @not_move_shape_of_into_assuming(%arg0 : !shape.witness,
|
||||
%arg1 : tensor<?x32xf32>, %arg2 : tensor<?x32xf32>) -> tensor<3xindex> {
|
||||
// CHECK: shape.assuming
|
||||
// CHECK-SAME: {
|
||||
// CHECK-NOT: shape_of
|
||||
// CHECK: }
|
||||
// CHECK: "some.other.op"
|
||||
// CHECK: shape_of
|
||||
%0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) {
|
||||
shape.assuming_yield %arg1, %arg2 : tensor<?x32xf32>, tensor<?x32xf32>
|
||||
}
|
||||
"some.other.op"() : () -> ()
|
||||
%2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
|
||||
return %2 : tensor<3xindex>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue