[MLIR][MHLO] Add pattern to move ops into the assuming region

This will eventually allow to make assuming regions' constraints independent
from each other.

PiperOrigin-RevId: 365985081
This commit is contained in:
A. Unique TensorFlower 2021-03-31 01:22:20 -07:00 committed by TensorFlow MLIR Team
parent 5ec66775d4
commit eade942635
2 changed files with 125 additions and 0 deletions

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
@ -96,6 +97,71 @@ struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
} }
}; };
/// Move operation into a preceeding assuming op. This allows to process
/// operations that depend on the assuming op's results. It will eventually
/// allow to make assuming regions' constraints independent from each other.
template <typename OpTy>
struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Only move into immediately preceeding `assuming` op.
auto assuming_op =
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
if (!assuming_op) return failure();
Block *body = assuming_op.getBody();
auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
// Find the operands to use if the op was within the assuming region. We
// will later use their copies, as we copy the assuming op and its body.
SmallVector<Value, 8> new_operands_unmapped;
for (auto operand : op->getOperands()) {
new_operands_unmapped.push_back(operand);
for (auto result : llvm::enumerate(assuming_op->getResults())) {
if (result.value() == operand)
new_operands_unmapped.back() = yield_op->getOperand(result.index());
}
}
// Insert the rewritten assuming op right before the old one.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(assuming_op);
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
assuming_op.getLoc(), assuming_op.witness(),
[&](OpBuilder &b, Location loc) {
// Copy body.
BlockAndValueMapping mapping;
for (auto &nested : body->without_terminator())
b.clone(nested, mapping);
// Copy op into the new body and use the mapped operands.
SmallVector<Value, 2> new_operands;
for (Value v_unmapped : new_operands_unmapped) {
Value v = mapping.lookupOrDefault(v_unmapped);
new_operands.push_back(v);
}
Value new_op = b.create<OpTy>(loc, op->getResultTypes(), new_operands,
op->getAttrs());
// Yield the previous results and also the new one.
SmallVector<Value, 2> mapped_results;
for (auto result : yield_op.operands())
mapped_results.push_back(mapping.lookupOrDefault(result));
mapped_results.push_back(new_op);
return mapped_results;
});
// Replace the assuming op and the root op with the corresponding result
// value.
ValueRange new_assuming_op_results = new_assuming_op->getResults();
rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back());
rewriter.replaceOp(op, new_assuming_op_results.back());
return success();
}
};
// TODO(frgossen): Only move up broadcasting operations if there is a consumer. // TODO(frgossen): Only move up broadcasting operations if there is a consumer.
struct MoveUpBroadcastInDimOpPattern struct MoveUpBroadcastInDimOpPattern
: public OpRewritePattern<DynamicBroadcastInDimOp> { : public OpRewritePattern<DynamicBroadcastInDimOp> {
@ -168,6 +234,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
// clang-format off // clang-format off
patterns->insert< patterns->insert<
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>, InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
MoveUpBroadcastInDimOpPattern, MoveUpBroadcastInDimOpPattern,
ShapeReificationPattern>(context); ShapeReificationPattern>(context);
// clang-format on // clang-format on

View File

@ -110,3 +110,60 @@ func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>,
%1 = shape.cstr_broadcastable %0, %c : tensor<?xindex>, tensor<?xindex> %1 = shape.cstr_broadcastable %0, %c : tensor<?xindex>, tensor<?xindex>
return %1 : !shape.witness return %1 : !shape.witness
} }
// -----
// CHECK-LABEL: @move_shape_of_into_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?x32xf32>, %[[ARG2:.*]]: tensor<?x32xf32>)
func @move_shape_of_into_assuming(%arg0 : !shape.witness,
%arg1 : tensor<?x32xf32>, %arg2 : tensor<?x32xf32>) -> tensor<3xindex> {
// CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<?x32xf32>, tensor<?x32xf32>, tensor<3xindex>) {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG2]]
// CHECK: shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[SHAPE]]
// CHECK: }
// CHECK-NOT: shape_of
// CHECK: return %[[ASSUMING_RESULTS]]#2
%0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) {
shape.assuming_yield %arg1, %arg2 : tensor<?x32xf32>, tensor<?x32xf32>
}
%1 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
return %1 : tensor<3xindex>
}
// -----
// CHECK-LABEL: @move_cstr_broadcastable_into_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>)
func @move_cstr_broadcastable_into_assuming(%arg0 : !shape.witness,
%arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness {
// CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<2xindex>, tensor<3xindex>, !shape.witness) {
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
// CHECK: shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[WITNESS]]
// CHECK: }
// CHECK-NOT: cstr_broadcastable
// CHECK: return %[[ASSUMING_RESULTS]]#2
%0:2 = shape.assuming %arg0 -> (tensor<2xindex>, tensor<3xindex>) {
shape.assuming_yield %arg1, %arg2 : tensor<2xindex>, tensor<3xindex>
}
%1 = shape.cstr_broadcastable %arg1, %0#1 : tensor<2xindex>, tensor<3xindex>
return %1 : !shape.witness
}
// -----
// CHECK-LABEL: @not_move_shape_of_into_assuming
func @not_move_shape_of_into_assuming(%arg0 : !shape.witness,
%arg1 : tensor<?x32xf32>, %arg2 : tensor<?x32xf32>) -> tensor<3xindex> {
// CHECK: shape.assuming
// CHECK-SAME: {
// CHECK-NOT: shape_of
// CHECK: }
// CHECK: "some.other.op"
// CHECK: shape_of
%0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) {
shape.assuming_yield %arg1, %arg2 : tensor<?x32xf32>, tensor<?x32xf32>
}
"some.other.op"() : () -> ()
%2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
return %2 : tensor<3xindex>
}