/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // This file implements logic for fusing linalg ops obtained after LHLO // lowering. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace lmhlo { namespace { using linalg::LinalgOp; class LhloFuseLinalgPass : public PassWrapper { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } public: LhloFuseLinalgPass() = default; LhloFuseLinalgPass(const LhloFuseLinalgPass&) {} LhloFuseLinalgPass(bool use_parallel_loops, llvm::ArrayRef tile_sizes) { tile_sizes_ = tile_sizes; use_parallel_loops_.setValue(use_parallel_loops); } void runOnFunction() override { auto func = getFunction(); // TODO(pifon): Remove assumption that the function has a single block. if (!llvm::hasSingleElement(func)) { emitError(func.getLoc(), "The function needs to have a single block."); signalPassFailure(); return; } // The fusion in Linalg is currently possible only when the consumer op is // tiled. In order to greedily fuse the ops, we have to start from the tiled // root linalg ops, i.e. linalg ops that write to output buffers of the // function or are returned in case of escaping allocations. llvm::SmallDenseSet result_buffers; for (auto func_arg : func.getArguments()) { result_buffers.insert(func_arg); } for (auto& block : func) { auto returnOp = mlir::dyn_cast(block.getTerminator()); if (!returnOp) continue; for (auto operand : returnOp.getOperands()) { result_buffers.insert(operand); } } // Resolve aliasing operations (like casts) on the result to identify // results. This only handles escaping results. // TODO(herhut): Use BufferizeAliasAnalysis for this. llvm::SmallVector worklist(result_buffers.begin(), result_buffers.end()); while (!worklist.empty()) { Value result = worklist.pop_back_val(); auto definingOp = result.getDefiningOp(); if (!definingOp) { continue; } if (auto viewLike = dyn_cast(definingOp)) { auto alias = viewLike.getViewSource(); if (result_buffers.insert(alias).second) { worklist.push_back(alias); } continue; } if (auto tensor_load = dyn_cast(definingOp)) { auto alias = tensor_load.memref(); if (result_buffers.insert(alias).second) { worklist.push_back(alias); } continue; } if (auto tensor_to_memref = dyn_cast(definingOp)) { auto alias = tensor_to_memref.tensor(); if (result_buffers.insert(alias).second) { worklist.push_back(alias); } continue; } if (auto tensor_cast = dyn_cast(definingOp)) { auto alias = tensor_cast.source(); if (result_buffers.insert(alias).second) { worklist.push_back(alias); } continue; } if (auto regionInterface = dyn_cast(definingOp)) { for (Region& region : regionInterface.getOperation()->getRegions()) { // Only consider regions that can return to the parent region. SmallVector successorRegions; regionInterface.getSuccessorRegions(region.getRegionNumber(), successorRegions); if (llvm::none_of(successorRegions, [&](auto successorRegion) { return successorRegion.isParent(); })) continue; // Iterate over all immediate terminators and record the values // corresponding to result_buffers of interest. for (Block& block : region) { if (block.empty()) continue; Operation& operation = block.back(); if (!operation.hasTrait()) continue; auto idx = result.dyn_cast().getResultNumber(); if (result_buffers.insert(operation.getOperand(idx)).second) { worklist.push_back(operation.getOperand(idx)); } } } } } MLIRContext* ctx = func.getContext(); OpBuilder b(func); func.walk([&](linalg::GenericOp generic_op) { SmallVector tile_sizes(tile_sizes_.begin(), tile_sizes_.end()); if (tile_sizes.empty()) { tile_sizes = SmallVector(generic_op.getNumLoops(), 1); } auto op = cast(generic_op.getOperation()); for (const Value result : op.getOutputBuffers()) { if (!result_buffers.count(result)) continue; if (tileGenericOp(op, tile_sizes, &b)) { generic_op.erase(); return; } } }); auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx); applyPatternsAndFoldGreedily(func, std::move(patterns)); // Fuse producers of tiled linalg ops. llvm::SmallDenseSet erase_set; SmallVector linalg_ops; func.walk([&](LinalgOp op) { linalg_ops.push_back(op); }); for (LinalgOp op : llvm::reverse(linalg_ops)) { for (OpOperand& inputOperand : op.getInputOpOperands()) { linalg::Aliases aliases; linalg::LinalgDependenceGraph graph(aliases, linalg_ops); if (auto info = fuseProducerOfBuffer(b, inputOperand, graph)) { auto originalOp = info->originalProducer.getOperation(); erase_set.insert(originalOp); auto originalOpInLinalgOpsVector = std::find_if( linalg_ops.begin(), linalg_ops.end(), [&](const Operation* op) { return op == originalOp; }); *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); } } auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx); applyPatternsAndFoldGreedily(func, std::move(patterns)); } for (auto* e : erase_set) e->erase(); } private: bool tileGenericOp(LinalgOp op, ArrayRef tile_sizes, OpBuilder* b) { auto loopType = use_parallel_loops_ ? linalg::LinalgTilingLoopType::ParallelLoops : linalg::LinalgTilingLoopType::Loops; auto tiled_generic_op = linalg::tileLinalgOp(*b, op, linalg::LinalgTilingOptions() .setTileSizes(tile_sizes) .setLoopType(loopType)); return tiled_generic_op.hasValue(); } Option use_parallel_loops_{ *this, "use-parallel-loops", llvm::cl::desc( "Tiles GenericOp consumer to parallel loops before linalg fusion"), llvm::cl::init(false)}; ListOption tile_sizes_{ *this, "tile-sizes", llvm::cl::desc( "Tile sizes by which to tile linalg generic before linalg fusion"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; }; } // namespace std::unique_ptr createLhloFuseLinalgPass( bool use_parallel_loops, ArrayRef tile_sizes) { return std::make_unique(use_parallel_loops, tile_sizes); } } // namespace lmhlo } // namespace mlir