From 34dc5f2a79e9049d0143548309cca3dd750d850b Mon Sep 17 00:00:00 2001 From: Wenyi Zhao <951425797@qq.com> Date: Wed, 16 Jun 2021 09:50:41 -0700 Subject: [PATCH] PR #50020: [MLIR][DISC] support fusion on buffer Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50020 This pass implements the logic to group kLoop/kInput fusion patterns on buffer level. The reason for this is that we can avoid a lot of headaches to handle `shape-only` consumers specially (e.g. memref.dim, shape.shapeOf) since shapes are already resolved in buffer world. It may be better to move this pass to tensor level after more shape inference/constraint infras are ready on mhlo level. Copybara import of the project: -- e31f8344b59aa9860097197585215ea1689b8ff4 by Wenyi Zhao : [MLIR][DISC] support fusion on buffer This pass implements the logic to group kLoop/kInput fusion patterns on buffer level. The reason for this is that we can avoid a lot of headaches to handle `shape-only` consumers specially (e.g. memref.dim, shape.shapeOf) since shapes are already resolved in buffer world. It may be better to move this pass to tensor level after more shape inference/constraint infras are ready on mhlo level. -- 35f2eb2791241b0ab5db1ddcaf1b4006278ddccf by Wenyi Zhao : fix -- 923c8d61f7fe00a2a0df22d5be396508f0667964 by Wenyi Zhao : fix sanity check failure PiperOrigin-RevId: 379743424 --- BUILD | 38 ++ .../Dialect/mhlo/transforms/PassDetail.h | 8 + .../Dialect/mhlo/transforms/fusion_utils.h | 244 ++++++++ .../Dialect/mhlo/transforms/lmhlo_passes.td | 9 + .../mlir-hlo/Dialect/mhlo/transforms/passes.h | 4 + lib/Dialect/mhlo/transforms/CMakeLists.txt | 2 + lib/Dialect/mhlo/transforms/fusion_utils.cc | 394 ++++++++++++ lib/Dialect/mhlo/transforms/lhlo_fusion.cc | 570 ++++++++++++++++++ tests/lhlo-fusion.mlir | 286 +++++++++ 9 files changed, 1555 insertions(+) create mode 100644 include/mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h create mode 100644 lib/Dialect/mhlo/transforms/fusion_utils.cc create mode 100644 lib/Dialect/mhlo/transforms/lhlo_fusion.cc create mode 100644 tests/lhlo-fusion.mlir diff --git a/BUILD b/BUILD index 6215c5f..aa0d8e8 100644 --- a/BUILD +++ b/BUILD @@ -1213,6 +1213,42 @@ cc_library( ], ) +cc_library( + name = "fusion_utils", + srcs = ["lib/Dialect/mhlo/transforms/fusion_utils.cc"], + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"], + deps = [ + ":lhlo", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + +cc_library( + name = "lhlo_fusion", + srcs = ["lib/Dialect/mhlo/transforms/lhlo_fusion.cc"], + deps = [ + ":cycle_detector", + ":fusion_utils", + ":lhlo", + ":pass_details", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], + alwayslink = 1, +) + cc_library( name = "chlo_legalize_to_hlo", srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"], @@ -1259,6 +1295,7 @@ cc_library( ], deps = [ ":DiscRalPassIncGen", + ":LmhloPassIncGen", ":MhloPassIncGen", "@llvm-project//mlir:Pass", ], @@ -1316,6 +1353,7 @@ cc_library( ":legalize_trigonometric_to_approximation", ":lhlo", ":lhlo_fuse_linalg", + ":lhlo_fusion", ":lhlo_legalize_to_affine", ":lhlo_legalize_to_gpu", ":lhlo_legalize_to_parallel_loops", diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h b/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h index f3b6e2e..156d69b 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h @@ -25,6 +25,14 @@ namespace mhlo { #include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" } // end namespace mhlo + +namespace lmhlo { + +#define GEN_PASS_CLASSES +#include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc" + +} // end namespace lmhlo + } // end namespace mlir namespace mlir { diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h b/include/mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h new file mode 100644 index 0000000..717f322 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h @@ -0,0 +1,244 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_ + +#include +#include + +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/Support/Debug.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project + +// This file implements some helper functions and classes used to do fusion +// & code generation. + +namespace mlir { +namespace lmhlo { + +// kLoop fusion template satisfies: +// - all ops in the fusion pattern are element-wise. +// - all the shapes of outputs of fusion pattern are same or have same number +// of elements, and thus can fit into a same parallel loop. +// +// kInput fusion template satisfies: +// - any op in the fusion pattern is either element-wise or a reduction. +// - if a op is a reduction, its output cannot be consumed by other +// ops in the same fusion pattern. +// - all the effective shapes of outputs of fusion pattern are same. +// - For element-wise op, its effective shape is its output shape. +// - For reduction op, its effective shape is its operand shape. +// - currently our downstreaming codegen engine only support 2d -> 1d tensor +// reduction. TODO: lift this limitation. +// - 2D row reduction: out[i] = sum({in[i][j] for all j}) +// - 2D column reduction: out[j] = sum({in[i][j] for all i}) +enum FusionType { + // Not a fusion pattern + kNone, + // kLoop fusion pattern + kLoop, + // kInput fusion pattern and all reduce ops of the fused pattern are row + // reduction + kRowReduction, + // kInput fusion pattern and all reduce ops of the fused pattern are column + // reduction + kColReduction, +}; + +// Returns true if the op is an elementwise unary lmhlo op. +// TODO: use fusibility interface +bool isElementWiseUnary(Operation* op); + +// Returns true if the op is an elementwise binary lmhlo op. +// TODO: use fusibility interface +bool isElementWiseBinary(Operation* op); + +// Returns true if the op is an elementwise lmhlo op. +// TODO: use fusibility interface +bool isElementWise(Operation* op); + +// Returns true if this op is a rank-2 row reduction. +bool isRank2RowReduction(Operation* op); + +// Returns true if this op is a rank-2 column reduction. +bool isRank2ColReduction(Operation* op); + +// Returns true if the op is supported by the downstreaming fusion codegen +// engine. +bool isFusible(Operation* op); + +// Returns the number of operands that are supposed to be written. +// For some ops (e.g. lmhlo ops), some operands are the output memrefs +// Thus these operands are supposed to be updated. +int getNumResultOperands(Operation* op); + +// Returns data users of the value and its aliases (e.g. memref.cast). +// Here non-data users means DimOp, DeallocOp and ShapeOfOp. +SmallVector getValueUsers(Value v); + +// Represents a list of lmhlo ops that are going to be fused. +class FusionPattern { + public: + using FusionOpList = SmallVector; + using FusionValueList = SmallVector; + + // Create a new fusion pattern from a single op. + FusionPattern(Operation* op); + + // Create a new fusion pattern from the ops inside the lmhlo fusion op. + FusionPattern(lmhlo::FusionOp op); + + // Returns the op list this fusion pattern represents. + FusionOpList& getOpList() { return op_list_; } + + // Returns the dominant op of this fusion pattern. + // For kLoop fusion, a dominant op may be any op that has external users. + // For kInput fusion, a dominant op may be a row reduction (if exists), or + // a column reduction op. + Operation* getDominantOp() { return dominant_op_; } + + // Sets the dominant op to the op provided. + void setDominantOp(Operation* op) { dominant_op_ = op; } + + // Returns the fusion kind of the fusion pattern. + FusionType getFusionType() { return fusion_type_; } + + // Sets the fusion type to the the type provided. + void setFusionType(FusionType type) { fusion_type_ = type; } + + // Returns true if this a fusible fusion pattern. + bool isFusible() { return getFusionType() != FusionType::kNone; } + + // Returns true if this fusion pattern is a kLoop fusion. + bool isKLoopFusion() { return getFusionType() == FusionType::kLoop; } + + // Returns true if this fusion pattern is a kInput fusion. + bool isKInputFusion() { + return (getFusionType() == FusionType::kRowReduction || + getFusionType() == FusionType::kColReduction); + } + + // Returns true if two fusion patterns can be merged into one bigger fusion + // pattern. + bool isMergeable(FusionPattern& other); + + // Merges two fusion patterns and returns the merged pattern. The original + // pattern remains unmodified. + FusionPattern merge(FusionPattern& other); + + // Merges two fusion patterns and returns the merged pattern. Replaces the + // original pattern with new merged pattern. + FusionPattern& mergeInplace(FusionPattern& other); + + // Returns values that are consumed by the lmhlo ops inside the fusion + // pattern. + FusionValueList& getOperands() { return operands_; } + + // Returns values that are outputs of any lmhlo op in the fused pattern and + // have consumers outside the fusion pattern. + FusionValueList& getResults() { return results_; } + + // Returns values that are outputs of any lmhlo op in the fused pattern and + // are only consumed by the lmhlo ops inside the fused pattern. + FusionValueList& getInternalResults() { return internal_results_; } + + // Returns the size of the ops this fusion pattern contains. + int size() { return op_list_.size(); } + + // Returns the effective size (e.g. not counting const ops) of the ops this + // fusion pattern contains. + int effectiveSize(); + + // Sorts the ops inside the fusion pattern according to the keys provided. + void sortFusionOpListBy(DenseMap& op_to_idx); + + private: + FusionPattern(SmallVectorImpl& op_list); + + private: + // Calculates the inputs and outputs of the fusion pattern. + void calculateOperandsAndResults(); + + private: + FusionOpList op_list_; + Operation* dominant_op_ = nullptr; + FusionType fusion_type_ = FusionType::kNone; + FusionValueList operands_; + FusionValueList results_; + FusionValueList internal_results_; +}; + +// Represents a list of disjoint fusion patterns for a block. +using FusionPlan = std::vector; + +using llvm::EquivalenceClasses; + +// Supports using EquivalenceClasses for Value +class ValueWrapper { + public: + explicit ValueWrapper(Value value) : value_(std::move(value)) {} + + Value getValue() const { return value_; } + + bool operator==(const ValueWrapper& rhs) const { + return getValue() == rhs.getValue(); + } + + private: + Value value_; +}; + +bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs); + +// This is a simple shape constraint analysis, which is used to +// guide fusion decision (e.g. we only fuse shape-compatible ops). +// +// Currently, We only consider shape equality and same-number-elements equality +// propagation based on the shape constraint traits of elementwise ops (assuming +// that implicit shape broadcast is forbidden). +class ShapeConstraintAnalysis { + public: + explicit ShapeConstraintAnalysis(const SmallVectorImpl& op_list) { + PropagateEquality(op_list); + } + + // Returns true if `lhs` and `rhs` are supposed to have same shape. + bool HasSameShape(Value lhs, Value rhs) { + return same_shape_impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs)); + } + + // Returns true if `lhs` and `rhs` are supposed to have same number of + // elements. + bool HasSameNumElements(Value lhs, Value rhs) { + return same_num_elements_impl_.isEquivalent(ValueWrapper(lhs), + ValueWrapper(rhs)); + } + + private: + // shape equality propagation based on the shape constrains of + // elementwise ops. + void PropagateEquality(const SmallVectorImpl& op_list); + + // a UnionFind set + EquivalenceClasses same_shape_impl_; + EquivalenceClasses same_num_elements_impl_; +}; + +} // namespace lmhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td index e1d840c..59388f0 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td @@ -56,3 +56,12 @@ def LegalizeTensorLoadOpPass : Pass<"lhlo-legalize-tensor-load-op", "FuncOp"> { let constructor = "createLegalizeTensorLoadOpPass()"; } +def LhloFusionPass : FunctionPass<"lhlo-fusion"> { + let summary = "Fuse lmhlo ops to kLoop/kInput fusion patterns."; + let constructor = "createLhloFusionPass()"; + let options = [ + Option<"max_num_arguments_per_kernel_", "max-num-arguments-per-kernel", "int", + /*default=*/"64", "Maximum allowed number of arguments per fused kernel.">, + ]; +} + diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index cc2b6eb..5eb953f 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -117,6 +117,10 @@ std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); // Legalizes tensor load ops that are inserted during mhlo to lmhlo conversion. std::unique_ptr> createLegalizeTensorLoadOpPass(); +// fuse lmhlo ops to kLoop/kInput fusion patterns +std::unique_ptr> createLhloFusionPass( + int max_num_arguments_per_kernel = 64); + } // namespace lmhlo namespace disc_ral { diff --git a/lib/Dialect/mhlo/transforms/CMakeLists.txt b/lib/Dialect/mhlo/transforms/CMakeLists.txt index 81da420..8caba8b 100644 --- a/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -137,8 +137,10 @@ add_mlir_library(MhloLhloToLinalg ) add_mlir_library(LmhloPasses + fusion_utils.cc legalize_tensor_load_op.cc lhlo_fuse_linalg.cc + lhlo_fusion.cc lhlo_legalize_to_affine.cc lhlo_legalize_to_gpu.cc lhlo_legalize_to_parallel_loops.cc diff --git a/lib/Dialect/mhlo/transforms/fusion_utils.cc b/lib/Dialect/mhlo/transforms/fusion_utils.cc new file mode 100644 index 0000000..10d4da8 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/fusion_utils.cc @@ -0,0 +1,394 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h" + +#include + +#include "mlir/Dialect/Shape/IR/Shape.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" + +// This file implements some helper functions and classes used to do fusion +// & code generation. + +namespace mlir { +namespace lmhlo { + +// Returns true if the op is an elementwise unary lmhlo op. +// TODO(disc): use fusibility interface +bool isElementWiseUnary(Operation* op) { + // clang-format off + return isa< + lmhlo::AbsOp, + lmhlo::CeilOp, + lmhlo::ConvertOp, + lmhlo::CopyOp, + lmhlo::CosOp, + lmhlo::ExpOp, + lmhlo::FloorOp, + lmhlo::IsFiniteOp, + lmhlo::LogOp, + lmhlo::NegOp, + lmhlo::NotOp, + lmhlo::RsqrtOp, + lmhlo::SignOp, + lmhlo::SqrtOp, + lmhlo::TanhOp + >(op); + // clang-format on +} + +// Returns true if the op is an elementwise binary lmhlo op. +// TODO(disc): use fusibility interface +bool isElementWiseBinary(Operation* op) { + // clang-format off + return isa< + lmhlo::AddOp, + lmhlo::AndOp, + lmhlo::CompareOp, + lmhlo::DivOp, + lmhlo::MaxOp, + lmhlo::MinOp, + lmhlo::MulOp, + lmhlo::OrOp, + lmhlo::PowOp, + lmhlo::SubOp + >(op); + // clang-format on +} + +// Returns true if the op is an elementwise lmhlo op. +// TODO(disc): use fusibility interface +bool isElementWise(Operation* op) { + return isElementWiseUnary(op) || isElementWiseBinary(op); +} + +// Returns true if this op is a rank-2 row reduction. +bool isRank2RowReduction(Operation* op) { + auto reduce_op = dyn_cast(op); + if (!reduce_op || reduce_op.dimensions().getNumElements() != 1) return false; + + int rank = op->getOperand(0).getType().cast().getRank(); + auto dimensions = reduce_op.dimensions().getValues(); + return ((*dimensions.begin() == 1) && (rank == 2)); +} + +// Returns true if this op is a rank-2 column reduction. +bool isRank2ColReduction(Operation* op) { + auto reduce_op = dyn_cast(op); + if (!reduce_op || reduce_op.dimensions().getNumElements() != 1) return false; + + int rank = op->getOperand(0).getType().cast().getRank(); + auto dimensions = reduce_op.dimensions().getValues(); + return ((*dimensions.begin() == 0) && (rank == 2)); +} + +// Returns true if the op is supported by the downstreaming fusion codegen +// engine. +bool isFusible(Operation* op) { + // Only scalar const are supported by the fusion codegen engine a.t.m. + if (dyn_cast(op)) { + MemRefType type = op->getOperand(0).getType().cast(); + return (type.getRank() == 0); + } + + // All element ops are supported by the fusion codegen engine. + if (isElementWise(op)) return true; + + // Only rank-2 tensor -> rank-1 tensor reduction are supported now. + if (isRank2RowReduction(op) || isRank2ColReduction(op)) return true; + + // clang-format off + return isa< + lmhlo::BroadcastInDimOp, + lmhlo::BroadcastOp, + lmhlo::ConcatenateOp, + lmhlo::DynamicBroadcastInDimOp, + lmhlo::DynamicGatherOp, + lmhlo::DynamicIotaOp, + lmhlo::DynamicPadOp, + lmhlo::DynamicReshapeOp, + lmhlo::GatherOp, + lmhlo::RealDynamicSliceOp, + lmhlo::ReshapeOp, + lmhlo::SelectOp, + lmhlo::SliceOp, + lmhlo::TransposeOp + >(op); + // clang-format on +} + +// Returns the number of operands that are supposed to be written. +// For some ops (e.g. lmhlo ops), some operands are the output memrefs +// Thus these operands are supposed to be updated. +int getNumResultOperands(Operation* op) { + if (op->getDialect()->getNamespace() != "lmhlo") { + return 0; + } + + auto isWritable = [&](Value operand) -> bool { + llvm::SmallVector effects; + MemoryEffectOpInterface interface = dyn_cast(op); + // Suppose that operands of op without `MemoryEffectOpInterface` are + // readonly. + if (!interface) return false; + + interface.getEffectsOnValue(operand, effects); + return llvm::any_of( + effects, [](const mlir::MemoryEffects::EffectInstance& instance) { + return mlir::isa(instance.getEffect()); + }); + }; + + return llvm::count_if(op->getOperands(), + [&](Value v) { return isWritable(v); }); +} + +// Returns data users of the value and its aliases (e.g. memref.cast). +// Here non-data users means DimOp, DeallocOp and ShapeOfOp. +SmallVector getValueUsers(Value v) { + SmallVector users; + SmallVector worklist; + worklist.push_back(v); + while (!worklist.empty()) { + Value curr = worklist.back(); + worklist.pop_back(); + for (Operation* user : curr.getUsers()) { + // Skip non-data users + if (isa(user)) { + continue; + } + // alias value + if (isa(user)) { + worklist.push_back(user->getResult(0)); + } else { + users.push_back(user); + } + } + } + return users; +} + +// Create a new fusion pattern from a single op. +FusionPattern::FusionPattern(Operation* op) { + op_list_.push_back(op); + if (isRank2RowReduction(op)) { + fusion_type_ = FusionType::kRowReduction; + } else if (isRank2ColReduction(op)) { + fusion_type_ = FusionType::kColReduction; + } else if (mlir::lmhlo::isFusible(op)) { + fusion_type_ = FusionType::kLoop; + } else { + fusion_type_ = FusionType::kNone; + } + dominant_op_ = op; + calculateOperandsAndResults(); +} + +// Create a new fusion pattern from the ops inside the lmhlo fusion op. +FusionPattern::FusionPattern(lmhlo::FusionOp op) { + for (Operation& op : op.region().getBlocks().front()) { + op_list_.push_back(&op); + } + + // Figure out fusion type and dominant op for the fusion pattern. + for (Operation* op : op_list_) { + if (isRank2RowReduction(op)) { + fusion_type_ = FusionType::kRowReduction; + dominant_op_ = op; + } else if (isRank2ColReduction(op)) { + if (fusion_type_ != FusionType::kRowReduction) { + fusion_type_ = FusionType::kColReduction; + dominant_op_ = op; + } + } else if (lmhlo::isFusible(op)) { + // Ignore if already a kRowReduction or kColReduction, otherwise update + // the fusion type to kLoop and dominant op to current op. This supposes + // that the last op inside the block is a valid candidate dominant op if + // the fusion pattern is a kLoop. + if (fusion_type_ == FusionType::kNone || + fusion_type_ == FusionType::kLoop) { + fusion_type_ = FusionType::kLoop; + dominant_op_ = op; + } + } else { + // Not a supported fusionOp, early stop. + fusion_type_ = FusionType::kNone; + dominant_op_ = nullptr; + break; + } + } + + if (isFusible()) calculateOperandsAndResults(); +} + +// Create a new fusion pattern from a valid fusion op list. +FusionPattern::FusionPattern(SmallVectorImpl& op_list) + : op_list_(op_list.begin(), op_list.end()) { + calculateOperandsAndResults(); +} + +// Returns true if two fusion patterns can be merged into one bigger fusion +// pattern. +bool FusionPattern::isMergeable(FusionPattern& other) { + if (!this->isFusible() || !other.isFusible()) return false; + return true; +} + +// Merges two fusion patterns and returns the merged pattern. The original +// pattern remains unmodified. +FusionPattern FusionPattern::merge(FusionPattern& other) { + assert(isMergeable(other)); + FusionOpList new_op_list = op_list_; + new_op_list.insert(new_op_list.end(), other.getOpList().begin(), + other.getOpList().end()); + FusionPattern new_fusion_pattern{new_op_list}; + + FusionType newType = FusionType::kLoop; + Operation* newDominant = getDominantOp(); + + // kRowReduction + (kRowReduction | kColReduction | kLoop) = kRowReduction + // kColReduction + (kColReduction | kLoop) = kColReduction + // kLoop + kLoop = kLoop + if (getFusionType() == FusionType::kRowReduction || + other.getFusionType() == FusionType::kRowReduction) { + newType = FusionType::kRowReduction; + if (getFusionType() != FusionType::kRowReduction) + newDominant = other.getDominantOp(); + } else if (getFusionType() == FusionType::kColReduction || + other.getFusionType() == FusionType::kColReduction) { + newType = FusionType::kColReduction; + if (getFusionType() != FusionType::kColReduction) + newDominant = other.getDominantOp(); + } + + new_fusion_pattern.setDominantOp(newDominant); + new_fusion_pattern.setFusionType(newType); + return new_fusion_pattern; +} + +// Merges two fusion patterns and returns the merged pattern. Replaces the +// original pattern with new merged pattern. +FusionPattern& FusionPattern::mergeInplace(FusionPattern& other) { + *this = merge(other); + return *this; +} + +// Returns the effective size (e.g. not counting const ops) of the ops this +// fusion pattern contains. +int FusionPattern::effectiveSize() { + return llvm::count_if( + op_list_, [](Operation* op) { return !matchPattern(op, m_Constant()); }); +} + +// Sorts the ops inside the fusion pattern according to the keys provided. +void FusionPattern::sortFusionOpListBy(DenseMap& op_to_idx) { + std::sort(op_list_.begin(), op_list_.end(), + [&](Operation* lhs, Operation* rhs) { + return op_to_idx[lhs] < op_to_idx[rhs]; + }); +} + +// Calculates the inputs and outputs of the fusion pattern. +void FusionPattern::calculateOperandsAndResults() { + DenseSet input_set; + DenseSet result_set; + DenseSet internal_result_set; + DenseSet op_set(op_list_.begin(), op_list_.end()); + + DenseMap last_writer; + for (Operation* op : op_list_) { + int num_input_operand = op->getNumOperands() - getNumResultOperands(op); + for (Value v : op->getOperands().drop_front(num_input_operand)) { + bool inserted = last_writer.try_emplace(v, op).second; + (void)inserted; + assert(inserted); + + bool has_external_user = false; + for (Operation* user : getValueUsers(v)) { + if (!op_set.contains(user)) { + has_external_user = true; + break; + } + } + + if (has_external_user) { + results_.push_back(v); + } else { + internal_results_.push_back(v); + } + } + } + + for (Operation* op : op_list_) { + int num_input_operand = op->getNumOperands() - getNumResultOperands(op); + for (Value value : op->getOperands().take_front(num_input_operand)) { + if (last_writer.find(value) != last_writer.end()) { + // skip if defining op is in the pattern + continue; + } + input_set.insert(value); + } + } + + for (Value v : input_set) operands_.push_back(v); +} + +// Supports using EquivalenceClasses for Value +bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) { + auto lhs_value = lhs.getValue().getAsOpaquePointer(); + auto rhs_value = rhs.getValue().getAsOpaquePointer(); + return lhs_value < rhs_value; +} + +// shape equality propagation based on the shape constrains of +// elementwise ops. +void ShapeConstraintAnalysis::PropagateEquality( + const SmallVectorImpl& op_list) { + bool converged = true; + do { + converged = true; + auto update = [&](Value lhs, Value rhs, + EquivalenceClasses& impl) { + if (!impl.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) { + converged = false; + impl.unionSets(ValueWrapper(lhs), ValueWrapper(rhs)); + } + }; + for (Operation* op : op_list) { + int num_operand = op->getNumOperands(); + // Propagates same num_elements equality, and shape equality + if (isElementWise(op)) { + Value lhs = op->getOperand(0); + for (Value rhs : op->getOperands().drop_front()) { + update(lhs, rhs, same_num_elements_impl_); + update(lhs, rhs, same_shape_impl_); + } + } + // Propagates same num_elements equality, not shape equality + if (isa( + op)) { + Value input = op->getOperand(0); + // The last operand is the output memref by design + Value output = op->getOperand(num_operand - 1); + update(input, output, same_num_elements_impl_); + } + } + } while (!converged); +} + +} // namespace lmhlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_fusion.cc b/lib/Dialect/mhlo/transforms/lhlo_fusion.cc new file mode 100644 index 0000000..3a6adac --- /dev/null +++ b/lib/Dialect/mhlo/transforms/lhlo_fusion.cc @@ -0,0 +1,570 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" +#include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h" +#include "mlir-hlo/utils/cycle_detector.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project + +// This pass has similar functionality of the fusion pass in XLA stack. +// However, unlike XLA, it targets the fully dynamic shape scenario. +// Currently, it implements the kLoop and kInput fusion templates. +// During conversion, it tries to greedily find kLoop/kInput fusion +// patterns. +// +// Similar to XLA, this pass supports fusion pattern having multiple outputs +// if all the shape of outputs are consistent. Following are some examples. +// +// kLoop kInput +// +----+ +----+ +----+ +----+ +----+ +----+ +// |elem| |elem| |elem| |elem<----+elem+---->elem+----+ +// +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ | +// | | | | | | +// | | | | | +// +-v--+ | +-v--+ +--v---+ +--v---+ | +// |elem+<---+----<+elem| |reduce| |reduce| | +// +-+--+ +-+--+ +--+---+ +--+---+ | +// | | | | | +// | | | | | +// v v v v v +// +// To this end, we also add an simple shape constraint analysis phase. +// For kLoop fusion template, it requires all the outputs of the fused +// pattern have the same shape. However, we don't know the actual value +// of the shape at the compile time in the dynamic shape world. +// Fortunately, we could still infer the relationship among different ops +// according to their shape constraint traits. Currently, We only consider +// shape equality propagation for elementwise ops (assuming that implicit +// shape broadcast is forbidden). The above process could be built on the +// shape dialect once it is ready. +// +// TODO(disc): This file implements fusion on buffer level, re-visit this after +// more shape inference/constraint infras are ready in mhlo level. +// TODO(disc): Not using fusibility interface a.t.m, re-visit this if necessary. + +namespace mlir { +namespace lmhlo { +namespace { + +struct FusionOptions { + // Maximum allowed number of arguments per fused kernel. Here arguments + // include both ready-only buffers and writable buffers. + int max_num_arguments_per_kernel; +}; + +// A fusion planner that can propose a fusion plan for a block of ops. +// The fusion plan is consisted of a group of fusion patterns. +// +// Currently all proposed patterns followed xla kLoop/kInput like fusion +// templates while are adapted to the fully dynamic shape world. +// +// kLoop fusion template satisfies: +// - all ops in the fusion pattern are element-wise. +// - all the shapes of outputs of fusion pattern are same or have same number +// of elements, and thus can fit into a same parallel loop. +// +// kInput fusion template satisfies: +// - any op in the fusion pattern is either element-wise or a reduction. +// - if a op is a reduction, its output cannot be consumed by other +// ops in the same fusion pattern. +// - all the effective shapes of outputs of fusion pattern are same. +// - For element-wise op, its effective shape is its output shape. +// - For reduction op, its effective shape is its operand shape. +// - currently our downstreaming codegen engine only support 2d -> 1d tensor +// reduction. TODO(disc): lift this limitation. +// - 2D row reduction: out[i] = sum({in[i][j] for all j}) +// - 2D column reduction: out[j] = sum({in[i][j] for all i} +class FusionPlanner { + public: + explicit FusionPlanner(const FusionOptions& options, Block* block) + : options_(options), block_(block) { + // Move up metadata-only ops (e.g. dim, shape_of) as far as possible. + MoveUpMetadataOnlyOpsForFusion(); + + for (Operation& op : *block) { + op_list_.push_back(&op); + } + shape_analysis_.reset(new ShapeConstraintAnalysis(op_list_)); + cycle_detector_.reset(new GraphCycles(op_list_.size())); + BuildNodeMap(); + } + + // Returns a fusion plan if success, otherwise none. + llvm::Optional Run() { + // Greedily search connected fusible pattern, and ops belonging to + // a same fusion pattern are grouped into a cluster. + RunEdgeContractionLoop(); + + // After doing edge contraction, each unique cluster having size + // more than one represents a potential fusion pattern. + // We collect all these clusters and construct a fusion plan. + FusionPlan plan; + DenseSet seen_clusters; + for (Operation* op : op_list_) { + Cluster* cluster = GetClusterForNode(op); + if (!seen_clusters.insert(cluster).second) continue; + FusionPattern& fusion_pattern = cluster->fused_pattern(); + // Make sure the ops in a fusion pattern are in topological ordering. + fusion_pattern.sortFusionOpListBy(op_to_node_id_); + if (!fusion_pattern.isFusible() || fusion_pattern.effectiveSize() <= 1) { + continue; + } + plan.emplace_back(fusion_pattern); + } + + // Re-order ops inside the blocks to make sure all producers are placed + // before its consumers after fusion. + ReorderOperationsInsideBlock(); + return plan; + } + + // Returns the op_list this planner operates on. + const SmallVectorImpl& op_list() const { return op_list_; } + + private: + // Represent a (partial) fused pattern + class Cluster { + public: + Cluster(int node_id, FusionPlanner* planner) + : node_id_(node_id), pattern_(planner->op_list()[node_id]) {} + + // Merges `other` into this cluster, and clears `other`. + void Merge(Cluster* other) { + pattern_.mergeInplace(other->fused_pattern()); + } + + // The number of nodes in this cluster. + int cluster_size() { return pattern_.size(); } + + // The ID of the cluster as represented in `cycle_detector_`. + int cycles_graph_node_id() const { return node_id_; } + + // Sets the ID of the cluster as represented in `cycle_detector_`. + void set_cycles_graph_node_id(int cycles_graph_node_id) { + node_id_ = cycles_graph_node_id; + } + + // Currently the fused pattern this cluster holds. + FusionPattern& fused_pattern() { return pattern_; } + + private: + // ID of the representative node of this cluster. + int node_id_; + + // the fused pattern this cluster holds. + FusionPattern pattern_; + }; + + private: + // Returns a new cluster with specified `cycles_graph_node_id` + Cluster* MakeCluster(int cycles_graph_node_id) { + cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this)); + return cluster_storage_.back().get(); + } + + // Metadata ops (e.g. shapeOf, dimOp) don't change data thus we move forward + // them as far as possible inside the same block to enable more fusion + // opportunities. + void MoveUpMetadataOnlyOpsForFusion() { + SmallVector ops; + for (Operation& op : *block_) { + ops.push_back(&op); + } + + auto inBlock = [&](Operation* op, Block* block) { + return op && op->getBlock() == block; + }; + + for (Operation* op : ops) { + Block* block = op->getBlock(); + if (isa(op)) { + Operation* definingOp = op->getOperand(0).getDefiningOp(); + if (!inBlock(definingOp, block)) { + op->moveBefore(block, block->begin()); + } else { + op->moveAfter(definingOp); + } + } else if (isa(op)) { + Operation* firstOperandOp = op->getOperand(0).getDefiningOp(); + Operation* secondOperandOp = op->getOperand(1).getDefiningOp(); + if (!inBlock(firstOperandOp, block) && + !inBlock(secondOperandOp, block)) { + op->moveBefore(block, block->begin()); + } else if (!inBlock(firstOperandOp, block)) { + op->moveAfter(secondOperandOp); + } else if (!inBlock(secondOperandOp, block)) { + op->moveAfter(firstOperandOp); + } else if (firstOperandOp->isBeforeInBlock(secondOperandOp)) { + op->moveAfter(secondOperandOp); + } else { + op->moveAfter(firstOperandOp); + } + } + } + } + + // Returns all the values touched by this op or its nested ops. + SmallVector GetAllPossibleUsedValues(Operation* op) { + SmallVector values; + op->walk([&](Operation* nest_op) { + for (Value v : nest_op->getOperands()) { + values.push_back(v); + } + }); + return values; + } + + // Builds the initial dependency graph. + void BuildNodeMap() { + int num_nodes = op_list_.size(); + for (int node_id = 0; node_id < num_nodes; ++node_id) { + Operation* op = op_list_[node_id]; + MakeCluster(node_id); + op_to_node_id_[op] = node_id; + leader_for_node_.insert(node_id); + for (Value operand : GetAllPossibleUsedValues(op)) { + Operation* operand_op = FindLastWriter(operand); + // Only consider the operand_op inside the target block. + auto iter = op_to_node_id_.find(operand_op); + if (iter == op_to_node_id_.end()) { + continue; + } + // Add an edge to connect the last writer and the current consumer. + cycle_detector_->InsertEdge(iter->second, node_id); + } + + // For some ops (e.g. lmhlo ops), some operands are the output memrefs + // Thus these operands are supposed to be updated. + // Suppose that a op (or its nested ops) can only write the buffers + // explicit passed in as operands of this op. + int num_input_operand = op->getNumOperands() - getNumResultOperands(op); + for (Value v : op->getOperands().drop_front(num_input_operand)) { + auto it = last_writer_.try_emplace(v, op); + (void)it; + // Currently, a buffer is only supposed to be written once (as the + // output operand of one lmhlo op). + assert(it.second); + } + } + } + + // Returns the cluster contains this op. + Cluster* GetClusterForNode(Operation* n) { + int id = op_to_node_id_[n]; + id = leader_for_node_.getLeaderValue(id); + return cluster_storage_[id].get(); + } + + // Returns the cluster contains the op having `node_id`. + Cluster* GetClusterForCyclesGraphNode(int node_id) { + return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get(); + } + + // Merges the clusters `cluster_from` and `cluster_to`. + bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) { + int from = cluster_from->cycles_graph_node_id(); + int to = cluster_to->cycles_graph_node_id(); + + auto optional_merged_node = cycle_detector_->ContractEdge(from, to); + if (!optional_merged_node.hasValue()) { + llvm::dbgs() << "Could not contract " << from << " -> " << to + << " because contracting the edge would create a cycle."; + return false; + } + + // Merge the clusters. + cluster_from->Merge(cluster_to); + cluster_from->set_cycles_graph_node_id(*optional_merged_node); + + // Merge the UnionFind Set. + leader_for_node_.unionSets(from, to); + return true; + } + + using FnTy = llvm::function_ref; + bool ForEachEdgeInPostOrder(FnTy fn, bool enable_cross_fusion = false) { + bool changed = false; + for (int32_t node : cycle_detector_->AllNodesInPostOrder()) { + Cluster* cluster_from = GetClusterForCyclesGraphNode(node); + // Make a copy of the set of successors because we may modify the graph in + // TryToContractEdge. + std::vector successors_copy = + cycle_detector_->SuccessorsCopy(cluster_from->cycles_graph_node_id()); + + for (int to : successors_copy) { + Cluster* cluster_to = GetClusterForCyclesGraphNode(to); + bool contracted_edge = fn(cluster_from, cluster_to); + changed |= contracted_edge; + } + } + + if (!enable_cross_fusion) return changed; + + // To enable even more fusion opportunities (e.g. horizontal fusion) + for (int32_t lhs : cycle_detector_->AllNodesInPostOrder()) { + Cluster* cluster_lhs = GetClusterForCyclesGraphNode(lhs); + if (!cluster_lhs) { + continue; + } + + for (int32_t rhs : cycle_detector_->AllNodesInPostOrder()) { + Cluster* cluster_rhs = GetClusterForCyclesGraphNode(rhs); + if (!cluster_rhs || cluster_lhs == cluster_rhs) { + continue; + } + + bool contracted_edge = fn(cluster_lhs, cluster_rhs); + changed |= contracted_edge; + } + } + + return changed; + } + + // This function check if fusing `from` with `to` is valid and if so perform + // the merge. The validity is based on the operations in the clusters and + // the compatibility of the shapes of the outputs of the would-be fused + // clusters. + // Returns true is the merge was performed. + bool TryToContractEdge(Cluster* from, Cluster* to) { + // Try merge and check if valid. + if (!from->fused_pattern().isMergeable(to->fused_pattern())) return false; + FusionPattern fused_pattern = + from->fused_pattern().merge(to->fused_pattern()); + auto& op_list = fused_pattern.getOpList(); + auto& operands = fused_pattern.getOperands(); + auto& results = fused_pattern.getResults(); + + if (results.size() + operands.size() > + options_.max_num_arguments_per_kernel) { + // some backend devices (e.g. GPU) do not support a kernel with + // too many arguments. + return false; + } + + // We currently do not support a constant op as final output of a fusion + // pattern. + // TODO(disc): copy small const in case necessary. + for (Value result : results) { + Operation* result_op = FindLastWriter(result); + assert(result_op); + if (isa(result_op)) { + return false; + } + } + + // ReduceOp can not have consumer within the fusion pattern. + for (Operation* op : op_list) { + if (!isa(op)) continue; + int num_input_operand = op->getNumOperands() - getNumResultOperands(op); + for (Value v : op->getOperands().drop_front(num_input_operand)) { + for (Operation* user : getValueUsers(v)) { + if (user == op) continue; + if (std::find(op_list.begin(), op_list.end(), user) != + op_list.end()) { + return false; + } + } + } + } + + // All outputs of a fusion pattern should have compatible shape. + // Here `compatible` means: + // - if `to` and `from` are both kInput fusion, all output should have same + // shape. + // - otherwise, all output should have same number of elements. + + // No outside users, these ops may be eliminated. We fused it here and let + // latter pass to do such DCE. + if (results.empty()) return true; + + bool check_same_shape = (to->fused_pattern().isKInputFusion() && + from->fused_pattern().isKInputFusion()); + auto get_effective_shape = [&](Value v) { + auto result_op = FindLastWriter(v); + assert(result_op); + // effective shape of reduce op is its operand's shape. + return isa(result_op) ? result_op->getOperand(0) : v; + }; + + Value ref_shape = get_effective_shape(results[0]); + if (!llvm::all_of(results, [&](Value result) { + Value shape = get_effective_shape(result); + return check_same_shape + ? shape_analysis_->HasSameShape(ref_shape, shape) + : shape_analysis_->HasSameNumElements(ref_shape, shape); + })) { + return false; + } + + return MergeClusters(from, to); + } + + // Greedily fuse connected node. + bool RunEdgeContractionLoop() { + using std::placeholders::_1; + using std::placeholders::_2; + bool changed = false; + + // Run fusion pass repeatedly until nothing to be fused + while (ForEachEdgeInPostOrder( + std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2), false)) { + // empty statement by design + } + return changed; + } + + // Here `value` is supported to be a pointer to buffer. + // Returns the defining op of `value `if no known op updates the buffer, + // otherwise returns the last op that updates the buffer pointed by the + // `value`. + Operation* FindLastWriter(Value value) { + auto it = last_writer_.find(value); + if (it != last_writer_.end()) { + return it->second; + } + return value.getDefiningOp(); + } + + // Re-order ops inside the block to make sure that producers are before + // consumers after fusion. + void ReorderOperationsInsideBlock() { + auto reorder_func = [&](Cluster* from, Cluster* to) { + FusionPattern& from_pattern = from->fused_pattern(); + FusionPattern& to_pattern = to->fused_pattern(); + + Operation* last_op_in_from = from_pattern.getOpList().back(); + for (Operation* op : llvm::reverse(to_pattern.getOpList())) { + if (!last_op_in_from->isBeforeInBlock(op)) + op->moveAfter(last_op_in_from); + } + return false; + }; + + ForEachEdgeInPostOrder(reorder_func); + } + + // hyper-parameters that controls the behaviour of the fusion planner. + FusionOptions options_; + + // The block that fusion planner works on. + Block* block_; + + // Ops inside the block + SmallVector op_list_; + + // Shape equality checker + std::unique_ptr shape_analysis_; + + // op -> node_id + DenseMap op_to_node_id_; + + // make sure not introduce cycle after fusion + std::unique_ptr cycle_detector_; + std::vector> cluster_storage_; + + // a UnionFind set. Each set represents a (partial) fused pattern + // and has a leader as representation. + EquivalenceClasses leader_for_node_; + + // Here `value` is supported to be a pointer to buffer. + // Returns the defining op of `value `if no known op updates the buffer, + // otherwise returns the last op that updates the buffer pointed by the + // `value`. + DenseMap last_writer_; +}; + +struct LhloFusionPass : public LhloFusionPassBase { + using LhloFusionPassBase::LhloFusionPassBase; + explicit LhloFusionPass(int max_num_arguments_per_kernel) + : LhloFusionPassBase::LhloFusionPassBase() { + this->max_num_arguments_per_kernel_ = max_num_arguments_per_kernel; + } + + void runOnFunction() override { + FuncOp func = getFunction(); + + // collect all blocks inside the function. + SmallVector blocks; + CollectBlocksInsideFunction(func, blocks); + + // process each block and do fusion within a block. + FusionOptions options; + options.max_num_arguments_per_kernel = max_num_arguments_per_kernel_; + for (Block* block : blocks) { + FusionPlanner planner(options, block); + llvm::Optional plan = planner.Run(); + if (!plan) { + emitError(func.getLoc(), + "an error occurs while trying to find fusion candidates"); + signalPassFailure(); + return; + } + if (!ApplyFusionPlan(*plan)) { + emitError(func.getLoc(), "apply fusion plan failed"); + signalPassFailure(); + return; + } + } + } + + bool ApplyFusionPlan(FusionPlan& plan) { + for (FusionPattern& pattern : plan) { + auto& op_list = pattern.getOpList(); + OpBuilder b(op_list.back()); + + // Get the fused locations + SmallVector locations; + locations.reserve(op_list.size()); + for (Operation* op : op_list) { + locations.push_back(op->getLoc()); + } + Location fused_loc = + FusedLoc::get(op_list.back()->getContext(), locations); + + // Move ops inside fusion pattern to the region attached to the fusion op. + FusionOp fusion = b.create(fused_loc); + Region& region = fusion.region(); + Block& block = region.front(); + for (Operation* op : llvm::reverse(op_list)) { + op->moveBefore(&block, block.begin()); + } + } + return true; + } + + void CollectBlocksInsideFunction(FuncOp op, SmallVectorImpl& blocks) { + op.walk([&](Block* block) { + // It does not make sense to fuse the region attached to these ops. + if (!isa(block->getParentOp())) + blocks.push_back(block); + }); + } +}; + +} // namespace + +std::unique_ptr> createLhloFusionPass( + int max_num_arguments_per_kernel) { + return std::make_unique(max_num_arguments_per_kernel); +} + +} // namespace lmhlo +} // namespace mlir diff --git a/tests/lhlo-fusion.mlir b/tests/lhlo-fusion.mlir new file mode 100644 index 0000000..6cc939b --- /dev/null +++ b/tests/lhlo-fusion.mlir @@ -0,0 +1,286 @@ +// RUN: mlir-hlo-opt --lhlo-fusion -split-input-file %s -o - | FileCheck %s + +// CHECK-LABEL: @simple_kloop_fusion +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref) -> memref +func @simple_kloop_fusion(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3: memref) -> memref { + // CHECK: "lmhlo.fusion"() ( { + // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref, memref) -> () + // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref, memref, memref) -> () + // CHECK: }) : () -> () + // CHECK: return %[[ARG3]] : memref + "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () + "lmhlo.add"(%arg1, %arg2, %arg3) : (memref, memref, memref) -> () + return %arg3 : memref +} + +// ----- + +// CHECK-LABEL: @simple_multi_output_kloop_fusion +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref) -> (memref, memref) +func @simple_multi_output_kloop_fusion(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3: memref) -> (memref, memref) { + // CHECK: "lmhlo.fusion"() ( { + // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref, memref) -> () + // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref, memref, memref) -> () + // CHECK: }) : () -> () + // CHECK: return %[[ARG1]], %[[ARG3]] : memref, memref + "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () + "lmhlo.add"(%arg1, %arg2, %arg3) : (memref, memref, memref) -> () + return %arg1, %arg3 : memref, memref +} + +// ----- + +// CHECK-LABEL: @simple_multi_output_kloop_fusion_with_reorder +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref, %[[ARG4:.*]]: memref<2xindex>, %[[ARG5:.*]]: memref) +func @simple_multi_output_kloop_fusion_with_reorder(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3: memref, + %arg4: memref<2xindex>, %arg5: memref) -> (memref, memref, memref) { + // CHECK: "lmhlo.fusion"() ( { + // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref, memref) -> () + // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref, memref, memref) -> () + // CHECK: }) : () -> () + // CHECK: "lmhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[ARG4]], %[[ARG5]]) + // CHECK: return %[[ARG1]], %[[ARG3]], %[[ARG5]] : memref, memref, memref + "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () + "lmhlo.dynamic_broadcast_in_dim"(%arg1, %arg4, %arg5) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (memref, memref<2xindex>, memref) -> () + "lmhlo.add"(%arg1, %arg2, %arg3) : (memref, memref, memref) -> () + return %arg1, %arg3, %arg5 : memref, memref, memref +} + +// ----- + +// CHECK-LABEL: @same_num_elements_multi_output_kloop_fusion +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref<2xi64>, %[[ARG3:.*]]: memref, %[[ARG4:.*]]: memref, %[[ARG5:.*]]: memref) +func @same_num_elements_multi_output_kloop_fusion(%arg0: memref, %arg1: memref, + %arg2: memref<2xi64>, %arg3: memref, + %arg4: memref, %arg5: memref) -> (memref, memref) { + // CHECK: "lmhlo.fusion"() ( { + // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref, memref) -> () + // CHECK: "lmhlo.dynamic_reshape"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) + // CHECK: "lmhlo.add"(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : (memref, memref, memref) -> () + // CHECK: }) : () -> () + // CHECK: return %[[ARG1]], %[[ARG5]] : memref, memref + "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () + "lmhlo.dynamic_reshape"(%arg1, %arg2, %arg3) : (memref, memref<2xi64>, memref) -> () + "lmhlo.add"(%arg3, %arg4, %arg5) : (memref, memref, memref) -> () + return %arg1, %arg5 : memref, memref +} + +// ----- + +// CHECK-LABEL: @check_not_kloop_fusion +func @check_not_kloop_fusion(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref) -> (memref, memref) { + // CHECK-NOT: "lmhlo.fusion" + "lmhlo.add"(%arg0, %arg0, %arg1) : (memref, memref, memref) -> () + "lmhlo.subtract"(%arg2, %arg2, %arg3) : (memref, memref, memref) -> () + return %arg1, %arg3: memref, memref +} + +// ----- + +// CHECK-LABEL: @kloop_fusion_with_dealloc +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref) +func @kloop_fusion_with_dealloc(%arg0: memref, %arg1: memref) -> (memref, memref) { + // CHECK: %[[TMP3:.*]] = memref.alloc + // CHECK: %[[TMP5:.*]] = memref.alloc + // CHECK: %[[TMP9:.*]] = memref.alloc + // CHECK: %[[TMP13:.*]] = memref.alloc + // CHECK: %[[TMP16:.*]] = memref.alloc + // CHECK: "lmhlo.fusion"() ( { + // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP3]]) : (memref, memref, memref) -> () + // CHECK: "lmhlo.multiply"(%[[ARG0]], %[[ARG1]], %[[TMP5]]) : (memref, memref, memref) -> () + // CHECK: "lmhlo.abs"(%[[TMP3]], %[[TMP9]]) : (memref, memref) -> () + // CHECK: "lmhlo.abs"(%[[TMP5]], %[[TMP13]]) : (memref, memref) -> () + // CHECK: "lmhlo.multiply"(%[[TMP9]], %[[TMP13]], %[[TMP16]]) : (memref, memref, memref) -> () + // CHECK: }) : () -> () + // CHECK: memref.dealloc %[[TMP3]] : memref + // CHECK: memref.dealloc %[[TMP5]] : memref + // CHECK: memref.dealloc %[[TMP13]] : memref + // CHECK: return %[[TMP9]], %[[TMP16]] : memref, memref + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = shape.shape_of %arg0 : memref -> tensor<2xindex> + %1 = tensor.extract %0[%c0] : tensor<2xindex> + %2 = tensor.extract %0[%c1] : tensor<2xindex> + %3 = memref.alloc(%1, %2) : memref + "lmhlo.add"(%arg0, %arg1, %3) : (memref, memref, memref) -> () + %4 = memref.alloc(%1, %2) : memref + "lmhlo.multiply"(%arg0, %arg1, %4) : (memref, memref, memref) -> () + %5 = shape.shape_of %3 : memref -> tensor<2xindex> + %6 = tensor.extract %5[%c0] : tensor<2xindex> + %7 = tensor.extract %5[%c1] : tensor<2xindex> + %8 = memref.alloc(%6, %7) : memref + "lmhlo.abs"(%3, %8) : (memref, memref) -> () + memref.dealloc %3 : memref + %9 = shape.shape_of %4 : memref -> tensor<2xindex> + %10 = tensor.extract %9[%c0] : tensor<2xindex> + %11 = tensor.extract %9[%c1] : tensor<2xindex> + %12 = memref.alloc(%10, %11) : memref + "lmhlo.abs"(%4, %12) : (memref, memref) -> () + memref.dealloc %4 : memref + %13 = shape.shape_of %8 : memref -> tensor<2xindex> + %14 = tensor.extract %13[%c0] : tensor<2xindex> + %15 = tensor.extract %13[%c1] : tensor<2xindex> + %16 = memref.alloc(%14, %15) : memref + "lmhlo.multiply"(%8, %12, %16) : (memref, memref, memref) -> () + memref.dealloc %12 : memref + return %8, %16 : memref, memref +} + +// ----- + +// CHECK-LABEL: @simple_kinput +// CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref +func @simple_kinput(%arg0: memref, %arg1: memref, %arg2: memref, %init: memref) -> memref { + // CHECK: "lmhlo.fusion"() ( { + // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref, memref) -> () + // CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( { + // CHECK: }) : () -> () + // CHECK: return %[[ARG2]] : memref + "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () + "lmhlo.reduce"(%arg1, %init, %arg2) ( { + ^bb0(%targ1: memref, %targ2: memref, %tresult: memref): + "lmhlo.add"(%targ1, %targ2, %tresult) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () + } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref, memref, memref) -> () + return %arg2: memref +} + +// ----- + +// CHECK-LABEL: @multi_output_kinput +// CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref +func @multi_output_kinput(%arg0: memref, %arg1: memref, %arg2: memref, %init: memref) -> (memref, memref) { + // CHECK: "lmhlo.fusion"() ( { + // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref, memref) -> () + // CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( { + // CHECK: }) : () -> () + // CHECK: return %[[ARG1]], %[[ARG2]] : memref, memref + "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () + "lmhlo.reduce"(%arg1, %init, %arg2) ( { + ^bb0(%targ1: memref, %targ2: memref, %tresult: memref): + "lmhlo.add"(%targ1, %targ2, %tresult) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () + } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref, memref, memref) -> () + return %arg1, %arg2: memref, memref +} + +// ----- + +// CHECK-LABEL: @row_red_and_row_red_kinput +// CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref, %[[ARG4:.*]]: memref, %[[ARG5:.*]]: memref, %[[ARG6:.*]]: memref +func @row_red_and_row_red_kinput(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %init: memref) -> (memref, memref) { + // CHECK: "lmhlo.fusion"() ( { + // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref, memref, memref) -> () + // CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref, memref) -> () + // CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( { + // CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( { + // CHECK: }) : () -> () + // CHECK: return %[[ARG3]], %[[ARG4]] : memref, memref + "lmhlo.add"(%arg0, %arg1, %arg2) : (memref, memref, memref) -> () + "lmhlo.abs"(%arg2, %arg5) : (memref, memref) -> () + "lmhlo.reduce"(%arg5, %init, %arg3) ( { + ^bb0(%targ1: memref, %targ2: memref, %tresult: memref): + "lmhlo.add"(%targ1, %targ2, %tresult) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () + } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref, memref, memref) -> () + "lmhlo.reduce"(%arg2, %init, %arg4) ( { + ^bb0(%targ1: memref, %targ2: memref, %tresult: memref): + "lmhlo.add"(%targ1, %targ2, %tresult) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () + } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref, memref, memref) -> () + return %arg3, %arg4: memref, memref +} + +// ----- + +// CHECK-LABEL: @row_red_and_col_red_kinput +// CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref, %[[ARG4:.*]]: memref, %[[ARG5:.*]]: memref, %[[ARG6:.*]]: memref +func @row_red_and_col_red_kinput(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %init: memref) -> (memref, memref) { + // CHECK: "lmhlo.fusion"() ( { + // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref, memref, memref) -> () + // CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref, memref) -> () + // CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( { + // CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( { + // CHECK: }) : () -> () + // CHECK: return %[[ARG3]], %[[ARG4]] : memref, memref + "lmhlo.add"(%arg0, %arg1, %arg2) : (memref, memref, memref) -> () + "lmhlo.abs"(%arg2, %arg5) : (memref, memref) -> () + "lmhlo.reduce"(%arg5, %init, %arg3) ( { + ^bb0(%targ1: memref, %targ2: memref, %tresult: memref): + "lmhlo.add"(%targ1, %targ2, %tresult) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () + } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref, memref, memref) -> () + "lmhlo.reduce"(%arg2, %init, %arg4) ( { + ^bb0(%targ1: memref, %targ2: memref, %tresult: memref): + "lmhlo.add"(%targ1, %targ2, %tresult) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () + } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref, memref, memref) -> () + return %arg3, %arg4: memref, memref +} + +// ----- + +// CHECK-LABEL: @reduce_should_not_have_consumer_in_the_fusion +// CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref +func @reduce_should_not_have_consumer_in_the_fusion(%arg0: memref, %arg1: memref) +-> (memref, memref) { + // CHECK: %[[TMP4:.*]] = memref.alloc + // CHECK: %[[TMP7:.*]] = memref.alloc + // CHECK: %[[TMP8:.*]] = memref.alloc + // CHECK: %[[TMP9:.*]] = memref.alloc + // CHECK: "lmhlo.fusion"() ( { + // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP4]]) : (memref, memref, memref) -> () + // CHECK: "lmhlo.subtract"(%[[ARG0]], %[[TMP4]], %[[TMP7]]) : (memref, memref, memref) -> () + // CHECK: "lmhlo.constant"(%[[TMP8]]) {value = dense<0.000000e+00> : tensor} : (memref) -> () + // CHECK: "lmhlo.reduce"(%[[TMP7]], %[[TMP8]], %[[TMP9]]) ( { + // CHECK: }) : () -> () + // CHECK: memref.dealloc %[[TMP4]] : memref + // CHECK: memref.dealloc %[[TMP8]] : memref + // CHECK: %[[TMP12:.*]] = memref.alloc + // CHECK: "lmhlo.add"(%[[TMP9]], %[[TMP9]], %[[TMP12]]) : (memref, memref, memref) -> () + // CHECK: memref.dealloc %[[TMP9]] : memref + // CHECK: return %[[TMP7]], %[[TMP12]] : memref, memref + %c1 = constant 1 : index + %c0 = constant 0 : index + %0 = shape.shape_of %arg0 : memref -> tensor<2xindex> + %1 = tensor.extract %0[%c0] : tensor<2xindex> + %2 = tensor.extract %0[%c1] : tensor<2xindex> + %3 = memref.alloc(%1, %2) : memref + "lmhlo.add"(%arg0, %arg1, %3) : (memref, memref, memref) -> () + %4 = shape.shape_of %arg0 : memref -> tensor<2xindex> + %5 = tensor.extract %4[%c0] : tensor<2xindex> + %6 = tensor.extract %4[%c1] : tensor<2xindex> + %7 = memref.alloc(%5, %6) : memref + "lmhlo.subtract"(%arg0, %3, %7) : (memref, memref, memref) -> () + memref.dealloc %3 : memref + %8 = memref.alloc() : memref + "lmhlo.constant"(%8) {value = dense<0.000000e+00> : tensor} : (memref) -> () + %9 = memref.alloc(%5) : memref + "lmhlo.reduce"(%7, %8, %9) ( { + ^bb0(%arg2: memref, %arg3: memref, %arg4: memref): // no predecessors + "lmhlo.add"(%arg2, %arg3, %arg4) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () + }) {dimensions = dense<1> : tensor<1xi64>} : (memref, memref, memref) -> () + memref.dealloc %8 : memref + %10 = shape.shape_of %9 : memref -> tensor<1xindex> + %11 = tensor.extract %10[%c0] : tensor<1xindex> + %12 = memref.alloc(%11) : memref + "lmhlo.add"(%9, %9, %12) : (memref, memref, memref) -> () + memref.dealloc %9 : memref + return %7, %12 : memref, memref +} + +// ----- + +// CHECK-LABEL: @const_should_not_be_output +func @const_should_not_be_output(%arg0: memref) -> (memref, memref) { + // CHECK-NOT: lmhlo.fusion + %0 = memref.alloc() : memref + "lmhlo.constant"(%0) {value = dense<1.000000e+00> : tensor} : (memref) -> () + %1 = memref.alloc() : memref + "lmhlo.add"(%arg0, %0, %1) : (memref, memref, memref) -> () + return %0, %1 : memref, memref +}