PR #50020: [MLIR][DISC] support fusion on buffer
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50020 This pass implements the logic to group kLoop/kInput fusion patterns on buffer level. The reason for this is that we can avoid a lot of headaches to handle `shape-only` consumers specially (e.g. memref.dim, shape.shapeOf) since shapes are already resolved in buffer world. It may be better to move this pass to tensor level after more shape inference/constraint infras are ready on mhlo level. Copybara import of the project: -- e31f8344b59aa9860097197585215ea1689b8ff4 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] support fusion on buffer This pass implements the logic to group kLoop/kInput fusion patterns on buffer level. The reason for this is that we can avoid a lot of headaches to handle `shape-only` consumers specially (e.g. memref.dim, shape.shapeOf) since shapes are already resolved in buffer world. It may be better to move this pass to tensor level after more shape inference/constraint infras are ready on mhlo level. -- 35f2eb2791241b0ab5db1ddcaf1b4006278ddccf by Wenyi Zhao <reyizero@gmail.com>: fix -- 923c8d61f7fe00a2a0df22d5be396508f0667964 by Wenyi Zhao <reyizero@gmail.com>: fix sanity check failure PiperOrigin-RevId: 379743424
This commit is contained in:
		
							parent
							
								
									82696f8598
								
							
						
					
					
						commit
						34dc5f2a79
					
				
							
								
								
									
										38
									
								
								BUILD
								
								
								
								
							
							
						
						
									
										38
									
								
								BUILD
								
								
								
								
							|  | @ -1213,6 +1213,42 @@ cc_library( | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "fusion_utils", | ||||||
|  |     srcs = ["lib/Dialect/mhlo/transforms/fusion_utils.cc"], | ||||||
|  |     hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"], | ||||||
|  |     deps = [ | ||||||
|  |         ":lhlo", | ||||||
|  |         "@llvm-project//llvm:Core", | ||||||
|  |         "@llvm-project//llvm:Support", | ||||||
|  |         "@llvm-project//mlir:IR", | ||||||
|  |         "@llvm-project//mlir:Shape", | ||||||
|  |         "@llvm-project//mlir:StandardOps", | ||||||
|  |         "@llvm-project//mlir:Support", | ||||||
|  |     ], | ||||||
|  |     alwayslink = 1, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "lhlo_fusion", | ||||||
|  |     srcs = ["lib/Dialect/mhlo/transforms/lhlo_fusion.cc"], | ||||||
|  |     deps = [ | ||||||
|  |         ":cycle_detector", | ||||||
|  |         ":fusion_utils", | ||||||
|  |         ":lhlo", | ||||||
|  |         ":pass_details", | ||||||
|  |         "@llvm-project//llvm:Core", | ||||||
|  |         "@llvm-project//llvm:Support", | ||||||
|  |         "@llvm-project//mlir:IR", | ||||||
|  |         "@llvm-project//mlir:Pass", | ||||||
|  |         "@llvm-project//mlir:Shape", | ||||||
|  |         "@llvm-project//mlir:StandardOps", | ||||||
|  |         "@llvm-project//mlir:Support", | ||||||
|  |         "@llvm-project//mlir:TransformUtils", | ||||||
|  |     ], | ||||||
|  |     alwayslink = 1, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| cc_library( | cc_library( | ||||||
|     name = "chlo_legalize_to_hlo", |     name = "chlo_legalize_to_hlo", | ||||||
|     srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"], |     srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"], | ||||||
|  | @ -1259,6 +1295,7 @@ cc_library( | ||||||
|     ], |     ], | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":DiscRalPassIncGen", |         ":DiscRalPassIncGen", | ||||||
|  |         ":LmhloPassIncGen", | ||||||
|         ":MhloPassIncGen", |         ":MhloPassIncGen", | ||||||
|         "@llvm-project//mlir:Pass", |         "@llvm-project//mlir:Pass", | ||||||
|     ], |     ], | ||||||
|  | @ -1316,6 +1353,7 @@ cc_library( | ||||||
|         ":legalize_trigonometric_to_approximation", |         ":legalize_trigonometric_to_approximation", | ||||||
|         ":lhlo", |         ":lhlo", | ||||||
|         ":lhlo_fuse_linalg", |         ":lhlo_fuse_linalg", | ||||||
|  |         ":lhlo_fusion", | ||||||
|         ":lhlo_legalize_to_affine", |         ":lhlo_legalize_to_affine", | ||||||
|         ":lhlo_legalize_to_gpu", |         ":lhlo_legalize_to_gpu", | ||||||
|         ":lhlo_legalize_to_parallel_loops", |         ":lhlo_legalize_to_parallel_loops", | ||||||
|  |  | ||||||
|  | @ -25,6 +25,14 @@ namespace mhlo { | ||||||
| #include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" | #include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" | ||||||
| 
 | 
 | ||||||
| }  // end namespace mhlo
 | }  // end namespace mhlo
 | ||||||
|  | 
 | ||||||
|  | namespace lmhlo { | ||||||
|  | 
 | ||||||
|  | #define GEN_PASS_CLASSES | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc" | ||||||
|  | 
 | ||||||
|  | }  // end namespace lmhlo
 | ||||||
|  | 
 | ||||||
| }  // end namespace mlir
 | }  // end namespace mlir
 | ||||||
| 
 | 
 | ||||||
| namespace mlir { | namespace mlir { | ||||||
|  |  | ||||||
|  | @ -0,0 +1,244 @@ | ||||||
|  | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_ | ||||||
|  | #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_ | ||||||
|  | 
 | ||||||
|  | #include <memory> | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | #include "llvm/ADT/EquivalenceClasses.h" | ||||||
|  | #include "llvm/Support/Debug.h" | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" | ||||||
|  | #include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project | ||||||
|  | 
 | ||||||
|  | // This file implements some helper functions and classes used to do fusion
 | ||||||
|  | // & code generation.
 | ||||||
|  | 
 | ||||||
|  | namespace mlir { | ||||||
|  | namespace lmhlo { | ||||||
|  | 
 | ||||||
|  | // kLoop fusion template satisfies:
 | ||||||
|  | //   - all ops in the fusion pattern are element-wise.
 | ||||||
|  | //   - all the shapes of outputs of fusion pattern are same or have same number
 | ||||||
|  | //   of elements, and thus can fit into a same parallel loop.
 | ||||||
|  | //
 | ||||||
|  | // kInput fusion template satisfies:
 | ||||||
|  | //   - any op in the fusion pattern is either element-wise or a reduction.
 | ||||||
|  | //   - if a op is a reduction, its output cannot be consumed by other
 | ||||||
|  | //     ops in the same fusion pattern.
 | ||||||
|  | //   - all the effective shapes of outputs of fusion pattern are same.
 | ||||||
|  | //     - For element-wise op, its effective shape is its output shape.
 | ||||||
|  | //     - For reduction op, its effective shape is its operand shape.
 | ||||||
|  | //   - currently our downstreaming codegen engine only support 2d -> 1d tensor
 | ||||||
|  | //   reduction. TODO: lift this limitation.
 | ||||||
|  | //     - 2D row reduction: out[i] = sum({in[i][j] for all j})
 | ||||||
|  | //     - 2D column reduction: out[j] = sum({in[i][j] for all i})
 | ||||||
|  | enum FusionType { | ||||||
|  |   // Not a fusion pattern
 | ||||||
|  |   kNone, | ||||||
|  |   // kLoop fusion pattern
 | ||||||
|  |   kLoop, | ||||||
|  |   // kInput fusion pattern and all reduce ops of the fused pattern are row
 | ||||||
|  |   // reduction
 | ||||||
|  |   kRowReduction, | ||||||
|  |   // kInput fusion pattern and all reduce ops of the fused pattern are column
 | ||||||
|  |   // reduction
 | ||||||
|  |   kColReduction, | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | // Returns true if the op is an elementwise unary lmhlo op.
 | ||||||
|  | // TODO: use fusibility interface
 | ||||||
|  | bool isElementWiseUnary(Operation* op); | ||||||
|  | 
 | ||||||
|  | // Returns true if the op is an elementwise binary lmhlo op.
 | ||||||
|  | // TODO: use fusibility interface
 | ||||||
|  | bool isElementWiseBinary(Operation* op); | ||||||
|  | 
 | ||||||
|  | // Returns true if the op is an elementwise lmhlo op.
 | ||||||
|  | // TODO: use fusibility interface
 | ||||||
|  | bool isElementWise(Operation* op); | ||||||
|  | 
 | ||||||
|  | // Returns true if this op is a rank-2 row reduction.
 | ||||||
|  | bool isRank2RowReduction(Operation* op); | ||||||
|  | 
 | ||||||
|  | // Returns true if this op is a rank-2 column reduction.
 | ||||||
|  | bool isRank2ColReduction(Operation* op); | ||||||
|  | 
 | ||||||
|  | // Returns true if the op is supported by the downstreaming fusion codegen
 | ||||||
|  | // engine.
 | ||||||
|  | bool isFusible(Operation* op); | ||||||
|  | 
 | ||||||
|  | // Returns the number of operands that are supposed to be written.
 | ||||||
|  | // For some ops (e.g. lmhlo ops), some operands are the output memrefs
 | ||||||
|  | // Thus these operands are supposed to be updated.
 | ||||||
|  | int getNumResultOperands(Operation* op); | ||||||
|  | 
 | ||||||
|  | // Returns data users of the value and its aliases (e.g. memref.cast).
 | ||||||
|  | // Here non-data users means DimOp, DeallocOp and ShapeOfOp.
 | ||||||
|  | SmallVector<Operation*, 4> getValueUsers(Value v); | ||||||
|  | 
 | ||||||
|  | // Represents a list of lmhlo ops that are going to be fused.
 | ||||||
|  | class FusionPattern { | ||||||
|  |  public: | ||||||
|  |   using FusionOpList = SmallVector<Operation*, 4>; | ||||||
|  |   using FusionValueList = SmallVector<Value, 4>; | ||||||
|  | 
 | ||||||
|  |   // Create a new fusion pattern from a single op.
 | ||||||
|  |   FusionPattern(Operation* op); | ||||||
|  | 
 | ||||||
|  |   // Create a new fusion pattern from the ops inside the lmhlo fusion op.
 | ||||||
|  |   FusionPattern(lmhlo::FusionOp op); | ||||||
|  | 
 | ||||||
|  |   // Returns the op list this fusion pattern represents.
 | ||||||
|  |   FusionOpList& getOpList() { return op_list_; } | ||||||
|  | 
 | ||||||
|  |   // Returns the dominant op of this fusion pattern.
 | ||||||
|  |   // For kLoop fusion, a dominant op may be any op that has external users.
 | ||||||
|  |   // For kInput fusion, a dominant op may be a row reduction (if exists), or
 | ||||||
|  |   // a column reduction op.
 | ||||||
|  |   Operation* getDominantOp() { return dominant_op_; } | ||||||
|  | 
 | ||||||
|  |   // Sets the dominant op to the op provided.
 | ||||||
|  |   void setDominantOp(Operation* op) { dominant_op_ = op; } | ||||||
|  | 
 | ||||||
|  |   // Returns the fusion kind of the fusion pattern.
 | ||||||
|  |   FusionType getFusionType() { return fusion_type_; } | ||||||
|  | 
 | ||||||
|  |   // Sets the fusion type to the the type provided.
 | ||||||
|  |   void setFusionType(FusionType type) { fusion_type_ = type; } | ||||||
|  | 
 | ||||||
|  |   // Returns true if this a fusible fusion pattern.
 | ||||||
|  |   bool isFusible() { return getFusionType() != FusionType::kNone; } | ||||||
|  | 
 | ||||||
|  |   // Returns true if this fusion pattern is a kLoop fusion.
 | ||||||
|  |   bool isKLoopFusion() { return getFusionType() == FusionType::kLoop; } | ||||||
|  | 
 | ||||||
|  |   // Returns true if this fusion pattern is a kInput fusion.
 | ||||||
|  |   bool isKInputFusion() { | ||||||
|  |     return (getFusionType() == FusionType::kRowReduction || | ||||||
|  |             getFusionType() == FusionType::kColReduction); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Returns true if two fusion patterns can be merged into one bigger fusion
 | ||||||
|  |   // pattern.
 | ||||||
|  |   bool isMergeable(FusionPattern& other); | ||||||
|  | 
 | ||||||
|  |   // Merges two fusion patterns and returns the merged pattern. The original
 | ||||||
|  |   // pattern remains unmodified.
 | ||||||
|  |   FusionPattern merge(FusionPattern& other); | ||||||
|  | 
 | ||||||
|  |   // Merges two fusion patterns and returns the merged pattern. Replaces the
 | ||||||
|  |   // original pattern with new merged pattern.
 | ||||||
|  |   FusionPattern& mergeInplace(FusionPattern& other); | ||||||
|  | 
 | ||||||
|  |   // Returns values that are consumed by the lmhlo ops inside the fusion
 | ||||||
|  |   // pattern.
 | ||||||
|  |   FusionValueList& getOperands() { return operands_; } | ||||||
|  | 
 | ||||||
|  |   // Returns values that are outputs of any lmhlo op in the fused pattern and
 | ||||||
|  |   // have consumers outside the fusion pattern.
 | ||||||
|  |   FusionValueList& getResults() { return results_; } | ||||||
|  | 
 | ||||||
|  |   // Returns values that are outputs of any lmhlo op in the fused pattern and
 | ||||||
|  |   // are only consumed by the lmhlo ops inside the fused pattern.
 | ||||||
|  |   FusionValueList& getInternalResults() { return internal_results_; } | ||||||
|  | 
 | ||||||
|  |   // Returns the size of the ops this fusion pattern contains.
 | ||||||
|  |   int size() { return op_list_.size(); } | ||||||
|  | 
 | ||||||
|  |   // Returns the effective size (e.g. not counting const ops) of the ops this
 | ||||||
|  |   // fusion pattern contains.
 | ||||||
|  |   int effectiveSize(); | ||||||
|  | 
 | ||||||
|  |   // Sorts the ops inside the fusion pattern according to the keys provided.
 | ||||||
|  |   void sortFusionOpListBy(DenseMap<Operation*, int>& op_to_idx); | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   FusionPattern(SmallVectorImpl<Operation*>& op_list); | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   // Calculates the inputs and outputs of the fusion pattern.
 | ||||||
|  |   void calculateOperandsAndResults(); | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   FusionOpList op_list_; | ||||||
|  |   Operation* dominant_op_ = nullptr; | ||||||
|  |   FusionType fusion_type_ = FusionType::kNone; | ||||||
|  |   FusionValueList operands_; | ||||||
|  |   FusionValueList results_; | ||||||
|  |   FusionValueList internal_results_; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | // Represents a list of disjoint fusion patterns for a block.
 | ||||||
|  | using FusionPlan = std::vector<FusionPattern>; | ||||||
|  | 
 | ||||||
|  | using llvm::EquivalenceClasses; | ||||||
|  | 
 | ||||||
|  | // Supports using EquivalenceClasses for Value
 | ||||||
|  | class ValueWrapper { | ||||||
|  |  public: | ||||||
|  |   explicit ValueWrapper(Value value) : value_(std::move(value)) {} | ||||||
|  | 
 | ||||||
|  |   Value getValue() const { return value_; } | ||||||
|  | 
 | ||||||
|  |   bool operator==(const ValueWrapper& rhs) const { | ||||||
|  |     return getValue() == rhs.getValue(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   Value value_; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs); | ||||||
|  | 
 | ||||||
|  | // This is a simple shape constraint analysis, which is used to
 | ||||||
|  | // guide fusion decision (e.g. we only fuse shape-compatible ops).
 | ||||||
|  | //
 | ||||||
|  | // Currently, We only consider shape equality and same-number-elements equality
 | ||||||
|  | // propagation based on the shape constraint traits of elementwise ops (assuming
 | ||||||
|  | // that implicit shape broadcast is forbidden).
 | ||||||
|  | class ShapeConstraintAnalysis { | ||||||
|  |  public: | ||||||
|  |   explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) { | ||||||
|  |     PropagateEquality(op_list); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Returns true if `lhs` and `rhs` are supposed to have same shape.
 | ||||||
|  |   bool HasSameShape(Value lhs, Value rhs) { | ||||||
|  |     return same_shape_impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs)); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Returns true if `lhs` and `rhs` are supposed to have same number of
 | ||||||
|  |   // elements.
 | ||||||
|  |   bool HasSameNumElements(Value lhs, Value rhs) { | ||||||
|  |     return same_num_elements_impl_.isEquivalent(ValueWrapper(lhs), | ||||||
|  |                                                 ValueWrapper(rhs)); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   // shape equality propagation based on the shape constrains of
 | ||||||
|  |   // elementwise ops.
 | ||||||
|  |   void PropagateEquality(const SmallVectorImpl<Operation*>& op_list); | ||||||
|  | 
 | ||||||
|  |   // a UnionFind set
 | ||||||
|  |   EquivalenceClasses<ValueWrapper> same_shape_impl_; | ||||||
|  |   EquivalenceClasses<ValueWrapper> same_num_elements_impl_; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace lmhlo
 | ||||||
|  | }  // namespace mlir
 | ||||||
|  | 
 | ||||||
|  | #endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_
 | ||||||
|  | @ -56,3 +56,12 @@ def LegalizeTensorLoadOpPass : Pass<"lhlo-legalize-tensor-load-op", "FuncOp"> { | ||||||
|   let constructor = "createLegalizeTensorLoadOpPass()"; |   let constructor = "createLegalizeTensorLoadOpPass()"; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | def LhloFusionPass : FunctionPass<"lhlo-fusion"> { | ||||||
|  |   let summary = "Fuse lmhlo ops to kLoop/kInput fusion patterns."; | ||||||
|  |   let constructor = "createLhloFusionPass()"; | ||||||
|  |   let options = [ | ||||||
|  |     Option<"max_num_arguments_per_kernel_", "max-num-arguments-per-kernel", "int", | ||||||
|  |            /*default=*/"64", "Maximum allowed number of arguments per fused kernel.">, | ||||||
|  |   ]; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  |  | ||||||
|  | @ -117,6 +117,10 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass(); | ||||||
| // Legalizes tensor load ops that are inserted during mhlo to lmhlo conversion.
 | // Legalizes tensor load ops that are inserted during mhlo to lmhlo conversion.
 | ||||||
| std::unique_ptr<OperationPass<FuncOp>> createLegalizeTensorLoadOpPass(); | std::unique_ptr<OperationPass<FuncOp>> createLegalizeTensorLoadOpPass(); | ||||||
| 
 | 
 | ||||||
|  | // fuse lmhlo ops to kLoop/kInput fusion patterns
 | ||||||
|  | std::unique_ptr<OperationPass<FuncOp>> createLhloFusionPass( | ||||||
|  |     int max_num_arguments_per_kernel = 64); | ||||||
|  | 
 | ||||||
| }  // namespace lmhlo
 | }  // namespace lmhlo
 | ||||||
| 
 | 
 | ||||||
| namespace disc_ral { | namespace disc_ral { | ||||||
|  |  | ||||||
|  | @ -137,8 +137,10 @@ add_mlir_library(MhloLhloToLinalg | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| add_mlir_library(LmhloPasses | add_mlir_library(LmhloPasses | ||||||
|  |   fusion_utils.cc | ||||||
|   legalize_tensor_load_op.cc |   legalize_tensor_load_op.cc | ||||||
|   lhlo_fuse_linalg.cc |   lhlo_fuse_linalg.cc | ||||||
|  |   lhlo_fusion.cc | ||||||
|   lhlo_legalize_to_affine.cc |   lhlo_legalize_to_affine.cc | ||||||
|   lhlo_legalize_to_gpu.cc |   lhlo_legalize_to_gpu.cc | ||||||
|   lhlo_legalize_to_parallel_loops.cc |   lhlo_legalize_to_parallel_loops.cc | ||||||
|  |  | ||||||
|  | @ -0,0 +1,394 @@ | ||||||
|  | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h" | ||||||
|  | 
 | ||||||
|  | #include <algorithm> | ||||||
|  | 
 | ||||||
|  | #include "mlir/Dialect/Shape/IR/Shape.h"  // TF:llvm-project
 | ||||||
|  | #include "mlir/IR/MLIRContext.h"          // TF:llvm-project
 | ||||||
|  | #include "mlir/IR/Matchers.h" | ||||||
|  | 
 | ||||||
|  | // This file implements some helper functions and classes used to do fusion
 | ||||||
|  | // & code generation.
 | ||||||
|  | 
 | ||||||
|  | namespace mlir { | ||||||
|  | namespace lmhlo { | ||||||
|  | 
 | ||||||
|  | // Returns true if the op is an elementwise unary lmhlo op.
 | ||||||
|  | // TODO(disc): use fusibility interface
 | ||||||
|  | bool isElementWiseUnary(Operation* op) { | ||||||
|  |   // clang-format off
 | ||||||
|  |   return isa< | ||||||
|  |     lmhlo::AbsOp, | ||||||
|  |     lmhlo::CeilOp, | ||||||
|  |     lmhlo::ConvertOp, | ||||||
|  |     lmhlo::CopyOp, | ||||||
|  |     lmhlo::CosOp, | ||||||
|  |     lmhlo::ExpOp, | ||||||
|  |     lmhlo::FloorOp, | ||||||
|  |     lmhlo::IsFiniteOp, | ||||||
|  |     lmhlo::LogOp, | ||||||
|  |     lmhlo::NegOp, | ||||||
|  |     lmhlo::NotOp, | ||||||
|  |     lmhlo::RsqrtOp, | ||||||
|  |     lmhlo::SignOp, | ||||||
|  |     lmhlo::SqrtOp, | ||||||
|  |     lmhlo::TanhOp | ||||||
|  |   >(op); | ||||||
|  |   // clang-format on
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Returns true if the op is an elementwise binary lmhlo op.
 | ||||||
|  | // TODO(disc): use fusibility interface
 | ||||||
|  | bool isElementWiseBinary(Operation* op) { | ||||||
|  |   // clang-format off
 | ||||||
|  |   return isa< | ||||||
|  |     lmhlo::AddOp, | ||||||
|  |     lmhlo::AndOp, | ||||||
|  |     lmhlo::CompareOp, | ||||||
|  |     lmhlo::DivOp, | ||||||
|  |     lmhlo::MaxOp, | ||||||
|  |     lmhlo::MinOp, | ||||||
|  |     lmhlo::MulOp, | ||||||
|  |     lmhlo::OrOp, | ||||||
|  |     lmhlo::PowOp, | ||||||
|  |     lmhlo::SubOp | ||||||
|  |   >(op); | ||||||
|  |   // clang-format on
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Returns true if the op is an elementwise lmhlo op.
 | ||||||
|  | // TODO(disc): use fusibility interface
 | ||||||
|  | bool isElementWise(Operation* op) { | ||||||
|  |   return isElementWiseUnary(op) || isElementWiseBinary(op); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Returns true if this op is a rank-2 row reduction.
 | ||||||
|  | bool isRank2RowReduction(Operation* op) { | ||||||
|  |   auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op); | ||||||
|  |   if (!reduce_op || reduce_op.dimensions().getNumElements() != 1) return false; | ||||||
|  | 
 | ||||||
|  |   int rank = op->getOperand(0).getType().cast<MemRefType>().getRank(); | ||||||
|  |   auto dimensions = reduce_op.dimensions().getValues<int64_t>(); | ||||||
|  |   return ((*dimensions.begin() == 1) && (rank == 2)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Returns true if this op is a rank-2 column reduction.
 | ||||||
|  | bool isRank2ColReduction(Operation* op) { | ||||||
|  |   auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op); | ||||||
|  |   if (!reduce_op || reduce_op.dimensions().getNumElements() != 1) return false; | ||||||
|  | 
 | ||||||
|  |   int rank = op->getOperand(0).getType().cast<MemRefType>().getRank(); | ||||||
|  |   auto dimensions = reduce_op.dimensions().getValues<int64_t>(); | ||||||
|  |   return ((*dimensions.begin() == 0) && (rank == 2)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Returns true if the op is supported by the downstreaming fusion codegen
 | ||||||
|  | // engine.
 | ||||||
|  | bool isFusible(Operation* op) { | ||||||
|  |   // Only scalar const are supported by the fusion codegen engine a.t.m.
 | ||||||
|  |   if (dyn_cast<lmhlo::ConstOp>(op)) { | ||||||
|  |     MemRefType type = op->getOperand(0).getType().cast<MemRefType>(); | ||||||
|  |     return (type.getRank() == 0); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // All element ops are supported by the fusion codegen engine.
 | ||||||
|  |   if (isElementWise(op)) return true; | ||||||
|  | 
 | ||||||
|  |   // Only rank-2 tensor -> rank-1 tensor reduction are supported now.
 | ||||||
|  |   if (isRank2RowReduction(op) || isRank2ColReduction(op)) return true; | ||||||
|  | 
 | ||||||
|  |   // clang-format off
 | ||||||
|  |   return isa< | ||||||
|  |     lmhlo::BroadcastInDimOp, | ||||||
|  |     lmhlo::BroadcastOp, | ||||||
|  |     lmhlo::ConcatenateOp, | ||||||
|  |     lmhlo::DynamicBroadcastInDimOp, | ||||||
|  |     lmhlo::DynamicGatherOp, | ||||||
|  |     lmhlo::DynamicIotaOp, | ||||||
|  |     lmhlo::DynamicPadOp, | ||||||
|  |     lmhlo::DynamicReshapeOp, | ||||||
|  |     lmhlo::GatherOp, | ||||||
|  |     lmhlo::RealDynamicSliceOp, | ||||||
|  |     lmhlo::ReshapeOp, | ||||||
|  |     lmhlo::SelectOp, | ||||||
|  |     lmhlo::SliceOp, | ||||||
|  |     lmhlo::TransposeOp | ||||||
|  |   >(op); | ||||||
|  |   // clang-format on
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Returns the number of operands that are supposed to be written.
 | ||||||
|  | // For some ops (e.g. lmhlo ops), some operands are the output memrefs
 | ||||||
|  | // Thus these operands are supposed to be updated.
 | ||||||
|  | int getNumResultOperands(Operation* op) { | ||||||
|  |   if (op->getDialect()->getNamespace() != "lmhlo") { | ||||||
|  |     return 0; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   auto isWritable = [&](Value operand) -> bool { | ||||||
|  |     llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects; | ||||||
|  |     MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); | ||||||
|  |     // Suppose that operands of op without `MemoryEffectOpInterface` are
 | ||||||
|  |     // readonly.
 | ||||||
|  |     if (!interface) return false; | ||||||
|  | 
 | ||||||
|  |     interface.getEffectsOnValue(operand, effects); | ||||||
|  |     return llvm::any_of( | ||||||
|  |         effects, [](const mlir::MemoryEffects::EffectInstance& instance) { | ||||||
|  |           return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect()); | ||||||
|  |         }); | ||||||
|  |   }; | ||||||
|  | 
 | ||||||
|  |   return llvm::count_if(op->getOperands(), | ||||||
|  |                         [&](Value v) { return isWritable(v); }); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Returns data users of the value and its aliases (e.g. memref.cast).
 | ||||||
|  | // Here non-data users means DimOp, DeallocOp and ShapeOfOp.
 | ||||||
|  | SmallVector<Operation*, 4> getValueUsers(Value v) { | ||||||
|  |   SmallVector<Operation*, 4> users; | ||||||
|  |   SmallVector<Value, 4> worklist; | ||||||
|  |   worklist.push_back(v); | ||||||
|  |   while (!worklist.empty()) { | ||||||
|  |     Value curr = worklist.back(); | ||||||
|  |     worklist.pop_back(); | ||||||
|  |     for (Operation* user : curr.getUsers()) { | ||||||
|  |       // Skip non-data users
 | ||||||
|  |       if (isa<memref::DimOp, memref::DeallocOp, shape::ShapeOfOp>(user)) { | ||||||
|  |         continue; | ||||||
|  |       } | ||||||
|  |       // alias value
 | ||||||
|  |       if (isa<memref::CastOp>(user)) { | ||||||
|  |         worklist.push_back(user->getResult(0)); | ||||||
|  |       } else { | ||||||
|  |         users.push_back(user); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   return users; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Create a new fusion pattern from a single op.
 | ||||||
|  | FusionPattern::FusionPattern(Operation* op) { | ||||||
|  |   op_list_.push_back(op); | ||||||
|  |   if (isRank2RowReduction(op)) { | ||||||
|  |     fusion_type_ = FusionType::kRowReduction; | ||||||
|  |   } else if (isRank2ColReduction(op)) { | ||||||
|  |     fusion_type_ = FusionType::kColReduction; | ||||||
|  |   } else if (mlir::lmhlo::isFusible(op)) { | ||||||
|  |     fusion_type_ = FusionType::kLoop; | ||||||
|  |   } else { | ||||||
|  |     fusion_type_ = FusionType::kNone; | ||||||
|  |   } | ||||||
|  |   dominant_op_ = op; | ||||||
|  |   calculateOperandsAndResults(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Create a new fusion pattern from the ops inside the lmhlo fusion op.
 | ||||||
|  | FusionPattern::FusionPattern(lmhlo::FusionOp op) { | ||||||
|  |   for (Operation& op : op.region().getBlocks().front()) { | ||||||
|  |     op_list_.push_back(&op); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Figure out fusion type and dominant op for the fusion pattern.
 | ||||||
|  |   for (Operation* op : op_list_) { | ||||||
|  |     if (isRank2RowReduction(op)) { | ||||||
|  |       fusion_type_ = FusionType::kRowReduction; | ||||||
|  |       dominant_op_ = op; | ||||||
|  |     } else if (isRank2ColReduction(op)) { | ||||||
|  |       if (fusion_type_ != FusionType::kRowReduction) { | ||||||
|  |         fusion_type_ = FusionType::kColReduction; | ||||||
|  |         dominant_op_ = op; | ||||||
|  |       } | ||||||
|  |     } else if (lmhlo::isFusible(op)) { | ||||||
|  |       // Ignore if already a kRowReduction or kColReduction, otherwise update
 | ||||||
|  |       // the fusion type to kLoop and dominant op to current op. This supposes
 | ||||||
|  |       // that the last op inside the block is a valid candidate dominant op if
 | ||||||
|  |       // the fusion pattern is a kLoop.
 | ||||||
|  |       if (fusion_type_ == FusionType::kNone || | ||||||
|  |           fusion_type_ == FusionType::kLoop) { | ||||||
|  |         fusion_type_ = FusionType::kLoop; | ||||||
|  |         dominant_op_ = op; | ||||||
|  |       } | ||||||
|  |     } else { | ||||||
|  |       // Not a supported fusionOp, early stop.
 | ||||||
|  |       fusion_type_ = FusionType::kNone; | ||||||
|  |       dominant_op_ = nullptr; | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   if (isFusible()) calculateOperandsAndResults(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Create a new fusion pattern from a valid fusion op list.
 | ||||||
|  | FusionPattern::FusionPattern(SmallVectorImpl<Operation*>& op_list) | ||||||
|  |     : op_list_(op_list.begin(), op_list.end()) { | ||||||
|  |   calculateOperandsAndResults(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Returns true if two fusion patterns can be merged into one bigger fusion
 | ||||||
|  | // pattern.
 | ||||||
|  | bool FusionPattern::isMergeable(FusionPattern& other) { | ||||||
|  |   if (!this->isFusible() || !other.isFusible()) return false; | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Merges two fusion patterns and returns the merged pattern. The original
 | ||||||
|  | // pattern remains unmodified.
 | ||||||
|  | FusionPattern FusionPattern::merge(FusionPattern& other) { | ||||||
|  |   assert(isMergeable(other)); | ||||||
|  |   FusionOpList new_op_list = op_list_; | ||||||
|  |   new_op_list.insert(new_op_list.end(), other.getOpList().begin(), | ||||||
|  |                      other.getOpList().end()); | ||||||
|  |   FusionPattern new_fusion_pattern{new_op_list}; | ||||||
|  | 
 | ||||||
|  |   FusionType newType = FusionType::kLoop; | ||||||
|  |   Operation* newDominant = getDominantOp(); | ||||||
|  | 
 | ||||||
|  |   // kRowReduction + (kRowReduction | kColReduction | kLoop) = kRowReduction
 | ||||||
|  |   // kColReduction + (kColReduction | kLoop) = kColReduction
 | ||||||
|  |   // kLoop + kLoop = kLoop
 | ||||||
|  |   if (getFusionType() == FusionType::kRowReduction || | ||||||
|  |       other.getFusionType() == FusionType::kRowReduction) { | ||||||
|  |     newType = FusionType::kRowReduction; | ||||||
|  |     if (getFusionType() != FusionType::kRowReduction) | ||||||
|  |       newDominant = other.getDominantOp(); | ||||||
|  |   } else if (getFusionType() == FusionType::kColReduction || | ||||||
|  |              other.getFusionType() == FusionType::kColReduction) { | ||||||
|  |     newType = FusionType::kColReduction; | ||||||
|  |     if (getFusionType() != FusionType::kColReduction) | ||||||
|  |       newDominant = other.getDominantOp(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   new_fusion_pattern.setDominantOp(newDominant); | ||||||
|  |   new_fusion_pattern.setFusionType(newType); | ||||||
|  |   return new_fusion_pattern; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Merges two fusion patterns and returns the merged pattern. Replaces the
 | ||||||
|  | // original pattern with new merged pattern.
 | ||||||
|  | FusionPattern& FusionPattern::mergeInplace(FusionPattern& other) { | ||||||
|  |   *this = merge(other); | ||||||
|  |   return *this; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Returns the effective size (e.g. not counting const ops) of the ops this
 | ||||||
|  | // fusion pattern contains.
 | ||||||
|  | int FusionPattern::effectiveSize() { | ||||||
|  |   return llvm::count_if( | ||||||
|  |       op_list_, [](Operation* op) { return !matchPattern(op, m_Constant()); }); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Sorts the ops inside the fusion pattern according to the keys provided.
 | ||||||
|  | void FusionPattern::sortFusionOpListBy(DenseMap<Operation*, int>& op_to_idx) { | ||||||
|  |   std::sort(op_list_.begin(), op_list_.end(), | ||||||
|  |             [&](Operation* lhs, Operation* rhs) { | ||||||
|  |               return op_to_idx[lhs] < op_to_idx[rhs]; | ||||||
|  |             }); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Calculates the inputs and outputs of the fusion pattern.
 | ||||||
|  | void FusionPattern::calculateOperandsAndResults() { | ||||||
|  |   DenseSet<Value> input_set; | ||||||
|  |   DenseSet<Value> result_set; | ||||||
|  |   DenseSet<Value> internal_result_set; | ||||||
|  |   DenseSet<Operation*> op_set(op_list_.begin(), op_list_.end()); | ||||||
|  | 
 | ||||||
|  |   DenseMap<Value, Operation*> last_writer; | ||||||
|  |   for (Operation* op : op_list_) { | ||||||
|  |     int num_input_operand = op->getNumOperands() - getNumResultOperands(op); | ||||||
|  |     for (Value v : op->getOperands().drop_front(num_input_operand)) { | ||||||
|  |       bool inserted = last_writer.try_emplace(v, op).second; | ||||||
|  |       (void)inserted; | ||||||
|  |       assert(inserted); | ||||||
|  | 
 | ||||||
|  |       bool has_external_user = false; | ||||||
|  |       for (Operation* user : getValueUsers(v)) { | ||||||
|  |         if (!op_set.contains(user)) { | ||||||
|  |           has_external_user = true; | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       if (has_external_user) { | ||||||
|  |         results_.push_back(v); | ||||||
|  |       } else { | ||||||
|  |         internal_results_.push_back(v); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   for (Operation* op : op_list_) { | ||||||
|  |     int num_input_operand = op->getNumOperands() - getNumResultOperands(op); | ||||||
|  |     for (Value value : op->getOperands().take_front(num_input_operand)) { | ||||||
|  |       if (last_writer.find(value) != last_writer.end()) { | ||||||
|  |         // skip if defining op is in the pattern
 | ||||||
|  |         continue; | ||||||
|  |       } | ||||||
|  |       input_set.insert(value); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   for (Value v : input_set) operands_.push_back(v); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Supports using EquivalenceClasses for Value
 | ||||||
|  | bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) { | ||||||
|  |   auto lhs_value = lhs.getValue().getAsOpaquePointer(); | ||||||
|  |   auto rhs_value = rhs.getValue().getAsOpaquePointer(); | ||||||
|  |   return lhs_value < rhs_value; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // shape equality propagation based on the shape constrains of
 | ||||||
|  | // elementwise ops.
 | ||||||
|  | void ShapeConstraintAnalysis::PropagateEquality( | ||||||
|  |     const SmallVectorImpl<Operation*>& op_list) { | ||||||
|  |   bool converged = true; | ||||||
|  |   do { | ||||||
|  |     converged = true; | ||||||
|  |     auto update = [&](Value lhs, Value rhs, | ||||||
|  |                       EquivalenceClasses<ValueWrapper>& impl) { | ||||||
|  |       if (!impl.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) { | ||||||
|  |         converged = false; | ||||||
|  |         impl.unionSets(ValueWrapper(lhs), ValueWrapper(rhs)); | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  |     for (Operation* op : op_list) { | ||||||
|  |       int num_operand = op->getNumOperands(); | ||||||
|  |       // Propagates same num_elements equality, and shape equality
 | ||||||
|  |       if (isElementWise(op)) { | ||||||
|  |         Value lhs = op->getOperand(0); | ||||||
|  |         for (Value rhs : op->getOperands().drop_front()) { | ||||||
|  |           update(lhs, rhs, same_num_elements_impl_); | ||||||
|  |           update(lhs, rhs, same_shape_impl_); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       // Propagates same num_elements equality, not shape equality
 | ||||||
|  |       if (isa<lmhlo::DynamicReshapeOp, lmhlo::ReshapeOp, lmhlo::TransposeOp>( | ||||||
|  |               op)) { | ||||||
|  |         Value input = op->getOperand(0); | ||||||
|  |         // The last operand is the output memref by design
 | ||||||
|  |         Value output = op->getOperand(num_operand - 1); | ||||||
|  |         update(input, output, same_num_elements_impl_); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } while (!converged); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace lmhlo
 | ||||||
|  | }  // namespace mlir
 | ||||||
|  | @ -0,0 +1,570 @@ | ||||||
|  | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h" | ||||||
|  | #include "mlir-hlo/utils/cycle_detector.h" | ||||||
|  | #include "mlir/Dialect/Shape/IR/Shape.h"      // TF:llvm-project
 | ||||||
|  | #include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 | ||||||
|  | #include "mlir/IR/MLIRContext.h"              // TF:llvm-project
 | ||||||
|  | #include "mlir/IR/Matchers.h" | ||||||
|  | #include "mlir/Pass/Pass.h"               // TF:local_config_mlir
 | ||||||
|  | #include "mlir/Transforms/RegionUtils.h"  // TF:llvm-project
 | ||||||
|  | 
 | ||||||
|  | // This pass has similar functionality of the fusion pass in XLA stack.
 | ||||||
|  | // However, unlike XLA, it targets the fully dynamic shape scenario.
 | ||||||
|  | // Currently, it implements the kLoop and kInput fusion templates.
 | ||||||
|  | // During conversion, it tries to greedily find kLoop/kInput fusion
 | ||||||
|  | // patterns.
 | ||||||
|  | //
 | ||||||
|  | // Similar to XLA, this pass supports fusion pattern having multiple outputs
 | ||||||
|  | // if all the shape of outputs are consistent. Following are some examples.
 | ||||||
|  | //
 | ||||||
|  | //        kLoop                          kInput
 | ||||||
|  | // +----+  +----+  +----+    +----+    +----+    +----+
 | ||||||
|  | // |elem|  |elem|  |elem|    |elem<----+elem+---->elem+----+
 | ||||||
|  | // +-+--+  +-+--+  +-+--+    +-+--+    +----+    +-+--+    |
 | ||||||
|  | //   |       |       |         |                   |       |
 | ||||||
|  | //   |               |         |                   |       |
 | ||||||
|  | // +-v--+    |     +-v--+   +--v---+            +--v---+   |
 | ||||||
|  | // |elem+<---+----<+elem|   |reduce|            |reduce|   |
 | ||||||
|  | // +-+--+          +-+--+   +--+---+            +--+---+   |
 | ||||||
|  | //   |               |         |                   |       |
 | ||||||
|  | //   |               |         |                   |       |
 | ||||||
|  | //   v               v         v                   v       v
 | ||||||
|  | //
 | ||||||
|  | // To this end, we also add an simple shape constraint analysis phase.
 | ||||||
|  | // For kLoop fusion template, it requires all the outputs of the fused
 | ||||||
|  | // pattern have the same shape. However, we don't know the actual value
 | ||||||
|  | // of the shape at the compile time in the dynamic shape world.
 | ||||||
|  | // Fortunately, we could still infer the relationship among different ops
 | ||||||
|  | // according to their shape constraint traits. Currently, We only consider
 | ||||||
|  | // shape equality propagation for elementwise ops (assuming that implicit
 | ||||||
|  | // shape broadcast is forbidden). The above process could be built on the
 | ||||||
|  | // shape dialect once it is ready.
 | ||||||
|  | //
 | ||||||
|  | // TODO(disc): This file implements fusion on buffer level, re-visit this after
 | ||||||
|  | // more shape inference/constraint infras are ready in mhlo level.
 | ||||||
|  | // TODO(disc): Not using fusibility interface a.t.m, re-visit this if necessary.
 | ||||||
|  | 
 | ||||||
|  | namespace mlir { | ||||||
|  | namespace lmhlo { | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | struct FusionOptions { | ||||||
|  |   // Maximum allowed number of arguments per fused kernel. Here arguments
 | ||||||
|  |   // include both ready-only buffers and writable buffers.
 | ||||||
|  |   int max_num_arguments_per_kernel; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | // A fusion planner that can propose a fusion plan for a block of ops.
 | ||||||
|  | // The fusion plan is consisted of a group of fusion patterns.
 | ||||||
|  | //
 | ||||||
|  | // Currently all proposed patterns followed xla kLoop/kInput like fusion
 | ||||||
|  | // templates while are adapted to the fully dynamic shape world.
 | ||||||
|  | //
 | ||||||
|  | // kLoop fusion template satisfies:
 | ||||||
|  | //   - all ops in the fusion pattern are element-wise.
 | ||||||
|  | //   - all the shapes of outputs of fusion pattern are same or have same number
 | ||||||
|  | //   of elements, and thus can fit into a same parallel loop.
 | ||||||
|  | //
 | ||||||
|  | // kInput fusion template satisfies:
 | ||||||
|  | //   - any op in the fusion pattern is either element-wise or a reduction.
 | ||||||
|  | //   - if a op is a reduction, its output cannot be consumed by other
 | ||||||
|  | //     ops in the same fusion pattern.
 | ||||||
|  | //   - all the effective shapes of outputs of fusion pattern are same.
 | ||||||
|  | //     - For element-wise op, its effective shape is its output shape.
 | ||||||
|  | //     - For reduction op, its effective shape is its operand shape.
 | ||||||
|  | //   - currently our downstreaming codegen engine only support 2d -> 1d tensor
 | ||||||
|  | //   reduction. TODO(disc): lift this limitation.
 | ||||||
|  | //     - 2D row reduction: out[i] = sum({in[i][j] for all j})
 | ||||||
|  | //     - 2D column reduction: out[j] = sum({in[i][j] for all i}
 | ||||||
|  | class FusionPlanner { | ||||||
|  |  public: | ||||||
|  |   explicit FusionPlanner(const FusionOptions& options, Block* block) | ||||||
|  |       : options_(options), block_(block) { | ||||||
|  |     // Move up metadata-only ops (e.g. dim, shape_of) as far as possible.
 | ||||||
|  |     MoveUpMetadataOnlyOpsForFusion(); | ||||||
|  | 
 | ||||||
|  |     for (Operation& op : *block) { | ||||||
|  |       op_list_.push_back(&op); | ||||||
|  |     } | ||||||
|  |     shape_analysis_.reset(new ShapeConstraintAnalysis(op_list_)); | ||||||
|  |     cycle_detector_.reset(new GraphCycles(op_list_.size())); | ||||||
|  |     BuildNodeMap(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Returns a fusion plan if success, otherwise none.
 | ||||||
|  |   llvm::Optional<FusionPlan> Run() { | ||||||
|  |     // Greedily search connected fusible pattern, and ops belonging to
 | ||||||
|  |     // a same fusion pattern are grouped into a cluster.
 | ||||||
|  |     RunEdgeContractionLoop(); | ||||||
|  | 
 | ||||||
|  |     // After doing edge contraction, each unique cluster having size
 | ||||||
|  |     // more than one represents a potential fusion pattern.
 | ||||||
|  |     // We collect all these clusters and construct a fusion plan.
 | ||||||
|  |     FusionPlan plan; | ||||||
|  |     DenseSet<Cluster*> seen_clusters; | ||||||
|  |     for (Operation* op : op_list_) { | ||||||
|  |       Cluster* cluster = GetClusterForNode(op); | ||||||
|  |       if (!seen_clusters.insert(cluster).second) continue; | ||||||
|  |       FusionPattern& fusion_pattern = cluster->fused_pattern(); | ||||||
|  |       // Make sure the ops in a fusion pattern are in topological ordering.
 | ||||||
|  |       fusion_pattern.sortFusionOpListBy(op_to_node_id_); | ||||||
|  |       if (!fusion_pattern.isFusible() || fusion_pattern.effectiveSize() <= 1) { | ||||||
|  |         continue; | ||||||
|  |       } | ||||||
|  |       plan.emplace_back(fusion_pattern); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Re-order ops inside the blocks to make sure all producers are placed
 | ||||||
|  |     // before its consumers after fusion.
 | ||||||
|  |     ReorderOperationsInsideBlock(); | ||||||
|  |     return plan; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Returns the op_list this planner operates on.
 | ||||||
|  |   const SmallVectorImpl<Operation*>& op_list() const { return op_list_; } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   // Represent a (partial) fused pattern
 | ||||||
|  |   class Cluster { | ||||||
|  |    public: | ||||||
|  |     Cluster(int node_id, FusionPlanner* planner) | ||||||
|  |         : node_id_(node_id), pattern_(planner->op_list()[node_id]) {} | ||||||
|  | 
 | ||||||
|  |     // Merges `other` into this cluster, and clears `other`.
 | ||||||
|  |     void Merge(Cluster* other) { | ||||||
|  |       pattern_.mergeInplace(other->fused_pattern()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // The number of nodes in this cluster.
 | ||||||
|  |     int cluster_size() { return pattern_.size(); } | ||||||
|  | 
 | ||||||
|  |     // The ID of the cluster as represented in `cycle_detector_`.
 | ||||||
|  |     int cycles_graph_node_id() const { return node_id_; } | ||||||
|  | 
 | ||||||
|  |     // Sets the ID of the cluster as represented in `cycle_detector_`.
 | ||||||
|  |     void set_cycles_graph_node_id(int cycles_graph_node_id) { | ||||||
|  |       node_id_ = cycles_graph_node_id; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Currently the fused pattern this cluster holds.
 | ||||||
|  |     FusionPattern& fused_pattern() { return pattern_; } | ||||||
|  | 
 | ||||||
|  |    private: | ||||||
|  |     // ID of the representative node of this cluster.
 | ||||||
|  |     int node_id_; | ||||||
|  | 
 | ||||||
|  |     // the fused pattern this cluster holds.
 | ||||||
|  |     FusionPattern pattern_; | ||||||
|  |   }; | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   // Returns a new cluster with specified `cycles_graph_node_id`
 | ||||||
|  |   Cluster* MakeCluster(int cycles_graph_node_id) { | ||||||
|  |     cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this)); | ||||||
|  |     return cluster_storage_.back().get(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Metadata ops (e.g. shapeOf, dimOp) don't change data thus we move forward
 | ||||||
|  |   // them as far as possible inside the same block to enable more fusion
 | ||||||
|  |   // opportunities.
 | ||||||
|  |   void MoveUpMetadataOnlyOpsForFusion() { | ||||||
|  |     SmallVector<Operation*, 4> ops; | ||||||
|  |     for (Operation& op : *block_) { | ||||||
|  |       ops.push_back(&op); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     auto inBlock = [&](Operation* op, Block* block) { | ||||||
|  |       return op && op->getBlock() == block; | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     for (Operation* op : ops) { | ||||||
|  |       Block* block = op->getBlock(); | ||||||
|  |       if (isa<shape::ShapeOfOp>(op)) { | ||||||
|  |         Operation* definingOp = op->getOperand(0).getDefiningOp(); | ||||||
|  |         if (!inBlock(definingOp, block)) { | ||||||
|  |           op->moveBefore(block, block->begin()); | ||||||
|  |         } else { | ||||||
|  |           op->moveAfter(definingOp); | ||||||
|  |         } | ||||||
|  |       } else if (isa<memref::DimOp>(op)) { | ||||||
|  |         Operation* firstOperandOp = op->getOperand(0).getDefiningOp(); | ||||||
|  |         Operation* secondOperandOp = op->getOperand(1).getDefiningOp(); | ||||||
|  |         if (!inBlock(firstOperandOp, block) && | ||||||
|  |             !inBlock(secondOperandOp, block)) { | ||||||
|  |           op->moveBefore(block, block->begin()); | ||||||
|  |         } else if (!inBlock(firstOperandOp, block)) { | ||||||
|  |           op->moveAfter(secondOperandOp); | ||||||
|  |         } else if (!inBlock(secondOperandOp, block)) { | ||||||
|  |           op->moveAfter(firstOperandOp); | ||||||
|  |         } else if (firstOperandOp->isBeforeInBlock(secondOperandOp)) { | ||||||
|  |           op->moveAfter(secondOperandOp); | ||||||
|  |         } else { | ||||||
|  |           op->moveAfter(firstOperandOp); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Returns all the values touched by this op or its nested ops.
 | ||||||
|  |   SmallVector<Value, 4> GetAllPossibleUsedValues(Operation* op) { | ||||||
|  |     SmallVector<Value, 4> values; | ||||||
|  |     op->walk([&](Operation* nest_op) { | ||||||
|  |       for (Value v : nest_op->getOperands()) { | ||||||
|  |         values.push_back(v); | ||||||
|  |       } | ||||||
|  |     }); | ||||||
|  |     return values; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Builds the initial dependency graph.
 | ||||||
|  |   void BuildNodeMap() { | ||||||
|  |     int num_nodes = op_list_.size(); | ||||||
|  |     for (int node_id = 0; node_id < num_nodes; ++node_id) { | ||||||
|  |       Operation* op = op_list_[node_id]; | ||||||
|  |       MakeCluster(node_id); | ||||||
|  |       op_to_node_id_[op] = node_id; | ||||||
|  |       leader_for_node_.insert(node_id); | ||||||
|  |       for (Value operand : GetAllPossibleUsedValues(op)) { | ||||||
|  |         Operation* operand_op = FindLastWriter(operand); | ||||||
|  |         // Only consider the operand_op inside the target block.
 | ||||||
|  |         auto iter = op_to_node_id_.find(operand_op); | ||||||
|  |         if (iter == op_to_node_id_.end()) { | ||||||
|  |           continue; | ||||||
|  |         } | ||||||
|  |         // Add an edge to connect the last writer and the current consumer.
 | ||||||
|  |         cycle_detector_->InsertEdge(iter->second, node_id); | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       // For some ops (e.g. lmhlo ops), some operands are the output memrefs
 | ||||||
|  |       // Thus these operands are supposed to be updated.
 | ||||||
|  |       // Suppose that a op (or its nested ops) can only write the buffers
 | ||||||
|  |       // explicit passed in as operands of this op.
 | ||||||
|  |       int num_input_operand = op->getNumOperands() - getNumResultOperands(op); | ||||||
|  |       for (Value v : op->getOperands().drop_front(num_input_operand)) { | ||||||
|  |         auto it = last_writer_.try_emplace(v, op); | ||||||
|  |         (void)it; | ||||||
|  |         // Currently, a buffer is only supposed to be written once (as the
 | ||||||
|  |         // output operand of one lmhlo op).
 | ||||||
|  |         assert(it.second); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Returns the cluster contains this op.
 | ||||||
|  |   Cluster* GetClusterForNode(Operation* n) { | ||||||
|  |     int id = op_to_node_id_[n]; | ||||||
|  |     id = leader_for_node_.getLeaderValue(id); | ||||||
|  |     return cluster_storage_[id].get(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Returns the cluster contains the op having `node_id`.
 | ||||||
|  |   Cluster* GetClusterForCyclesGraphNode(int node_id) { | ||||||
|  |     return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Merges the clusters `cluster_from` and `cluster_to`.
 | ||||||
|  |   bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) { | ||||||
|  |     int from = cluster_from->cycles_graph_node_id(); | ||||||
|  |     int to = cluster_to->cycles_graph_node_id(); | ||||||
|  | 
 | ||||||
|  |     auto optional_merged_node = cycle_detector_->ContractEdge(from, to); | ||||||
|  |     if (!optional_merged_node.hasValue()) { | ||||||
|  |       llvm::dbgs() << "Could not contract " << from << " -> " << to | ||||||
|  |                    << " because contracting the edge would create a cycle."; | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Merge the clusters.
 | ||||||
|  |     cluster_from->Merge(cluster_to); | ||||||
|  |     cluster_from->set_cycles_graph_node_id(*optional_merged_node); | ||||||
|  | 
 | ||||||
|  |     // Merge the UnionFind Set.
 | ||||||
|  |     leader_for_node_.unionSets(from, to); | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   using FnTy = llvm::function_ref<bool(Cluster*, Cluster*)>; | ||||||
|  |   bool ForEachEdgeInPostOrder(FnTy fn, bool enable_cross_fusion = false) { | ||||||
|  |     bool changed = false; | ||||||
|  |     for (int32_t node : cycle_detector_->AllNodesInPostOrder()) { | ||||||
|  |       Cluster* cluster_from = GetClusterForCyclesGraphNode(node); | ||||||
|  |       // Make a copy of the set of successors because we may modify the graph in
 | ||||||
|  |       // TryToContractEdge.
 | ||||||
|  |       std::vector<int32_t> successors_copy = | ||||||
|  |           cycle_detector_->SuccessorsCopy(cluster_from->cycles_graph_node_id()); | ||||||
|  | 
 | ||||||
|  |       for (int to : successors_copy) { | ||||||
|  |         Cluster* cluster_to = GetClusterForCyclesGraphNode(to); | ||||||
|  |         bool contracted_edge = fn(cluster_from, cluster_to); | ||||||
|  |         changed |= contracted_edge; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (!enable_cross_fusion) return changed; | ||||||
|  | 
 | ||||||
|  |     // To enable even more fusion opportunities (e.g. horizontal fusion)
 | ||||||
|  |     for (int32_t lhs : cycle_detector_->AllNodesInPostOrder()) { | ||||||
|  |       Cluster* cluster_lhs = GetClusterForCyclesGraphNode(lhs); | ||||||
|  |       if (!cluster_lhs) { | ||||||
|  |         continue; | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       for (int32_t rhs : cycle_detector_->AllNodesInPostOrder()) { | ||||||
|  |         Cluster* cluster_rhs = GetClusterForCyclesGraphNode(rhs); | ||||||
|  |         if (!cluster_rhs || cluster_lhs == cluster_rhs) { | ||||||
|  |           continue; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         bool contracted_edge = fn(cluster_lhs, cluster_rhs); | ||||||
|  |         changed |= contracted_edge; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     return changed; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // This function check if fusing `from` with `to` is valid and if so perform
 | ||||||
|  |   // the merge. The validity is based on the operations in the clusters and
 | ||||||
|  |   // the compatibility of the shapes of the outputs of the would-be fused
 | ||||||
|  |   // clusters.
 | ||||||
|  |   // Returns true is the merge was performed.
 | ||||||
|  |   bool TryToContractEdge(Cluster* from, Cluster* to) { | ||||||
|  |     // Try merge and check if valid.
 | ||||||
|  |     if (!from->fused_pattern().isMergeable(to->fused_pattern())) return false; | ||||||
|  |     FusionPattern fused_pattern = | ||||||
|  |         from->fused_pattern().merge(to->fused_pattern()); | ||||||
|  |     auto& op_list = fused_pattern.getOpList(); | ||||||
|  |     auto& operands = fused_pattern.getOperands(); | ||||||
|  |     auto& results = fused_pattern.getResults(); | ||||||
|  | 
 | ||||||
|  |     if (results.size() + operands.size() > | ||||||
|  |         options_.max_num_arguments_per_kernel) { | ||||||
|  |       // some backend devices (e.g. GPU) do not support a kernel with
 | ||||||
|  |       // too many arguments.
 | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // We currently do not support a constant op as final output of a fusion
 | ||||||
|  |     // pattern.
 | ||||||
|  |     // TODO(disc): copy small const in case necessary.
 | ||||||
|  |     for (Value result : results) { | ||||||
|  |       Operation* result_op = FindLastWriter(result); | ||||||
|  |       assert(result_op); | ||||||
|  |       if (isa<lmhlo::ConstOp>(result_op)) { | ||||||
|  |         return false; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // ReduceOp can not have consumer within the fusion pattern.
 | ||||||
|  |     for (Operation* op : op_list) { | ||||||
|  |       if (!isa<lmhlo::ReduceOp>(op)) continue; | ||||||
|  |       int num_input_operand = op->getNumOperands() - getNumResultOperands(op); | ||||||
|  |       for (Value v : op->getOperands().drop_front(num_input_operand)) { | ||||||
|  |         for (Operation* user : getValueUsers(v)) { | ||||||
|  |           if (user == op) continue; | ||||||
|  |           if (std::find(op_list.begin(), op_list.end(), user) != | ||||||
|  |               op_list.end()) { | ||||||
|  |             return false; | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // All outputs of a fusion pattern should have compatible shape.
 | ||||||
|  |     // Here `compatible` means:
 | ||||||
|  |     // - if `to` and `from` are both kInput fusion, all output should have same
 | ||||||
|  |     // shape.
 | ||||||
|  |     // - otherwise, all output should have same number of elements.
 | ||||||
|  | 
 | ||||||
|  |     // No outside users, these ops may be eliminated. We fused it here and let
 | ||||||
|  |     // latter pass to do such DCE.
 | ||||||
|  |     if (results.empty()) return true; | ||||||
|  | 
 | ||||||
|  |     bool check_same_shape = (to->fused_pattern().isKInputFusion() && | ||||||
|  |                              from->fused_pattern().isKInputFusion()); | ||||||
|  |     auto get_effective_shape = [&](Value v) { | ||||||
|  |       auto result_op = FindLastWriter(v); | ||||||
|  |       assert(result_op); | ||||||
|  |       // effective shape of reduce op is its operand's shape.
 | ||||||
|  |       return isa<lmhlo::ReduceOp>(result_op) ? result_op->getOperand(0) : v; | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     Value ref_shape = get_effective_shape(results[0]); | ||||||
|  |     if (!llvm::all_of(results, [&](Value result) { | ||||||
|  |           Value shape = get_effective_shape(result); | ||||||
|  |           return check_same_shape | ||||||
|  |                      ? shape_analysis_->HasSameShape(ref_shape, shape) | ||||||
|  |                      : shape_analysis_->HasSameNumElements(ref_shape, shape); | ||||||
|  |         })) { | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     return MergeClusters(from, to); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Greedily fuse connected node.
 | ||||||
|  |   bool RunEdgeContractionLoop() { | ||||||
|  |     using std::placeholders::_1; | ||||||
|  |     using std::placeholders::_2; | ||||||
|  |     bool changed = false; | ||||||
|  | 
 | ||||||
|  |     // Run fusion pass repeatedly until nothing to be fused
 | ||||||
|  |     while (ForEachEdgeInPostOrder( | ||||||
|  |         std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2), false)) { | ||||||
|  |       // empty statement by design
 | ||||||
|  |     } | ||||||
|  |     return changed; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Here `value` is supported to be a pointer to buffer.
 | ||||||
|  |   // Returns the defining op of `value `if no known op updates the buffer,
 | ||||||
|  |   // otherwise returns the last op that updates the buffer pointed by the
 | ||||||
|  |   // `value`.
 | ||||||
|  |   Operation* FindLastWriter(Value value) { | ||||||
|  |     auto it = last_writer_.find(value); | ||||||
|  |     if (it != last_writer_.end()) { | ||||||
|  |       return it->second; | ||||||
|  |     } | ||||||
|  |     return value.getDefiningOp(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Re-order ops inside the block to make sure that producers are before
 | ||||||
|  |   // consumers after fusion.
 | ||||||
|  |   void ReorderOperationsInsideBlock() { | ||||||
|  |     auto reorder_func = [&](Cluster* from, Cluster* to) { | ||||||
|  |       FusionPattern& from_pattern = from->fused_pattern(); | ||||||
|  |       FusionPattern& to_pattern = to->fused_pattern(); | ||||||
|  | 
 | ||||||
|  |       Operation* last_op_in_from = from_pattern.getOpList().back(); | ||||||
|  |       for (Operation* op : llvm::reverse(to_pattern.getOpList())) { | ||||||
|  |         if (!last_op_in_from->isBeforeInBlock(op)) | ||||||
|  |           op->moveAfter(last_op_in_from); | ||||||
|  |       } | ||||||
|  |       return false; | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     ForEachEdgeInPostOrder(reorder_func); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // hyper-parameters that controls the behaviour of the fusion planner.
 | ||||||
|  |   FusionOptions options_; | ||||||
|  | 
 | ||||||
|  |   // The block that fusion planner works on.
 | ||||||
|  |   Block* block_; | ||||||
|  | 
 | ||||||
|  |   // Ops inside the block
 | ||||||
|  |   SmallVector<Operation*, 4> op_list_; | ||||||
|  | 
 | ||||||
|  |   // Shape equality checker
 | ||||||
|  |   std::unique_ptr<ShapeConstraintAnalysis> shape_analysis_; | ||||||
|  | 
 | ||||||
|  |   // op -> node_id
 | ||||||
|  |   DenseMap<Operation*, int> op_to_node_id_; | ||||||
|  | 
 | ||||||
|  |   // make sure not introduce cycle after fusion
 | ||||||
|  |   std::unique_ptr<GraphCycles> cycle_detector_; | ||||||
|  |   std::vector<std::unique_ptr<Cluster>> cluster_storage_; | ||||||
|  | 
 | ||||||
|  |   // a UnionFind set. Each set represents a (partial) fused pattern
 | ||||||
|  |   // and has a leader as representation.
 | ||||||
|  |   EquivalenceClasses<int32_t> leader_for_node_; | ||||||
|  | 
 | ||||||
|  |   // Here `value` is supported to be a pointer to buffer.
 | ||||||
|  |   // Returns the defining op of `value `if no known op updates the buffer,
 | ||||||
|  |   // otherwise returns the last op that updates the buffer pointed by the
 | ||||||
|  |   // `value`.
 | ||||||
|  |   DenseMap<Value, Operation*> last_writer_; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | struct LhloFusionPass : public LhloFusionPassBase<LhloFusionPass> { | ||||||
|  |   using LhloFusionPassBase<LhloFusionPass>::LhloFusionPassBase; | ||||||
|  |   explicit LhloFusionPass(int max_num_arguments_per_kernel) | ||||||
|  |       : LhloFusionPassBase<LhloFusionPass>::LhloFusionPassBase() { | ||||||
|  |     this->max_num_arguments_per_kernel_ = max_num_arguments_per_kernel; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   void runOnFunction() override { | ||||||
|  |     FuncOp func = getFunction(); | ||||||
|  | 
 | ||||||
|  |     // collect all blocks inside the function.
 | ||||||
|  |     SmallVector<Block*, 4> blocks; | ||||||
|  |     CollectBlocksInsideFunction(func, blocks); | ||||||
|  | 
 | ||||||
|  |     // process each block and do fusion within a block.
 | ||||||
|  |     FusionOptions options; | ||||||
|  |     options.max_num_arguments_per_kernel = max_num_arguments_per_kernel_; | ||||||
|  |     for (Block* block : blocks) { | ||||||
|  |       FusionPlanner planner(options, block); | ||||||
|  |       llvm::Optional<FusionPlan> plan = planner.Run(); | ||||||
|  |       if (!plan) { | ||||||
|  |         emitError(func.getLoc(), | ||||||
|  |                   "an error occurs while trying to find fusion candidates"); | ||||||
|  |         signalPassFailure(); | ||||||
|  |         return; | ||||||
|  |       } | ||||||
|  |       if (!ApplyFusionPlan(*plan)) { | ||||||
|  |         emitError(func.getLoc(), "apply fusion plan failed"); | ||||||
|  |         signalPassFailure(); | ||||||
|  |         return; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   bool ApplyFusionPlan(FusionPlan& plan) { | ||||||
|  |     for (FusionPattern& pattern : plan) { | ||||||
|  |       auto& op_list = pattern.getOpList(); | ||||||
|  |       OpBuilder b(op_list.back()); | ||||||
|  | 
 | ||||||
|  |       // Get the fused locations
 | ||||||
|  |       SmallVector<Location, 4> locations; | ||||||
|  |       locations.reserve(op_list.size()); | ||||||
|  |       for (Operation* op : op_list) { | ||||||
|  |         locations.push_back(op->getLoc()); | ||||||
|  |       } | ||||||
|  |       Location fused_loc = | ||||||
|  |           FusedLoc::get(op_list.back()->getContext(), locations); | ||||||
|  | 
 | ||||||
|  |       // Move ops inside fusion pattern to the region attached to the fusion op.
 | ||||||
|  |       FusionOp fusion = b.create<lmhlo::FusionOp>(fused_loc); | ||||||
|  |       Region& region = fusion.region(); | ||||||
|  |       Block& block = region.front(); | ||||||
|  |       for (Operation* op : llvm::reverse(op_list)) { | ||||||
|  |         op->moveBefore(&block, block.begin()); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   void CollectBlocksInsideFunction(FuncOp op, SmallVectorImpl<Block*>& blocks) { | ||||||
|  |     op.walk([&](Block* block) { | ||||||
|  |       // It does not make sense to fuse the region attached to these ops.
 | ||||||
|  |       if (!isa<lmhlo::ReduceOp, lmhlo::FusionOp>(block->getParentOp())) | ||||||
|  |         blocks.push_back(block); | ||||||
|  |     }); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | std::unique_ptr<OperationPass<FuncOp>> createLhloFusionPass( | ||||||
|  |     int max_num_arguments_per_kernel) { | ||||||
|  |   return std::make_unique<LhloFusionPass>(max_num_arguments_per_kernel); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace lmhlo
 | ||||||
|  | }  // namespace mlir
 | ||||||
|  | @ -0,0 +1,286 @@ | ||||||
|  | // RUN: mlir-hlo-opt --lhlo-fusion -split-input-file %s -o - | FileCheck %s | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @simple_kloop_fusion | ||||||
|  | // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>) -> memref<?x?xf32> | ||||||
|  | func @simple_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, | ||||||
|  |                           %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> memref<?x?xf32> { | ||||||
|  |   // CHECK: "lmhlo.fusion"() ( { | ||||||
|  |   // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: }) : () -> () | ||||||
|  |   // CHECK: return %[[ARG3]] : memref<?x?xf32> | ||||||
|  |   "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   "lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   return %arg3 : memref<?x?xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @simple_multi_output_kloop_fusion | ||||||
|  | // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) | ||||||
|  | func @simple_multi_output_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, | ||||||
|  |                           %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) { | ||||||
|  |   // CHECK: "lmhlo.fusion"() ( { | ||||||
|  |   // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: }) : () -> () | ||||||
|  |   // CHECK: return %[[ARG1]], %[[ARG3]] : memref<?x?xf32>, memref<?x?xf32> | ||||||
|  |   "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   "lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   return %arg1, %arg3 : memref<?x?xf32>, memref<?x?xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @simple_multi_output_kloop_fusion_with_reorder | ||||||
|  | // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>, %[[ARG4:.*]]: memref<2xindex>, %[[ARG5:.*]]: memref<?x?xf32>) | ||||||
|  | func @simple_multi_output_kloop_fusion_with_reorder(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, | ||||||
|  |                           %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>, | ||||||
|  |                           %arg4: memref<2xindex>, %arg5:  memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) { | ||||||
|  |   // CHECK: "lmhlo.fusion"() ( { | ||||||
|  |   // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: }) : () -> () | ||||||
|  |   // CHECK: "lmhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[ARG4]], %[[ARG5]]) | ||||||
|  |   // CHECK: return %[[ARG1]], %[[ARG3]], %[[ARG5]] : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32> | ||||||
|  |   "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   "lmhlo.dynamic_broadcast_in_dim"(%arg1, %arg4, %arg5) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (memref<?x?xf32>, memref<2xindex>, memref<?x?xf32>) -> () | ||||||
|  |   "lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   return %arg1, %arg3, %arg5 : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @same_num_elements_multi_output_kloop_fusion | ||||||
|  | // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<2xi64>, %[[ARG3:.*]]: memref<?x?x?xf32>, %[[ARG4:.*]]: memref<?x?x?xf32>, %[[ARG5:.*]]: memref<?x?x?xf32>) | ||||||
|  | func @same_num_elements_multi_output_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, | ||||||
|  |                           %arg2: memref<2xi64>, %arg3: memref<?x?x?xf32>, | ||||||
|  |                           %arg4: memref<?x?x?xf32>, %arg5:  memref<?x?x?xf32>) -> (memref<?x?xf32>, memref<?x?x?xf32>) { | ||||||
|  |   // CHECK: "lmhlo.fusion"() ( { | ||||||
|  |   // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.dynamic_reshape"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) | ||||||
|  |   // CHECK: "lmhlo.add"(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> () | ||||||
|  |   // CHECK: }) : () -> () | ||||||
|  |   // CHECK: return %[[ARG1]], %[[ARG5]] : memref<?x?xf32>, memref<?x?x?xf32> | ||||||
|  |   "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   "lmhlo.dynamic_reshape"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<2xi64>, memref<?x?x?xf32>) -> () | ||||||
|  |   "lmhlo.add"(%arg3, %arg4, %arg5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> () | ||||||
|  |   return %arg1, %arg5 : memref<?x?xf32>, memref<?x?x?xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @check_not_kloop_fusion | ||||||
|  | func @check_not_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) { | ||||||
|  |   // CHECK-NOT: "lmhlo.fusion" | ||||||
|  |   "lmhlo.add"(%arg0, %arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   "lmhlo.subtract"(%arg2, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   return %arg1, %arg3: memref<?x?xf32>, memref<?x?xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @kloop_fusion_with_dealloc | ||||||
|  | // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>) | ||||||
|  | func @kloop_fusion_with_dealloc(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) { | ||||||
|  |   // CHECK: %[[TMP3:.*]] = memref.alloc | ||||||
|  |   // CHECK: %[[TMP5:.*]] = memref.alloc | ||||||
|  |   // CHECK: %[[TMP9:.*]] = memref.alloc | ||||||
|  |   // CHECK: %[[TMP13:.*]] = memref.alloc | ||||||
|  |   // CHECK: %[[TMP16:.*]] = memref.alloc | ||||||
|  |   // CHECK: "lmhlo.fusion"() ( { | ||||||
|  |   // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.multiply"(%[[ARG0]], %[[ARG1]], %[[TMP5]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.abs"(%[[TMP3]], %[[TMP9]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.abs"(%[[TMP5]], %[[TMP13]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.multiply"(%[[TMP9]], %[[TMP13]], %[[TMP16]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: }) : () -> () | ||||||
|  |   // CHECK: memref.dealloc %[[TMP3]] : memref<?x?xf32> | ||||||
|  |   // CHECK: memref.dealloc %[[TMP5]] : memref<?x?xf32> | ||||||
|  |   // CHECK: memref.dealloc %[[TMP13]] : memref<?x?xf32> | ||||||
|  |   // CHECK: return %[[TMP9]], %[[TMP16]] : memref<?x?xf32>, memref<?x?xf32> | ||||||
|  |   %c0 = constant 0 : index | ||||||
|  |   %c1 = constant 1 : index | ||||||
|  |   %0 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex> | ||||||
|  |   %1 = tensor.extract %0[%c0] : tensor<2xindex> | ||||||
|  |   %2 = tensor.extract %0[%c1] : tensor<2xindex> | ||||||
|  |   %3 = memref.alloc(%1, %2) : memref<?x?xf32> | ||||||
|  |   "lmhlo.add"(%arg0, %arg1, %3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   %4 = memref.alloc(%1, %2) : memref<?x?xf32> | ||||||
|  |   "lmhlo.multiply"(%arg0, %arg1, %4) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   %5 = shape.shape_of %3 : memref<?x?xf32> -> tensor<2xindex> | ||||||
|  |   %6 = tensor.extract %5[%c0] : tensor<2xindex> | ||||||
|  |   %7 = tensor.extract %5[%c1] : tensor<2xindex> | ||||||
|  |   %8 = memref.alloc(%6, %7) : memref<?x?xf32> | ||||||
|  |   "lmhlo.abs"(%3, %8) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   memref.dealloc %3 : memref<?x?xf32> | ||||||
|  |   %9 = shape.shape_of %4 : memref<?x?xf32> -> tensor<2xindex> | ||||||
|  |   %10 = tensor.extract %9[%c0] : tensor<2xindex> | ||||||
|  |   %11 = tensor.extract %9[%c1] : tensor<2xindex> | ||||||
|  |   %12 = memref.alloc(%10, %11) : memref<?x?xf32> | ||||||
|  |   "lmhlo.abs"(%4, %12) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   memref.dealloc %4 : memref<?x?xf32> | ||||||
|  |   %13 = shape.shape_of %8 : memref<?x?xf32> -> tensor<2xindex> | ||||||
|  |   %14 = tensor.extract %13[%c0] : tensor<2xindex> | ||||||
|  |   %15 = tensor.extract %13[%c1] : tensor<2xindex> | ||||||
|  |   %16 = memref.alloc(%14, %15) : memref<?x?xf32> | ||||||
|  |   "lmhlo.multiply"(%8, %12, %16) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   memref.dealloc %12 : memref<?x?xf32> | ||||||
|  |   return %8, %16 : memref<?x?xf32>, memref<?x?xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @simple_kinput | ||||||
|  | // CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<f32> | ||||||
|  | func @simple_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?xf32>, %init: memref<f32>) -> memref<?xf32> { | ||||||
|  |   // CHECK: "lmhlo.fusion"() ( { | ||||||
|  |   // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( { | ||||||
|  |   // CHECK: }) : () -> () | ||||||
|  |   // CHECK: return %[[ARG2]] : memref<?xf32> | ||||||
|  |   "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   "lmhlo.reduce"(%arg1, %init, %arg2) ( { | ||||||
|  |   ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>): | ||||||
|  |     "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> () | ||||||
|  |     "lmhlo.terminator"() : () -> () | ||||||
|  |   } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () | ||||||
|  |   return %arg2: memref<?xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @multi_output_kinput | ||||||
|  | // CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<f32> | ||||||
|  | func @multi_output_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?xf32>, %init: memref<f32>) -> (memref<?x?xf32>, memref<?xf32>) { | ||||||
|  |   // CHECK: "lmhlo.fusion"() ( { | ||||||
|  |   // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( { | ||||||
|  |   // CHECK: }) : () -> () | ||||||
|  |   // CHECK: return %[[ARG1]], %[[ARG2]] : memref<?x?xf32>, memref<?xf32> | ||||||
|  |   "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   "lmhlo.reduce"(%arg1, %init, %arg2) ( { | ||||||
|  |   ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>): | ||||||
|  |     "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> () | ||||||
|  |     "lmhlo.terminator"() : () -> () | ||||||
|  |   } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () | ||||||
|  |   return %arg1, %arg2: memref<?x?xf32>, memref<?xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @row_red_and_row_red_kinput | ||||||
|  | // CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?xf32>, %[[ARG4:.*]]: memref<?xf32>, %[[ARG5:.*]]: memref<?x?xf32>, %[[ARG6:.*]]: memref<f32> | ||||||
|  | func @row_red_and_row_red_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?xf32>, %arg4: memref<?xf32>, %arg5: memref<?x?xf32>, %init: memref<f32>) -> (memref<?xf32>, memref<?xf32>) { | ||||||
|  |   // CHECK: "lmhlo.fusion"() ( { | ||||||
|  |   // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( { | ||||||
|  |   // CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( { | ||||||
|  |   // CHECK: }) : () -> () | ||||||
|  |   // CHECK: return %[[ARG3]], %[[ARG4]] : memref<?xf32>, memref<?xf32> | ||||||
|  |   "lmhlo.add"(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   "lmhlo.abs"(%arg2, %arg5) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   "lmhlo.reduce"(%arg5, %init, %arg3) ( { | ||||||
|  |   ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>): | ||||||
|  |     "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> () | ||||||
|  |     "lmhlo.terminator"() : () -> () | ||||||
|  |   } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () | ||||||
|  |   "lmhlo.reduce"(%arg2, %init, %arg4) ( { | ||||||
|  |   ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>): | ||||||
|  |     "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> () | ||||||
|  |     "lmhlo.terminator"() : () -> () | ||||||
|  |   } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () | ||||||
|  |   return %arg3, %arg4: memref<?xf32>, memref<?xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @row_red_and_col_red_kinput | ||||||
|  | // CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?xf32>, %[[ARG4:.*]]: memref<?xf32>, %[[ARG5:.*]]: memref<?x?xf32>, %[[ARG6:.*]]: memref<f32> | ||||||
|  | func @row_red_and_col_red_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?xf32>, %arg4: memref<?xf32>, %arg5: memref<?x?xf32>, %init: memref<f32>) -> (memref<?xf32>, memref<?xf32>) { | ||||||
|  |   // CHECK: "lmhlo.fusion"() ( { | ||||||
|  |   // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( { | ||||||
|  |   // CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( { | ||||||
|  |   // CHECK: }) : () -> () | ||||||
|  |   // CHECK: return %[[ARG3]], %[[ARG4]] : memref<?xf32>, memref<?xf32> | ||||||
|  |   "lmhlo.add"(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   "lmhlo.abs"(%arg2, %arg5) : (memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   "lmhlo.reduce"(%arg5, %init, %arg3) ( { | ||||||
|  |   ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>): | ||||||
|  |     "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> () | ||||||
|  |     "lmhlo.terminator"() : () -> () | ||||||
|  |   } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () | ||||||
|  |   "lmhlo.reduce"(%arg2, %init, %arg4) ( { | ||||||
|  |   ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>): | ||||||
|  |     "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> () | ||||||
|  |     "lmhlo.terminator"() : () -> () | ||||||
|  |   } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () | ||||||
|  |   return %arg3, %arg4: memref<?xf32>, memref<?xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @reduce_should_not_have_consumer_in_the_fusion | ||||||
|  | // CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32> | ||||||
|  | func @reduce_should_not_have_consumer_in_the_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) | ||||||
|  | -> (memref<?x?xf32>, memref<?xf32>) { | ||||||
|  |   // CHECK: %[[TMP4:.*]] = memref.alloc | ||||||
|  |   // CHECK: %[[TMP7:.*]] = memref.alloc | ||||||
|  |   // CHECK: %[[TMP8:.*]] = memref.alloc | ||||||
|  |   // CHECK: %[[TMP9:.*]] = memref.alloc | ||||||
|  |   // CHECK: "lmhlo.fusion"() ( { | ||||||
|  |   // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP4]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.subtract"(%[[ARG0]], %[[TMP4]], %[[TMP7]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   // CHECK: "lmhlo.constant"(%[[TMP8]]) {value = dense<0.000000e+00> : tensor<f32>} : (memref<f32>) -> () | ||||||
|  |   // CHECK: "lmhlo.reduce"(%[[TMP7]], %[[TMP8]], %[[TMP9]]) ( { | ||||||
|  |   // CHECK: }) : () -> () | ||||||
|  |   // CHECK: memref.dealloc %[[TMP4]] : memref<?x?xf32> | ||||||
|  |   // CHECK: memref.dealloc %[[TMP8]] : memref<f32> | ||||||
|  |   // CHECK: %[[TMP12:.*]] = memref.alloc | ||||||
|  |   // CHECK: "lmhlo.add"(%[[TMP9]], %[[TMP9]], %[[TMP12]]) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> () | ||||||
|  |   // CHECK: memref.dealloc %[[TMP9]] : memref<?xf32> | ||||||
|  |   // CHECK: return %[[TMP7]], %[[TMP12]] : memref<?x?xf32>, memref<?xf32> | ||||||
|  |   %c1 = constant 1 : index | ||||||
|  |   %c0 = constant 0 : index | ||||||
|  |   %0 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex> | ||||||
|  |   %1 = tensor.extract %0[%c0] : tensor<2xindex> | ||||||
|  |   %2 = tensor.extract %0[%c1] : tensor<2xindex> | ||||||
|  |   %3 = memref.alloc(%1, %2) : memref<?x?xf32> | ||||||
|  |   "lmhlo.add"(%arg0, %arg1, %3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   %4 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex> | ||||||
|  |   %5 = tensor.extract %4[%c0] : tensor<2xindex> | ||||||
|  |   %6 = tensor.extract %4[%c1] : tensor<2xindex> | ||||||
|  |   %7 = memref.alloc(%5, %6) : memref<?x?xf32> | ||||||
|  |   "lmhlo.subtract"(%arg0, %3, %7) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () | ||||||
|  |   memref.dealloc %3 : memref<?x?xf32> | ||||||
|  |   %8 = memref.alloc() : memref<f32> | ||||||
|  |   "lmhlo.constant"(%8) {value = dense<0.000000e+00> : tensor<f32>} : (memref<f32>) -> () | ||||||
|  |   %9 = memref.alloc(%5) : memref<?xf32> | ||||||
|  |   "lmhlo.reduce"(%7, %8, %9) ( { | ||||||
|  |   ^bb0(%arg2: memref<f32>, %arg3: memref<f32>, %arg4: memref<f32>):  // no predecessors | ||||||
|  |     "lmhlo.add"(%arg2, %arg3, %arg4) : (memref<f32>, memref<f32>, memref<f32>) -> () | ||||||
|  |     "lmhlo.terminator"() : () -> () | ||||||
|  |   }) {dimensions = dense<1> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () | ||||||
|  |   memref.dealloc %8 : memref<f32> | ||||||
|  |   %10 = shape.shape_of %9 : memref<?xf32> -> tensor<1xindex> | ||||||
|  |   %11 = tensor.extract %10[%c0] : tensor<1xindex> | ||||||
|  |   %12 = memref.alloc(%11) : memref<?xf32> | ||||||
|  |   "lmhlo.add"(%9, %9, %12) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> () | ||||||
|  |   memref.dealloc %9 : memref<?xf32> | ||||||
|  |   return %7, %12 : memref<?x?xf32>, memref<?xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @const_should_not_be_output | ||||||
|  | func @const_should_not_be_output(%arg0: memref<f32>) -> (memref<f32>, memref<f32>) { | ||||||
|  |   // CHECK-NOT: lmhlo.fusion | ||||||
|  |   %0 = memref.alloc() : memref<f32> | ||||||
|  |   "lmhlo.constant"(%0) {value = dense<1.000000e+00> : tensor<f32>} : (memref<f32>) -> () | ||||||
|  |   %1 = memref.alloc() : memref<f32> | ||||||
|  |   "lmhlo.add"(%arg0, %0, %1) : (memref<f32>, memref<f32>, memref<f32>) -> () | ||||||
|  |   return %0, %1 : memref<f32>, memref<f32> | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue