2021-05-12 18:45:09 +08:00
|
|
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
|
|
|
|
==============================================================================*/
|
|
|
|
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
|
|
#include "llvm/ADT/SmallSet.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
2021-05-17 18:55:32 +08:00
|
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
2021-05-12 18:45:09 +08:00
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
|
|
#include "mlir/IR/BuiltinTypes.h"
|
2021-05-17 22:30:45 +08:00
|
|
|
#include "mlir/IR/MLIRContext.h"
|
2021-05-12 18:45:09 +08:00
|
|
|
#include "mlir/IR/Operation.h"
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
|
|
|
|
namespace mlir {
|
|
|
|
|
2021-05-17 18:55:32 +08:00
|
|
|
/// Needed to build `llvm::SmallSet`s of `mlir::Value`s.
|
2021-05-12 18:45:09 +08:00
|
|
|
static bool operator<(const Value &lhs, const Value &rhs) {
|
|
|
|
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace mhlo {
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
/// Identify clusters of operations that can be rank-specialized together. The
|
|
|
|
/// required traits for clustered operations are:
|
|
|
|
/// - Element-wise: All operations in the group must be element-wise. This
|
|
|
|
/// allows to reshape operands before applying the operations as well as
|
|
|
|
/// reshaping the result to the desired shape afterwards. This way, we can,
|
|
|
|
/// e.g., apply unary ops to a completely flattened operand and restore the
|
|
|
|
/// original shape afterwards.
|
|
|
|
/// - Broadcasting semantics: All operations must implement broadcasting
|
|
|
|
/// semantics. Most importantly, this allows extending operand shapes such
|
2021-05-17 22:30:45 +08:00
|
|
|
/// that they match in rank. Operations that require all their operands to
|
|
|
|
/// be of the same shape also fulfill this requirement.
|
2021-05-12 18:45:09 +08:00
|
|
|
/// - Shape reification: All operations must implement
|
|
|
|
/// `InferShapedTypeOpInterface`. This is later needed to compute and to
|
|
|
|
/// restore the desired result shape.
|
|
|
|
|
|
|
|
bool IsClusterable(Operation *op) {
|
|
|
|
if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
|
2021-05-17 22:30:45 +08:00
|
|
|
if (op->getNumOperands() == 0) return false;
|
|
|
|
return (op->hasTrait<OpTrait::Elementwise>() &&
|
|
|
|
op->hasTrait<OpTrait::SameOperandsAndResultShape>()) ||
|
|
|
|
(op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() &&
|
|
|
|
op->hasTrait<chlo::OpTrait::Broadcasting>());
|
2021-05-12 18:45:09 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
struct RankSpecializationClusterPattern : public RewritePattern {
|
|
|
|
explicit RankSpecializationClusterPattern(MLIRContext *ctx)
|
|
|
|
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
|
|
|
|
|
2021-05-17 22:30:45 +08:00
|
|
|
LogicalResult matchAndRewrite(Operation *op,
|
2021-05-12 18:45:09 +08:00
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
// Only apply to operations that have not been clustered yet.
|
2021-05-17 22:30:45 +08:00
|
|
|
if (op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
|
2021-05-12 18:45:09 +08:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Only cluster when rank specialization is needed.
|
2021-05-17 22:30:45 +08:00
|
|
|
if (!IsClusterable(op) || !llvm::any_of(op->getOperandTypes(), [](Type ty) {
|
|
|
|
return ty.isa<UnrankedTensorType>();
|
|
|
|
})) {
|
2021-05-12 18:45:09 +08:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Collect all collectively rank specializable ops.
|
|
|
|
SmallVector<Operation *, 16> cluster;
|
|
|
|
llvm::SmallSet<Value, 16> operand_set;
|
|
|
|
llvm::SmallSet<Value, 16> result_set;
|
2021-05-17 22:30:45 +08:00
|
|
|
|
|
|
|
Operation *root_op = op;
|
|
|
|
while (root_op->getNextNode() != nullptr &&
|
|
|
|
IsClusterable(root_op->getNextNode()))
|
|
|
|
root_op = root_op->getNextNode();
|
|
|
|
|
|
|
|
Operation *it = root_op;
|
|
|
|
while (it != nullptr && IsClusterable(it)) {
|
2021-05-12 18:45:09 +08:00
|
|
|
// Find results that escape the cluster.
|
2021-05-17 22:30:45 +08:00
|
|
|
for (OpOperand &use : it->getUses()) {
|
2021-05-12 18:45:09 +08:00
|
|
|
if (!llvm::is_contained(cluster, use.getOwner()))
|
|
|
|
result_set.insert(use.get());
|
|
|
|
}
|
|
|
|
|
|
|
|
// Update cluster operands.
|
2021-05-17 22:30:45 +08:00
|
|
|
for (OpResult v : it->getResults()) operand_set.erase(Value(v));
|
|
|
|
for (OpOperand &v : it->getOpOperands()) operand_set.insert(v.get());
|
2021-05-12 18:45:09 +08:00
|
|
|
|
2021-05-17 22:30:45 +08:00
|
|
|
cluster.push_back(it);
|
|
|
|
it = it->getPrevNode();
|
2021-05-12 18:45:09 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Create `RankSpecializationClusterOp`.
|
|
|
|
auto operands = llvm::to_vector<16>(operand_set);
|
|
|
|
auto results = llvm::to_vector<16>(result_set);
|
|
|
|
auto result_types = llvm::to_vector<16>(
|
|
|
|
llvm::map_range(result_set, [](Value v) { return v.getType(); }));
|
2021-05-17 22:30:45 +08:00
|
|
|
Location loc = op->getLoc();
|
2021-05-12 18:45:09 +08:00
|
|
|
auto cluster_op = rewriter.create<chlo::RankSpecializationClusterOp>(
|
|
|
|
loc, result_types, operands);
|
|
|
|
|
|
|
|
// Create body block.
|
|
|
|
auto operand_types = llvm::to_vector<16>(
|
|
|
|
llvm::map_range(operand_set, [](Value v) { return v.getType(); }));
|
|
|
|
Block *block = rewriter.createBlock(&cluster_op.body(), {}, operand_types);
|
|
|
|
|
|
|
|
// Copy operations into the body.
|
|
|
|
BlockAndValueMapping bvm;
|
|
|
|
for (auto it : llvm::zip(operands, block->getArguments()))
|
|
|
|
bvm.map(std::get<0>(it), std::get<1>(it));
|
|
|
|
rewriter.setInsertionPointToStart(block);
|
|
|
|
for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm);
|
|
|
|
|
|
|
|
// Create `RankSpecializationClusterYieldOp`.
|
|
|
|
auto mapped_results = llvm::to_vector<16>(
|
|
|
|
llvm::map_range(results, [&](Value v) { return bvm.lookup(v); }));
|
|
|
|
rewriter.create<chlo::RankSpecializationClusterYieldOp>(loc,
|
|
|
|
mapped_results);
|
|
|
|
|
|
|
|
// Replace original ops with the new results.
|
|
|
|
for (auto it : llvm::zip(results, cluster_op.results()))
|
|
|
|
bvm.map(std::get<0>(it), std::get<1>(it));
|
|
|
|
for (Operation *it : cluster) {
|
|
|
|
if (it->getUses().empty()) {
|
|
|
|
rewriter.eraseOp(it);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
auto replacements = llvm::to_vector<16>(llvm::map_range(
|
|
|
|
it->getResults(), [&](Value v) { return bvm.lookup(v); }));
|
2021-05-17 22:30:45 +08:00
|
|
|
rewriter.replaceOp(it, replacements);
|
2021-05-12 18:45:09 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
struct RankSpecializationClusterPass
|
|
|
|
: public PassWrapper<RankSpecializationClusterPass, FunctionPass> {
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
|
|
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect>();
|
|
|
|
}
|
|
|
|
|
|
|
|
void runOnFunction() override {
|
|
|
|
MLIRContext *ctx = &getContext();
|
|
|
|
RewritePatternSet patterns(ctx);
|
|
|
|
mhlo::PopulateRankSpecializationClusterPatterns(ctx, &patterns);
|
|
|
|
if (failed(
|
|
|
|
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2021-05-17 18:55:32 +08:00
|
|
|
/// Lower rank specialization cluster to SCF.
|
|
|
|
|
|
|
|
Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
|
|
|
|
auto unranked_ty = ty.dyn_cast<UnrankedTensorType>();
|
|
|
|
if (!unranked_ty) return ty;
|
|
|
|
SmallVector<int64_t, 8> shape(rank, ShapedType::kDynamicSize);
|
|
|
|
return RankedTensorType::get(shape, unranked_ty.getElementType());
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Unary element-wise operations on unranked tensors can be applied to the
|
|
|
|
/// flattened tensor and reshaped to the expected shape afterwards.
|
|
|
|
struct LowerUnaryRankSpecializationClusterPattern
|
|
|
|
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
|
|
|
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
// Only apply this to unary operations.
|
|
|
|
if (op.operands().size() != 1) return failure();
|
|
|
|
|
|
|
|
// Compute flattened operand shape.
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value arg = op.operands().front();
|
|
|
|
Value shape = rewriter.create<shape::ShapeOfOp>(loc, arg);
|
|
|
|
Value flat_shape = rewriter.create<tensor::FromElementsOp>(
|
|
|
|
loc,
|
|
|
|
rewriter
|
|
|
|
.create<shape::NumElementsOp>(loc, rewriter.getIndexType(), shape)
|
|
|
|
.result());
|
|
|
|
|
|
|
|
// Flatten operand.
|
|
|
|
Value flat_arg = rewriter.create<mhlo::DynamicReshapeOp>(
|
|
|
|
loc, DeriveRankedTensorTypes(arg.getType(), /*rank=*/1), arg,
|
|
|
|
flat_shape);
|
|
|
|
|
|
|
|
// Materialize ranked versions of the element-wise operations.
|
|
|
|
BlockAndValueMapping bvm;
|
|
|
|
bvm.map(op.getBody()->getArguments().front(), flat_arg);
|
|
|
|
for (Operation &nested_op : op.getBody()->without_terminator()) {
|
|
|
|
auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
|
|
|
|
nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); }));
|
|
|
|
auto ranked_result_types = llvm::to_vector<2>(llvm::map_range(
|
|
|
|
nested_op.getResultTypes(),
|
|
|
|
[](Type ty) { return DeriveRankedTensorTypes(ty, /*rank=*/1); }));
|
|
|
|
OperationState ranked_op_state(loc, nested_op.getName().getStringRef(),
|
|
|
|
mapped_operands, ranked_result_types,
|
|
|
|
nested_op.getAttrs());
|
|
|
|
Operation *ranked_op = rewriter.createOperation(ranked_op_state);
|
|
|
|
for (auto it :
|
|
|
|
llvm::zip(nested_op.getResults(), ranked_op->getResults())) {
|
|
|
|
bvm.map(std::get<0>(it), std::get<1>(it));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Collect results and restore their shape. We don't have to reify a shape
|
|
|
|
// computation in the unary case as the operand shapes to all the
|
|
|
|
// element-wise ops can only be the unique input shape.
|
|
|
|
SmallVector<Value> results;
|
|
|
|
for (Value v : llvm::cast<chlo::RankSpecializationClusterYieldOp>(
|
|
|
|
op.getBody()->getTerminator())
|
|
|
|
.results()) {
|
|
|
|
Value flat_result = bvm.lookup(v);
|
|
|
|
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
|
|
|
|
loc, v.getType(), flat_result, shape);
|
|
|
|
results.push_back(result);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Replace the rank specialization cluster.
|
|
|
|
rewriter.replaceOp(op, results);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
struct RankSpecializationToSCFPass
|
|
|
|
: public PassWrapper<RankSpecializationToSCFPass, FunctionPass> {
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
|
|
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect,
|
|
|
|
shape::ShapeDialect>();
|
|
|
|
}
|
|
|
|
|
|
|
|
void runOnFunction() override {
|
|
|
|
MLIRContext *ctx = &getContext();
|
|
|
|
RewritePatternSet patterns(ctx);
|
|
|
|
PopulateRankSpecializationToSCFPatterns(ctx, &patterns);
|
|
|
|
if (failed(
|
|
|
|
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2021-05-12 18:45:09 +08:00
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void PopulateRankSpecializationClusterPatterns(
|
|
|
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
|
|
|
patterns->insert<RankSpecializationClusterPattern>(context);
|
|
|
|
}
|
|
|
|
|
2021-05-17 18:55:32 +08:00
|
|
|
void PopulateRankSpecializationToSCFPatterns(
|
|
|
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
|
|
|
patterns->insert<LowerUnaryRankSpecializationClusterPattern>(context);
|
|
|
|
}
|
|
|
|
|
2021-05-12 18:45:09 +08:00
|
|
|
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
|
|
|
return std::make_unique<RankSpecializationClusterPass>();
|
|
|
|
}
|
|
|
|
|
2021-05-17 18:55:32 +08:00
|
|
|
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass() {
|
|
|
|
return std::make_unique<RankSpecializationToSCFPass>();
|
|
|
|
}
|
|
|
|
|
2021-05-12 18:45:09 +08:00
|
|
|
} // namespace mhlo
|
|
|
|
} // namespace mlir
|