428 lines
17 KiB
C++
428 lines
17 KiB
C++
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
|
|
==============================================================================*/
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
namespace mlir {
|
|
namespace mhlo {
|
|
namespace {
|
|
|
|
struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
|
|
explicit ShapeReificationPattern(MLIRContext *context)
|
|
: OpRewritePattern<shape::ShapeOfOp>(context) {
|
|
// Recursively reify until we hit an op that doesn't support it.
|
|
setHasBoundedRewriteRecursion();
|
|
}
|
|
|
|
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
// Only reify shape computation if operand allows for it.
|
|
auto shape_origin = op.arg().getDefiningOp<InferShapedTypeOpInterface>();
|
|
if (!shape_origin) return failure();
|
|
|
|
llvm::SmallVector<Value, 1> reifications;
|
|
if (failed(shape_origin.reifyReturnTypeShapes(
|
|
rewriter, shape_origin->getOperands(), reifications)))
|
|
return failure();
|
|
assert(reifications.size() == 1);
|
|
Value reified_shape = reifications.front();
|
|
|
|
// Insert cast if needed.
|
|
if (reified_shape.getType() != op.getType()) {
|
|
reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
|
|
reified_shape);
|
|
}
|
|
|
|
rewriter.replaceOp(op, reified_shape);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <typename OpTy>
|
|
struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
PatternRewriter &rewriter) const override {
|
|
// Find all the shape operands, direct and indirect.
|
|
SmallVector<Value, 8> inlined_operands;
|
|
for (Value direct : op->getOperands()) {
|
|
if (auto bcast_op = direct.getDefiningOp<shape::BroadcastOp>()) {
|
|
for (Value indirect : bcast_op->getOperands())
|
|
inlined_operands.push_back(indirect);
|
|
} else {
|
|
inlined_operands.push_back(direct);
|
|
}
|
|
}
|
|
|
|
// Only rewrite if it makes a difference.
|
|
if (inlined_operands.size() == op.getNumOperands()) return failure();
|
|
|
|
// Inline shape operands.
|
|
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
|
|
inlined_operands, op->getAttrs());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
LogicalResult MoveIntoAssumingOpMatchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) {
|
|
// Only move into immediately preceding `assuming` op.
|
|
auto assuming_op =
|
|
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
|
if (!assuming_op) return failure();
|
|
|
|
Block *body = assuming_op.getBody();
|
|
auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
|
|
|
|
// Find the operands to use if the op was within the assuming region. We
|
|
// will later use their copies, as we copy the assuming op and its body.
|
|
SmallVector<Value, 8> new_operands_unmapped =
|
|
llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) {
|
|
for (auto result : llvm::enumerate(assuming_op->getResults())) {
|
|
if (result.value() == v) return yield_op->getOperand(result.index());
|
|
}
|
|
return v;
|
|
}));
|
|
|
|
// Insert the rewritten assuming op right before the old one.
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPoint(assuming_op);
|
|
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
|
|
assuming_op.getLoc(), assuming_op.witness(), [&](OpBuilder &b, Location) {
|
|
// Copy body.
|
|
BlockAndValueMapping mapping;
|
|
for (auto &nested : body->without_terminator())
|
|
b.clone(nested, mapping);
|
|
|
|
// Copy op into the new body and use the mapped operands.
|
|
for (auto it : llvm::zip(op->getOperands(), new_operands_unmapped)) {
|
|
Value old_operand, new_operand_unmapped;
|
|
std::tie(old_operand, new_operand_unmapped) = it;
|
|
mapping.map(old_operand,
|
|
mapping.lookupOrDefault(new_operand_unmapped));
|
|
}
|
|
Operation *new_op = b.clone(*op, mapping);
|
|
|
|
// Yield the previous results and also the new ones.
|
|
auto mapped_results = llvm::to_vector<8>(llvm::map_range(
|
|
yield_op.operands(),
|
|
[&](Value v) { return mapping.lookupOrDefault(v); }));
|
|
mapped_results.append(new_op->getResults().begin(),
|
|
new_op->getResults().end());
|
|
return mapped_results;
|
|
});
|
|
|
|
// Replace the assuming op and the root op with the corresponding result
|
|
// value.
|
|
ValueRange new_assuming_op_results = new_assuming_op->getResults();
|
|
rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back());
|
|
rewriter.replaceOp(op, new_assuming_op_results.back());
|
|
return success();
|
|
}
|
|
|
|
/// Move operation into a preceding assuming op. This allows to process
|
|
/// operations that depend on the assuming op's results. It will eventually
|
|
/// allow to make assuming regions' constraints independent from each other.
|
|
template <typename OpTy>
|
|
struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
PatternRewriter &rewriter) const override {
|
|
return MoveIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter);
|
|
}
|
|
};
|
|
|
|
// Move elementwise operations into assuming regions. This will eventually allow
|
|
// for more fusion opportunities.
|
|
struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern {
|
|
explicit MoveElementwiseOpsIntoAssumingOpPattern(MLIRContext *ctx)
|
|
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
|
|
|
|
LogicalResult matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const override {
|
|
// Apply to all elementwise and broadcasting elementwise operations.
|
|
if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
|
!op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>())
|
|
return failure();
|
|
|
|
return MoveIntoAssumingOpMatchAndRewrite(op, rewriter);
|
|
}
|
|
};
|
|
|
|
/// Move operation out of assuming op. This is only valid for
|
|
/// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It
|
|
/// will eventually allow to make assuming regions' constraints independent from
|
|
/// each other.
|
|
template <typename OpTy>
|
|
struct MoveOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
PatternRewriter &rewriter) const override {
|
|
// Must be inside of an assuming op.
|
|
auto assuming_op = op->template getParentOfType<shape::AssumingOp>();
|
|
if (!assuming_op) return failure();
|
|
|
|
// Operands must not be defined within the assuming op.
|
|
Block *body = assuming_op.getBody();
|
|
auto is_available = [&](Value v) {
|
|
Operation *def = v.getDefiningOp();
|
|
return def == nullptr || def->getBlock() != body;
|
|
};
|
|
if (!llvm::all_of(op->getOperands(), is_available)) return failure();
|
|
|
|
// Move op before the assuming region.
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPoint(assuming_op);
|
|
Operation *new_op = rewriter.clone(*op);
|
|
rewriter.replaceOp(op, new_op->getResults());
|
|
|
|
// If the assuming region yields none of the new op's results, these values
|
|
// are exclusively used in the assuming op's body. In these cases there is
|
|
// no need for further rewrites.
|
|
auto is_new_op_result = [&](Value v) {
|
|
return llvm::is_contained(new_op->getResults(), v);
|
|
};
|
|
auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
|
|
if (llvm::none_of(yield_op.operands(), is_new_op_result)) return success();
|
|
|
|
// If the assuming region yields any of the new op's results, these values
|
|
// can instead bypass the assuming region. There is no need to yield them
|
|
// explicitly as they are assumed to be independent. The assuming op is
|
|
// rewritten accordingly.
|
|
SmallVector<Value, 2> replacement_values;
|
|
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
|
|
assuming_op.getLoc(), assuming_op.witness(),
|
|
[&](OpBuilder &b, Location) {
|
|
// Copy body.
|
|
BlockAndValueMapping mapping;
|
|
for (Operation &nested : body->without_terminator()) {
|
|
b.clone(nested, mapping);
|
|
}
|
|
|
|
// Collect new yield operands.
|
|
SmallVector<Value, 2> new_yield_operands;
|
|
for (Value result : yield_op.operands()) {
|
|
if (is_new_op_result(result)) {
|
|
replacement_values.push_back(result);
|
|
} else {
|
|
new_yield_operands.push_back(mapping.lookup(result));
|
|
replacement_values.push_back(nullptr);
|
|
}
|
|
}
|
|
return new_yield_operands;
|
|
});
|
|
|
|
// Use the assuming op's results for the missing replacement values.
|
|
auto src = new_assuming_op.getResults().begin();
|
|
for (auto &dst : replacement_values) {
|
|
if (dst) continue;
|
|
dst = *src++;
|
|
}
|
|
|
|
rewriter.replaceOp(assuming_op, replacement_values);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Merge assuming regions if their constraints are independent from each other.
|
|
struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
|
|
using OpRewritePattern<shape::AssumingOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(shape::AssumingOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
// Merge assuming op with directly preceding one if both witnesses are
|
|
// availiable.
|
|
auto preceding_op =
|
|
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
|
if (!preceding_op) return failure();
|
|
if (op.witness().getDefiningOp() == preceding_op) return failure();
|
|
|
|
// Merge witnesses.
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPoint(preceding_op);
|
|
Value new_witness = rewriter.create<shape::AssumingAllOp>(
|
|
op.witness().getDefiningOp()->getLoc(),
|
|
ValueRange{preceding_op.witness(), op.witness()});
|
|
|
|
// Merge assuming ops.
|
|
Block *body_a = preceding_op.getBody();
|
|
Block *body_b = op.getBody();
|
|
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
|
|
preceding_op.getLoc(), new_witness, [&](OpBuilder &b, Location) {
|
|
// Copy preceding op's body.
|
|
BlockAndValueMapping mapping;
|
|
for (auto &nested : body_a->without_terminator()) {
|
|
b.clone(nested, mapping);
|
|
}
|
|
|
|
// Map result values of preceding assuming op.
|
|
auto yield_op_a =
|
|
llvm::dyn_cast<shape::AssumingYieldOp>(body_a->getTerminator());
|
|
for (auto pair :
|
|
llvm::zip(preceding_op->getResults(), yield_op_a.operands())) {
|
|
mapping.map(std::get<0>(pair),
|
|
mapping.lookupOrDefault(std::get<1>(pair)));
|
|
}
|
|
|
|
// Copy op's body.
|
|
for (auto &nested : body_b->without_terminator()) {
|
|
b.clone(nested, mapping);
|
|
}
|
|
|
|
// Collect merged assuming op's results.
|
|
SmallVector<Value, 4> mapped_results;
|
|
auto yield_op_b =
|
|
llvm::dyn_cast<shape::AssumingYieldOp>(body_b->getTerminator());
|
|
for (Value v : yield_op_a.operands()) {
|
|
mapped_results.push_back(mapping.lookupOrDefault(v));
|
|
}
|
|
for (Value v : yield_op_b.operands()) {
|
|
mapped_results.push_back(mapping.lookupOrDefault(v));
|
|
}
|
|
return mapped_results;
|
|
});
|
|
|
|
// Replace the two assuming ops with the new corresponding results.
|
|
ValueRange new_results = new_assuming_op->getResults();
|
|
size_t split_at = preceding_op->getNumResults();
|
|
rewriter.replaceOp(preceding_op, new_results.take_front(split_at));
|
|
rewriter.replaceOp(op, new_results.drop_front(split_at));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct EarlyBroadcastInDimOpPattern
|
|
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
|
using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast_op,
|
|
PatternRewriter &rewriter) const override {
|
|
Operation *producer_op = bcast_op.operand().getDefiningOp();
|
|
if (!producer_op ||
|
|
!producer_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() ||
|
|
!producer_op->hasTrait<mlir::OpTrait::Elementwise>()) {
|
|
return failure();
|
|
}
|
|
|
|
// Materialize broadcast on operands.
|
|
SmallVector<Value, 2> bcasted_operands;
|
|
Location loc = bcast_op.getLoc();
|
|
ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape();
|
|
for (Value operand : producer_op->getOperands()) {
|
|
// The broadcast only works on ranked operations.
|
|
auto operand_ty = operand.getType().dyn_cast<RankedTensorType>();
|
|
if (!operand_ty) {
|
|
return bcast_op.emitError()
|
|
<< "Can only move up broadcasts over ranked tensor operands.";
|
|
}
|
|
|
|
auto bcasted_operand_ty =
|
|
RankedTensorType::get(ty_shape, operand_ty.getElementType());
|
|
bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>(
|
|
loc, bcasted_operand_ty, operand, bcast_op.output_dimensions(),
|
|
bcast_op.broadcast_dimensions()));
|
|
}
|
|
|
|
// Create a copy of the producer op with the new broadcasted operands.
|
|
OperationState new_producer_op_state(
|
|
loc, producer_op->getName().getStringRef(), bcasted_operands,
|
|
bcast_op.getType(), producer_op->getAttrs());
|
|
Operation *new_producer_op =
|
|
rewriter.createOperation(new_producer_op_state);
|
|
|
|
// The original result of the broadcast now falls directly out of the new
|
|
// producer op. Use it instead.
|
|
rewriter.replaceOp(bcast_op, new_producer_op->getResults());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct BroadcastPropagationPass
|
|
: public BroadcastPropagationPassBase<BroadcastPropagationPass> {
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
|
}
|
|
|
|
void runOnFunction() override {
|
|
MLIRContext *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
mhlo::PopulateBroadcastsPropagationPatterns(ctx, &patterns);
|
|
if (failed(
|
|
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
|
return signalPassFailure();
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
|
|
OwningRewritePatternList *patterns) {
|
|
// clang-format off
|
|
patterns->insert<
|
|
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
|
MergeAssumingOpsPattern,
|
|
MoveElementwiseOpsIntoAssumingOpPattern,
|
|
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
|
|
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
|
MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
|
|
MoveOutOfAssumingOpPattern<shape::ShapeOfOp>,
|
|
EarlyBroadcastInDimOpPattern,
|
|
ShapeReificationPattern>(context);
|
|
// clang-format on
|
|
mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
|
|
context);
|
|
mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context);
|
|
shape::AssumingAllOp::getCanonicalizationPatterns(*patterns, context);
|
|
shape::AssumingOp::getCanonicalizationPatterns(*patterns, context);
|
|
shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
|
|
shape::CstrBroadcastableOp::getCanonicalizationPatterns(*patterns, context);
|
|
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
|
|
}
|
|
|
|
std::unique_ptr<FunctionPass> createBroadcastPropagationPass() {
|
|
return std::make_unique<BroadcastPropagationPass>();
|
|
}
|
|
|
|
} // namespace mhlo
|
|
} // namespace mlir
|