[MLIR][MHLO] Add pattern to move ops into the assuming region
This will eventually allow to make assuming regions' constraints independent from each other. PiperOrigin-RevId: 365985081
This commit is contained in:
parent
5ec66775d4
commit
eade942635
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
@ -96,6 +97,71 @@ struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Move operation into a preceeding assuming op. This allows to process
|
||||||
|
/// operations that depend on the assuming op's results. It will eventually
|
||||||
|
/// allow to make assuming regions' constraints independent from each other.
|
||||||
|
template <typename OpTy>
|
||||||
|
struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
||||||
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Only move into immediately preceeding `assuming` op.
|
||||||
|
auto assuming_op =
|
||||||
|
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
||||||
|
if (!assuming_op) return failure();
|
||||||
|
|
||||||
|
Block *body = assuming_op.getBody();
|
||||||
|
auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
|
||||||
|
|
||||||
|
// Find the operands to use if the op was within the assuming region. We
|
||||||
|
// will later use their copies, as we copy the assuming op and its body.
|
||||||
|
SmallVector<Value, 8> new_operands_unmapped;
|
||||||
|
for (auto operand : op->getOperands()) {
|
||||||
|
new_operands_unmapped.push_back(operand);
|
||||||
|
for (auto result : llvm::enumerate(assuming_op->getResults())) {
|
||||||
|
if (result.value() == operand)
|
||||||
|
new_operands_unmapped.back() = yield_op->getOperand(result.index());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert the rewritten assuming op right before the old one.
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPoint(assuming_op);
|
||||||
|
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
|
||||||
|
assuming_op.getLoc(), assuming_op.witness(),
|
||||||
|
[&](OpBuilder &b, Location loc) {
|
||||||
|
// Copy body.
|
||||||
|
BlockAndValueMapping mapping;
|
||||||
|
for (auto &nested : body->without_terminator())
|
||||||
|
b.clone(nested, mapping);
|
||||||
|
|
||||||
|
// Copy op into the new body and use the mapped operands.
|
||||||
|
SmallVector<Value, 2> new_operands;
|
||||||
|
for (Value v_unmapped : new_operands_unmapped) {
|
||||||
|
Value v = mapping.lookupOrDefault(v_unmapped);
|
||||||
|
new_operands.push_back(v);
|
||||||
|
}
|
||||||
|
Value new_op = b.create<OpTy>(loc, op->getResultTypes(), new_operands,
|
||||||
|
op->getAttrs());
|
||||||
|
|
||||||
|
// Yield the previous results and also the new one.
|
||||||
|
SmallVector<Value, 2> mapped_results;
|
||||||
|
for (auto result : yield_op.operands())
|
||||||
|
mapped_results.push_back(mapping.lookupOrDefault(result));
|
||||||
|
mapped_results.push_back(new_op);
|
||||||
|
return mapped_results;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Replace the assuming op and the root op with the corresponding result
|
||||||
|
// value.
|
||||||
|
ValueRange new_assuming_op_results = new_assuming_op->getResults();
|
||||||
|
rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back());
|
||||||
|
rewriter.replaceOp(op, new_assuming_op_results.back());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||||
struct MoveUpBroadcastInDimOpPattern
|
struct MoveUpBroadcastInDimOpPattern
|
||||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||||
|
@ -168,6 +234,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||||
|
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
||||||
|
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
|
||||||
MoveUpBroadcastInDimOpPattern,
|
MoveUpBroadcastInDimOpPattern,
|
||||||
ShapeReificationPattern>(context);
|
ShapeReificationPattern>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
|
@ -110,3 +110,60 @@ func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>,
|
||||||
%1 = shape.cstr_broadcastable %0, %c : tensor<?xindex>, tensor<?xindex>
|
%1 = shape.cstr_broadcastable %0, %c : tensor<?xindex>, tensor<?xindex>
|
||||||
return %1 : !shape.witness
|
return %1 : !shape.witness
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @move_shape_of_into_assuming
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?x32xf32>, %[[ARG2:.*]]: tensor<?x32xf32>)
|
||||||
|
func @move_shape_of_into_assuming(%arg0 : !shape.witness,
|
||||||
|
%arg1 : tensor<?x32xf32>, %arg2 : tensor<?x32xf32>) -> tensor<3xindex> {
|
||||||
|
// CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<?x32xf32>, tensor<?x32xf32>, tensor<3xindex>) {
|
||||||
|
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG2]]
|
||||||
|
// CHECK: shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[SHAPE]]
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK-NOT: shape_of
|
||||||
|
// CHECK: return %[[ASSUMING_RESULTS]]#2
|
||||||
|
%0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) {
|
||||||
|
shape.assuming_yield %arg1, %arg2 : tensor<?x32xf32>, tensor<?x32xf32>
|
||||||
|
}
|
||||||
|
%1 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
|
||||||
|
return %1 : tensor<3xindex>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @move_cstr_broadcastable_into_assuming
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>)
|
||||||
|
func @move_cstr_broadcastable_into_assuming(%arg0 : !shape.witness,
|
||||||
|
%arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness {
|
||||||
|
// CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<2xindex>, tensor<3xindex>, !shape.witness) {
|
||||||
|
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
|
||||||
|
// CHECK: shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[WITNESS]]
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK-NOT: cstr_broadcastable
|
||||||
|
// CHECK: return %[[ASSUMING_RESULTS]]#2
|
||||||
|
%0:2 = shape.assuming %arg0 -> (tensor<2xindex>, tensor<3xindex>) {
|
||||||
|
shape.assuming_yield %arg1, %arg2 : tensor<2xindex>, tensor<3xindex>
|
||||||
|
}
|
||||||
|
%1 = shape.cstr_broadcastable %arg1, %0#1 : tensor<2xindex>, tensor<3xindex>
|
||||||
|
return %1 : !shape.witness
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @not_move_shape_of_into_assuming
|
||||||
|
func @not_move_shape_of_into_assuming(%arg0 : !shape.witness,
|
||||||
|
%arg1 : tensor<?x32xf32>, %arg2 : tensor<?x32xf32>) -> tensor<3xindex> {
|
||||||
|
// CHECK: shape.assuming
|
||||||
|
// CHECK-SAME: {
|
||||||
|
// CHECK-NOT: shape_of
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: "some.other.op"
|
||||||
|
// CHECK: shape_of
|
||||||
|
%0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) {
|
||||||
|
shape.assuming_yield %arg1, %arg2 : tensor<?x32xf32>, tensor<?x32xf32>
|
||||||
|
}
|
||||||
|
"some.other.op"() : () -> ()
|
||||||
|
%2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
|
||||||
|
return %2 : tensor<3xindex>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue