228 lines
8.6 KiB
C++
228 lines
8.6 KiB
C++
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// This file implements logic for fusing linalg ops obtained after LHLO
|
|
// lowering.
|
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Interfaces/ViewLikeInterface.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
namespace mlir {
|
|
namespace lmhlo {
|
|
namespace {
|
|
|
|
using linalg::LinalgOp;
|
|
|
|
class LhloFuseLinalgPass
|
|
: public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
|
|
void getDependentDialects(DialectRegistry& registry) const override {
|
|
registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
|
|
}
|
|
|
|
public:
|
|
LhloFuseLinalgPass() = default;
|
|
LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}
|
|
LhloFuseLinalgPass(bool use_parallel_loops,
|
|
llvm::ArrayRef<unsigned> tile_sizes) {
|
|
tile_sizes_ = tile_sizes;
|
|
use_parallel_loops_.setValue(use_parallel_loops);
|
|
}
|
|
|
|
void runOnFunction() override {
|
|
auto func = getFunction();
|
|
|
|
// TODO(pifon): Remove assumption that the function has a single block.
|
|
if (!llvm::hasSingleElement(func)) {
|
|
emitError(func.getLoc(), "The function needs to have a single block.");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
// The fusion in Linalg is currently possible only when the consumer op is
|
|
// tiled. In order to greedily fuse the ops, we have to start from the tiled
|
|
// root linalg ops, i.e. linalg ops that write to output buffers of the
|
|
// function or are returned in case of escaping allocations.
|
|
llvm::SmallDenseSet<Value> result_buffers;
|
|
for (auto func_arg : func.getArguments()) {
|
|
result_buffers.insert(func_arg);
|
|
}
|
|
for (auto& block : func) {
|
|
auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
|
|
if (!returnOp) continue;
|
|
for (auto operand : returnOp.getOperands()) {
|
|
result_buffers.insert(operand);
|
|
}
|
|
}
|
|
// Resolve aliasing operations (like casts) on the result to identify
|
|
// results. This only handles escaping results.
|
|
// TODO(herhut): Use BufferizeAliasAnalysis for this.
|
|
llvm::SmallVector<Value, 4> worklist(result_buffers.begin(),
|
|
result_buffers.end());
|
|
while (!worklist.empty()) {
|
|
Value result = worklist.pop_back_val();
|
|
auto definingOp = result.getDefiningOp();
|
|
if (!definingOp) {
|
|
continue;
|
|
}
|
|
|
|
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
|
|
auto alias = viewLike.getViewSource();
|
|
if (result_buffers.insert(alias).second) {
|
|
worklist.push_back(alias);
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (auto tensor_load = dyn_cast<memref::TensorLoadOp>(definingOp)) {
|
|
auto alias = tensor_load.memref();
|
|
if (result_buffers.insert(alias).second) {
|
|
worklist.push_back(alias);
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (auto tensor_to_memref = dyn_cast<memref::BufferCastOp>(definingOp)) {
|
|
auto alias = tensor_to_memref.tensor();
|
|
if (result_buffers.insert(alias).second) {
|
|
worklist.push_back(alias);
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (auto tensor_cast = dyn_cast<tensor::CastOp>(definingOp)) {
|
|
auto alias = tensor_cast.source();
|
|
if (result_buffers.insert(alias).second) {
|
|
worklist.push_back(alias);
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (auto regionInterface =
|
|
dyn_cast<RegionBranchOpInterface>(definingOp)) {
|
|
for (Region& region : regionInterface.getOperation()->getRegions()) {
|
|
// Only consider regions that can return to the parent region.
|
|
SmallVector<RegionSuccessor, 2> successorRegions;
|
|
regionInterface.getSuccessorRegions(region.getRegionNumber(),
|
|
successorRegions);
|
|
if (llvm::none_of(successorRegions, [&](auto successorRegion) {
|
|
return successorRegion.isParent();
|
|
}))
|
|
continue;
|
|
|
|
// Iterate over all immediate terminators and record the values
|
|
// corresponding to result_buffers of interest.
|
|
for (Block& block : region) {
|
|
if (block.empty()) continue;
|
|
Operation& operation = block.back();
|
|
if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
|
|
auto idx = result.dyn_cast<OpResult>().getResultNumber();
|
|
if (result_buffers.insert(operation.getOperand(idx)).second) {
|
|
worklist.push_back(operation.getOperand(idx));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
MLIRContext* ctx = func.getContext();
|
|
OpBuilder b(func);
|
|
func.walk([&](linalg::GenericOp generic_op) {
|
|
SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(),
|
|
tile_sizes_.end());
|
|
if (tile_sizes.empty()) {
|
|
tile_sizes = SmallVector<int64_t, 2>(generic_op.getNumLoops(), 1);
|
|
}
|
|
auto op = cast<LinalgOp>(generic_op.getOperation());
|
|
for (OpOperand* op_operand : op.getOutputBufferOperands()) {
|
|
if (!result_buffers.count(op_operand->get())) continue;
|
|
if (tileGenericOp(op, tile_sizes, &b)) {
|
|
generic_op.erase();
|
|
return;
|
|
}
|
|
}
|
|
});
|
|
auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
|
|
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
|
|
|
// Fuse producers of tiled linalg ops.
|
|
llvm::SmallDenseSet<Operation*> erase_set;
|
|
SmallVector<LinalgOp, 8> linalg_ops;
|
|
func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
|
|
for (LinalgOp op : llvm::reverse(linalg_ops)) {
|
|
for (OpOperand* inputOperand : op.getInputOperands()) {
|
|
linalg::Aliases aliases;
|
|
linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
|
|
if (auto info = fuseProducerOfBuffer(b, *inputOperand, graph)) {
|
|
auto originalOp = info->originalProducer.getOperation();
|
|
erase_set.insert(originalOp);
|
|
auto originalOpInLinalgOpsVector = std::find_if(
|
|
linalg_ops.begin(), linalg_ops.end(),
|
|
[&](const Operation* op) { return op == originalOp; });
|
|
*originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
|
|
}
|
|
}
|
|
|
|
auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
|
|
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
|
}
|
|
for (auto* e : erase_set) e->erase();
|
|
}
|
|
|
|
private:
|
|
bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b) {
|
|
auto loopType = use_parallel_loops_
|
|
? linalg::LinalgTilingLoopType::ParallelLoops
|
|
: linalg::LinalgTilingLoopType::Loops;
|
|
auto tiled_generic_op = linalg::tileLinalgOp(*b, op,
|
|
linalg::LinalgTilingOptions()
|
|
.setTileSizes(tile_sizes)
|
|
.setLoopType(loopType));
|
|
return tiled_generic_op.hasValue();
|
|
}
|
|
|
|
Option<bool> use_parallel_loops_{
|
|
*this, "use-parallel-loops",
|
|
llvm::cl::desc(
|
|
"Tiles GenericOp consumer to parallel loops before linalg fusion"),
|
|
llvm::cl::init(false)};
|
|
|
|
ListOption<unsigned> tile_sizes_{
|
|
*this, "tile-sizes",
|
|
llvm::cl::desc(
|
|
"Tile sizes by which to tile linalg generic before linalg fusion"),
|
|
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
|
|
bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
|
|
return std::make_unique<LhloFuseLinalgPass>(use_parallel_loops, tile_sizes);
|
|
}
|
|
|
|
} // namespace lmhlo
|
|
} // namespace mlir
|