PR #50020: [MLIR][DISC] support fusion on buffer

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50020

This pass implements the logic to group kLoop/kInput fusion patterns on
buffer level. The reason for this is that we can avoid a lot of
headaches to handle `shape-only` consumers specially (e.g. memref.dim,
shape.shapeOf) since shapes are already resolved in buffer world. It may
be better to move this pass to tensor level after more shape
inference/constraint infras are ready on mhlo level.
Copybara import of the project:

--
e31f8344b59aa9860097197585215ea1689b8ff4 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] support fusion on buffer

This pass implements the logic to group kLoop/kInput fusion patterns on
buffer level. The reason for this is that we can avoid a lot of
headaches to handle `shape-only` consumers specially (e.g. memref.dim,
shape.shapeOf) since shapes are already resolved in buffer world. It may
be better to move this pass to tensor level after more shape
inference/constraint infras are ready on mhlo level.

--
35f2eb2791241b0ab5db1ddcaf1b4006278ddccf by Wenyi Zhao <reyizero@gmail.com>:

fix

--
923c8d61f7fe00a2a0df22d5be396508f0667964 by Wenyi Zhao <reyizero@gmail.com>:

fix sanity check failure

PiperOrigin-RevId: 379743424
This commit is contained in:
Wenyi Zhao 2021-06-16 09:50:41 -07:00 committed by TensorFlow MLIR Team
parent 82696f8598
commit 34dc5f2a79
9 changed files with 1555 additions and 0 deletions

38
BUILD
View File

@ -1213,6 +1213,42 @@ cc_library(
],
)
cc_library(
name = "fusion_utils",
srcs = ["lib/Dialect/mhlo/transforms/fusion_utils.cc"],
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"],
deps = [
":lhlo",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
cc_library(
name = "lhlo_fusion",
srcs = ["lib/Dialect/mhlo/transforms/lhlo_fusion.cc"],
deps = [
":cycle_detector",
":fusion_utils",
":lhlo",
":pass_details",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
cc_library(
name = "chlo_legalize_to_hlo",
srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"],
@ -1259,6 +1295,7 @@ cc_library(
],
deps = [
":DiscRalPassIncGen",
":LmhloPassIncGen",
":MhloPassIncGen",
"@llvm-project//mlir:Pass",
],
@ -1316,6 +1353,7 @@ cc_library(
":legalize_trigonometric_to_approximation",
":lhlo",
":lhlo_fuse_linalg",
":lhlo_fusion",
":lhlo_legalize_to_affine",
":lhlo_legalize_to_gpu",
":lhlo_legalize_to_parallel_loops",

View File

@ -25,6 +25,14 @@ namespace mhlo {
#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc"
} // end namespace mhlo
namespace lmhlo {
#define GEN_PASS_CLASSES
#include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"
} // end namespace lmhlo
} // end namespace mlir
namespace mlir {

View File

@ -0,0 +1,244 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_
#include <memory>
#include <vector>
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/Support/Debug.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
// This file implements some helper functions and classes used to do fusion
// & code generation.
namespace mlir {
namespace lmhlo {
// kLoop fusion template satisfies:
// - all ops in the fusion pattern are element-wise.
// - all the shapes of outputs of fusion pattern are same or have same number
// of elements, and thus can fit into a same parallel loop.
//
// kInput fusion template satisfies:
// - any op in the fusion pattern is either element-wise or a reduction.
// - if a op is a reduction, its output cannot be consumed by other
// ops in the same fusion pattern.
// - all the effective shapes of outputs of fusion pattern are same.
// - For element-wise op, its effective shape is its output shape.
// - For reduction op, its effective shape is its operand shape.
// - currently our downstreaming codegen engine only support 2d -> 1d tensor
// reduction. TODO: lift this limitation.
// - 2D row reduction: out[i] = sum({in[i][j] for all j})
// - 2D column reduction: out[j] = sum({in[i][j] for all i})
enum FusionType {
// Not a fusion pattern
kNone,
// kLoop fusion pattern
kLoop,
// kInput fusion pattern and all reduce ops of the fused pattern are row
// reduction
kRowReduction,
// kInput fusion pattern and all reduce ops of the fused pattern are column
// reduction
kColReduction,
};
// Returns true if the op is an elementwise unary lmhlo op.
// TODO: use fusibility interface
bool isElementWiseUnary(Operation* op);
// Returns true if the op is an elementwise binary lmhlo op.
// TODO: use fusibility interface
bool isElementWiseBinary(Operation* op);
// Returns true if the op is an elementwise lmhlo op.
// TODO: use fusibility interface
bool isElementWise(Operation* op);
// Returns true if this op is a rank-2 row reduction.
bool isRank2RowReduction(Operation* op);
// Returns true if this op is a rank-2 column reduction.
bool isRank2ColReduction(Operation* op);
// Returns true if the op is supported by the downstreaming fusion codegen
// engine.
bool isFusible(Operation* op);
// Returns the number of operands that are supposed to be written.
// For some ops (e.g. lmhlo ops), some operands are the output memrefs
// Thus these operands are supposed to be updated.
int getNumResultOperands(Operation* op);
// Returns data users of the value and its aliases (e.g. memref.cast).
// Here non-data users means DimOp, DeallocOp and ShapeOfOp.
SmallVector<Operation*, 4> getValueUsers(Value v);
// Represents a list of lmhlo ops that are going to be fused.
class FusionPattern {
public:
using FusionOpList = SmallVector<Operation*, 4>;
using FusionValueList = SmallVector<Value, 4>;
// Create a new fusion pattern from a single op.
FusionPattern(Operation* op);
// Create a new fusion pattern from the ops inside the lmhlo fusion op.
FusionPattern(lmhlo::FusionOp op);
// Returns the op list this fusion pattern represents.
FusionOpList& getOpList() { return op_list_; }
// Returns the dominant op of this fusion pattern.
// For kLoop fusion, a dominant op may be any op that has external users.
// For kInput fusion, a dominant op may be a row reduction (if exists), or
// a column reduction op.
Operation* getDominantOp() { return dominant_op_; }
// Sets the dominant op to the op provided.
void setDominantOp(Operation* op) { dominant_op_ = op; }
// Returns the fusion kind of the fusion pattern.
FusionType getFusionType() { return fusion_type_; }
// Sets the fusion type to the the type provided.
void setFusionType(FusionType type) { fusion_type_ = type; }
// Returns true if this a fusible fusion pattern.
bool isFusible() { return getFusionType() != FusionType::kNone; }
// Returns true if this fusion pattern is a kLoop fusion.
bool isKLoopFusion() { return getFusionType() == FusionType::kLoop; }
// Returns true if this fusion pattern is a kInput fusion.
bool isKInputFusion() {
return (getFusionType() == FusionType::kRowReduction ||
getFusionType() == FusionType::kColReduction);
}
// Returns true if two fusion patterns can be merged into one bigger fusion
// pattern.
bool isMergeable(FusionPattern& other);
// Merges two fusion patterns and returns the merged pattern. The original
// pattern remains unmodified.
FusionPattern merge(FusionPattern& other);
// Merges two fusion patterns and returns the merged pattern. Replaces the
// original pattern with new merged pattern.
FusionPattern& mergeInplace(FusionPattern& other);
// Returns values that are consumed by the lmhlo ops inside the fusion
// pattern.
FusionValueList& getOperands() { return operands_; }
// Returns values that are outputs of any lmhlo op in the fused pattern and
// have consumers outside the fusion pattern.
FusionValueList& getResults() { return results_; }
// Returns values that are outputs of any lmhlo op in the fused pattern and
// are only consumed by the lmhlo ops inside the fused pattern.
FusionValueList& getInternalResults() { return internal_results_; }
// Returns the size of the ops this fusion pattern contains.
int size() { return op_list_.size(); }
// Returns the effective size (e.g. not counting const ops) of the ops this
// fusion pattern contains.
int effectiveSize();
// Sorts the ops inside the fusion pattern according to the keys provided.
void sortFusionOpListBy(DenseMap<Operation*, int>& op_to_idx);
private:
FusionPattern(SmallVectorImpl<Operation*>& op_list);
private:
// Calculates the inputs and outputs of the fusion pattern.
void calculateOperandsAndResults();
private:
FusionOpList op_list_;
Operation* dominant_op_ = nullptr;
FusionType fusion_type_ = FusionType::kNone;
FusionValueList operands_;
FusionValueList results_;
FusionValueList internal_results_;
};
// Represents a list of disjoint fusion patterns for a block.
using FusionPlan = std::vector<FusionPattern>;
using llvm::EquivalenceClasses;
// Supports using EquivalenceClasses for Value
class ValueWrapper {
public:
explicit ValueWrapper(Value value) : value_(std::move(value)) {}
Value getValue() const { return value_; }
bool operator==(const ValueWrapper& rhs) const {
return getValue() == rhs.getValue();
}
private:
Value value_;
};
bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs);
// This is a simple shape constraint analysis, which is used to
// guide fusion decision (e.g. we only fuse shape-compatible ops).
//
// Currently, We only consider shape equality and same-number-elements equality
// propagation based on the shape constraint traits of elementwise ops (assuming
// that implicit shape broadcast is forbidden).
class ShapeConstraintAnalysis {
public:
explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
PropagateEquality(op_list);
}
// Returns true if `lhs` and `rhs` are supposed to have same shape.
bool HasSameShape(Value lhs, Value rhs) {
return same_shape_impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
}
// Returns true if `lhs` and `rhs` are supposed to have same number of
// elements.
bool HasSameNumElements(Value lhs, Value rhs) {
return same_num_elements_impl_.isEquivalent(ValueWrapper(lhs),
ValueWrapper(rhs));
}
private:
// shape equality propagation based on the shape constrains of
// elementwise ops.
void PropagateEquality(const SmallVectorImpl<Operation*>& op_list);
// a UnionFind set
EquivalenceClasses<ValueWrapper> same_shape_impl_;
EquivalenceClasses<ValueWrapper> same_num_elements_impl_;
};
} // namespace lmhlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_

View File

@ -56,3 +56,12 @@ def LegalizeTensorLoadOpPass : Pass<"lhlo-legalize-tensor-load-op", "FuncOp"> {
let constructor = "createLegalizeTensorLoadOpPass()";
}
def LhloFusionPass : FunctionPass<"lhlo-fusion"> {
let summary = "Fuse lmhlo ops to kLoop/kInput fusion patterns.";
let constructor = "createLhloFusionPass()";
let options = [
Option<"max_num_arguments_per_kernel_", "max-num-arguments-per-kernel", "int",
/*default=*/"64", "Maximum allowed number of arguments per fused kernel.">,
];
}

View File

@ -117,6 +117,10 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
// Legalizes tensor load ops that are inserted during mhlo to lmhlo conversion.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTensorLoadOpPass();
// fuse lmhlo ops to kLoop/kInput fusion patterns
std::unique_ptr<OperationPass<FuncOp>> createLhloFusionPass(
int max_num_arguments_per_kernel = 64);
} // namespace lmhlo
namespace disc_ral {

View File

@ -137,8 +137,10 @@ add_mlir_library(MhloLhloToLinalg
)
add_mlir_library(LmhloPasses
fusion_utils.cc
legalize_tensor_load_op.cc
lhlo_fuse_linalg.cc
lhlo_fusion.cc
lhlo_legalize_to_affine.cc
lhlo_legalize_to_gpu.cc
lhlo_legalize_to_parallel_loops.cc

View File

@ -0,0 +1,394 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"
#include <algorithm>
#include "mlir/Dialect/Shape/IR/Shape.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h"
// This file implements some helper functions and classes used to do fusion
// & code generation.
namespace mlir {
namespace lmhlo {
// Returns true if the op is an elementwise unary lmhlo op.
// TODO(disc): use fusibility interface
bool isElementWiseUnary(Operation* op) {
// clang-format off
return isa<
lmhlo::AbsOp,
lmhlo::CeilOp,
lmhlo::ConvertOp,
lmhlo::CopyOp,
lmhlo::CosOp,
lmhlo::ExpOp,
lmhlo::FloorOp,
lmhlo::IsFiniteOp,
lmhlo::LogOp,
lmhlo::NegOp,
lmhlo::NotOp,
lmhlo::RsqrtOp,
lmhlo::SignOp,
lmhlo::SqrtOp,
lmhlo::TanhOp
>(op);
// clang-format on
}
// Returns true if the op is an elementwise binary lmhlo op.
// TODO(disc): use fusibility interface
bool isElementWiseBinary(Operation* op) {
// clang-format off
return isa<
lmhlo::AddOp,
lmhlo::AndOp,
lmhlo::CompareOp,
lmhlo::DivOp,
lmhlo::MaxOp,
lmhlo::MinOp,
lmhlo::MulOp,
lmhlo::OrOp,
lmhlo::PowOp,
lmhlo::SubOp
>(op);
// clang-format on
}
// Returns true if the op is an elementwise lmhlo op.
// TODO(disc): use fusibility interface
bool isElementWise(Operation* op) {
return isElementWiseUnary(op) || isElementWiseBinary(op);
}
// Returns true if this op is a rank-2 row reduction.
bool isRank2RowReduction(Operation* op) {
auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
if (!reduce_op || reduce_op.dimensions().getNumElements() != 1) return false;
int rank = op->getOperand(0).getType().cast<MemRefType>().getRank();
auto dimensions = reduce_op.dimensions().getValues<int64_t>();
return ((*dimensions.begin() == 1) && (rank == 2));
}
// Returns true if this op is a rank-2 column reduction.
bool isRank2ColReduction(Operation* op) {
auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
if (!reduce_op || reduce_op.dimensions().getNumElements() != 1) return false;
int rank = op->getOperand(0).getType().cast<MemRefType>().getRank();
auto dimensions = reduce_op.dimensions().getValues<int64_t>();
return ((*dimensions.begin() == 0) && (rank == 2));
}
// Returns true if the op is supported by the downstreaming fusion codegen
// engine.
bool isFusible(Operation* op) {
// Only scalar const are supported by the fusion codegen engine a.t.m.
if (dyn_cast<lmhlo::ConstOp>(op)) {
MemRefType type = op->getOperand(0).getType().cast<MemRefType>();
return (type.getRank() == 0);
}
// All element ops are supported by the fusion codegen engine.
if (isElementWise(op)) return true;
// Only rank-2 tensor -> rank-1 tensor reduction are supported now.
if (isRank2RowReduction(op) || isRank2ColReduction(op)) return true;
// clang-format off
return isa<
lmhlo::BroadcastInDimOp,
lmhlo::BroadcastOp,
lmhlo::ConcatenateOp,
lmhlo::DynamicBroadcastInDimOp,
lmhlo::DynamicGatherOp,
lmhlo::DynamicIotaOp,
lmhlo::DynamicPadOp,
lmhlo::DynamicReshapeOp,
lmhlo::GatherOp,
lmhlo::RealDynamicSliceOp,
lmhlo::ReshapeOp,
lmhlo::SelectOp,
lmhlo::SliceOp,
lmhlo::TransposeOp
>(op);
// clang-format on
}
// Returns the number of operands that are supposed to be written.
// For some ops (e.g. lmhlo ops), some operands are the output memrefs
// Thus these operands are supposed to be updated.
int getNumResultOperands(Operation* op) {
if (op->getDialect()->getNamespace() != "lmhlo") {
return 0;
}
auto isWritable = [&](Value operand) -> bool {
llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
// Suppose that operands of op without `MemoryEffectOpInterface` are
// readonly.
if (!interface) return false;
interface.getEffectsOnValue(operand, effects);
return llvm::any_of(
effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
});
};
return llvm::count_if(op->getOperands(),
[&](Value v) { return isWritable(v); });
}
// Returns data users of the value and its aliases (e.g. memref.cast).
// Here non-data users means DimOp, DeallocOp and ShapeOfOp.
SmallVector<Operation*, 4> getValueUsers(Value v) {
SmallVector<Operation*, 4> users;
SmallVector<Value, 4> worklist;
worklist.push_back(v);
while (!worklist.empty()) {
Value curr = worklist.back();
worklist.pop_back();
for (Operation* user : curr.getUsers()) {
// Skip non-data users
if (isa<memref::DimOp, memref::DeallocOp, shape::ShapeOfOp>(user)) {
continue;
}
// alias value
if (isa<memref::CastOp>(user)) {
worklist.push_back(user->getResult(0));
} else {
users.push_back(user);
}
}
}
return users;
}
// Create a new fusion pattern from a single op.
FusionPattern::FusionPattern(Operation* op) {
op_list_.push_back(op);
if (isRank2RowReduction(op)) {
fusion_type_ = FusionType::kRowReduction;
} else if (isRank2ColReduction(op)) {
fusion_type_ = FusionType::kColReduction;
} else if (mlir::lmhlo::isFusible(op)) {
fusion_type_ = FusionType::kLoop;
} else {
fusion_type_ = FusionType::kNone;
}
dominant_op_ = op;
calculateOperandsAndResults();
}
// Create a new fusion pattern from the ops inside the lmhlo fusion op.
FusionPattern::FusionPattern(lmhlo::FusionOp op) {
for (Operation& op : op.region().getBlocks().front()) {
op_list_.push_back(&op);
}
// Figure out fusion type and dominant op for the fusion pattern.
for (Operation* op : op_list_) {
if (isRank2RowReduction(op)) {
fusion_type_ = FusionType::kRowReduction;
dominant_op_ = op;
} else if (isRank2ColReduction(op)) {
if (fusion_type_ != FusionType::kRowReduction) {
fusion_type_ = FusionType::kColReduction;
dominant_op_ = op;
}
} else if (lmhlo::isFusible(op)) {
// Ignore if already a kRowReduction or kColReduction, otherwise update
// the fusion type to kLoop and dominant op to current op. This supposes
// that the last op inside the block is a valid candidate dominant op if
// the fusion pattern is a kLoop.
if (fusion_type_ == FusionType::kNone ||
fusion_type_ == FusionType::kLoop) {
fusion_type_ = FusionType::kLoop;
dominant_op_ = op;
}
} else {
// Not a supported fusionOp, early stop.
fusion_type_ = FusionType::kNone;
dominant_op_ = nullptr;
break;
}
}
if (isFusible()) calculateOperandsAndResults();
}
// Create a new fusion pattern from a valid fusion op list.
FusionPattern::FusionPattern(SmallVectorImpl<Operation*>& op_list)
: op_list_(op_list.begin(), op_list.end()) {
calculateOperandsAndResults();
}
// Returns true if two fusion patterns can be merged into one bigger fusion
// pattern.
bool FusionPattern::isMergeable(FusionPattern& other) {
if (!this->isFusible() || !other.isFusible()) return false;
return true;
}
// Merges two fusion patterns and returns the merged pattern. The original
// pattern remains unmodified.
FusionPattern FusionPattern::merge(FusionPattern& other) {
assert(isMergeable(other));
FusionOpList new_op_list = op_list_;
new_op_list.insert(new_op_list.end(), other.getOpList().begin(),
other.getOpList().end());
FusionPattern new_fusion_pattern{new_op_list};
FusionType newType = FusionType::kLoop;
Operation* newDominant = getDominantOp();
// kRowReduction + (kRowReduction | kColReduction | kLoop) = kRowReduction
// kColReduction + (kColReduction | kLoop) = kColReduction
// kLoop + kLoop = kLoop
if (getFusionType() == FusionType::kRowReduction ||
other.getFusionType() == FusionType::kRowReduction) {
newType = FusionType::kRowReduction;
if (getFusionType() != FusionType::kRowReduction)
newDominant = other.getDominantOp();
} else if (getFusionType() == FusionType::kColReduction ||
other.getFusionType() == FusionType::kColReduction) {
newType = FusionType::kColReduction;
if (getFusionType() != FusionType::kColReduction)
newDominant = other.getDominantOp();
}
new_fusion_pattern.setDominantOp(newDominant);
new_fusion_pattern.setFusionType(newType);
return new_fusion_pattern;
}
// Merges two fusion patterns and returns the merged pattern. Replaces the
// original pattern with new merged pattern.
FusionPattern& FusionPattern::mergeInplace(FusionPattern& other) {
*this = merge(other);
return *this;
}
// Returns the effective size (e.g. not counting const ops) of the ops this
// fusion pattern contains.
int FusionPattern::effectiveSize() {
return llvm::count_if(
op_list_, [](Operation* op) { return !matchPattern(op, m_Constant()); });
}
// Sorts the ops inside the fusion pattern according to the keys provided.
void FusionPattern::sortFusionOpListBy(DenseMap<Operation*, int>& op_to_idx) {
std::sort(op_list_.begin(), op_list_.end(),
[&](Operation* lhs, Operation* rhs) {
return op_to_idx[lhs] < op_to_idx[rhs];
});
}
// Calculates the inputs and outputs of the fusion pattern.
void FusionPattern::calculateOperandsAndResults() {
DenseSet<Value> input_set;
DenseSet<Value> result_set;
DenseSet<Value> internal_result_set;
DenseSet<Operation*> op_set(op_list_.begin(), op_list_.end());
DenseMap<Value, Operation*> last_writer;
for (Operation* op : op_list_) {
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
for (Value v : op->getOperands().drop_front(num_input_operand)) {
bool inserted = last_writer.try_emplace(v, op).second;
(void)inserted;
assert(inserted);
bool has_external_user = false;
for (Operation* user : getValueUsers(v)) {
if (!op_set.contains(user)) {
has_external_user = true;
break;
}
}
if (has_external_user) {
results_.push_back(v);
} else {
internal_results_.push_back(v);
}
}
}
for (Operation* op : op_list_) {
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
for (Value value : op->getOperands().take_front(num_input_operand)) {
if (last_writer.find(value) != last_writer.end()) {
// skip if defining op is in the pattern
continue;
}
input_set.insert(value);
}
}
for (Value v : input_set) operands_.push_back(v);
}
// Supports using EquivalenceClasses for Value
bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
auto lhs_value = lhs.getValue().getAsOpaquePointer();
auto rhs_value = rhs.getValue().getAsOpaquePointer();
return lhs_value < rhs_value;
}
// shape equality propagation based on the shape constrains of
// elementwise ops.
void ShapeConstraintAnalysis::PropagateEquality(
const SmallVectorImpl<Operation*>& op_list) {
bool converged = true;
do {
converged = true;
auto update = [&](Value lhs, Value rhs,
EquivalenceClasses<ValueWrapper>& impl) {
if (!impl.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
converged = false;
impl.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
}
};
for (Operation* op : op_list) {
int num_operand = op->getNumOperands();
// Propagates same num_elements equality, and shape equality
if (isElementWise(op)) {
Value lhs = op->getOperand(0);
for (Value rhs : op->getOperands().drop_front()) {
update(lhs, rhs, same_num_elements_impl_);
update(lhs, rhs, same_shape_impl_);
}
}
// Propagates same num_elements equality, not shape equality
if (isa<lmhlo::DynamicReshapeOp, lmhlo::ReshapeOp, lmhlo::TransposeOp>(
op)) {
Value input = op->getOperand(0);
// The last operand is the output memref by design
Value output = op->getOperand(num_operand - 1);
update(input, output, same_num_elements_impl_);
}
}
} while (!converged);
}
} // namespace lmhlo
} // namespace mlir

View File

@ -0,0 +1,570 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"
#include "mlir-hlo/utils/cycle_detector.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
// This pass has similar functionality of the fusion pass in XLA stack.
// However, unlike XLA, it targets the fully dynamic shape scenario.
// Currently, it implements the kLoop and kInput fusion templates.
// During conversion, it tries to greedily find kLoop/kInput fusion
// patterns.
//
// Similar to XLA, this pass supports fusion pattern having multiple outputs
// if all the shape of outputs are consistent. Following are some examples.
//
// kLoop kInput
// +----+ +----+ +----+ +----+ +----+ +----+
// |elem| |elem| |elem| |elem<----+elem+---->elem+----+
// +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ |
// | | | | | |
// | | | | |
// +-v--+ | +-v--+ +--v---+ +--v---+ |
// |elem+<---+----<+elem| |reduce| |reduce| |
// +-+--+ +-+--+ +--+---+ +--+---+ |
// | | | | |
// | | | | |
// v v v v v
//
// To this end, we also add an simple shape constraint analysis phase.
// For kLoop fusion template, it requires all the outputs of the fused
// pattern have the same shape. However, we don't know the actual value
// of the shape at the compile time in the dynamic shape world.
// Fortunately, we could still infer the relationship among different ops
// according to their shape constraint traits. Currently, We only consider
// shape equality propagation for elementwise ops (assuming that implicit
// shape broadcast is forbidden). The above process could be built on the
// shape dialect once it is ready.
//
// TODO(disc): This file implements fusion on buffer level, re-visit this after
// more shape inference/constraint infras are ready in mhlo level.
// TODO(disc): Not using fusibility interface a.t.m, re-visit this if necessary.
namespace mlir {
namespace lmhlo {
namespace {
struct FusionOptions {
// Maximum allowed number of arguments per fused kernel. Here arguments
// include both ready-only buffers and writable buffers.
int max_num_arguments_per_kernel;
};
// A fusion planner that can propose a fusion plan for a block of ops.
// The fusion plan is consisted of a group of fusion patterns.
//
// Currently all proposed patterns followed xla kLoop/kInput like fusion
// templates while are adapted to the fully dynamic shape world.
//
// kLoop fusion template satisfies:
// - all ops in the fusion pattern are element-wise.
// - all the shapes of outputs of fusion pattern are same or have same number
// of elements, and thus can fit into a same parallel loop.
//
// kInput fusion template satisfies:
// - any op in the fusion pattern is either element-wise or a reduction.
// - if a op is a reduction, its output cannot be consumed by other
// ops in the same fusion pattern.
// - all the effective shapes of outputs of fusion pattern are same.
// - For element-wise op, its effective shape is its output shape.
// - For reduction op, its effective shape is its operand shape.
// - currently our downstreaming codegen engine only support 2d -> 1d tensor
// reduction. TODO(disc): lift this limitation.
// - 2D row reduction: out[i] = sum({in[i][j] for all j})
// - 2D column reduction: out[j] = sum({in[i][j] for all i}
class FusionPlanner {
public:
explicit FusionPlanner(const FusionOptions& options, Block* block)
: options_(options), block_(block) {
// Move up metadata-only ops (e.g. dim, shape_of) as far as possible.
MoveUpMetadataOnlyOpsForFusion();
for (Operation& op : *block) {
op_list_.push_back(&op);
}
shape_analysis_.reset(new ShapeConstraintAnalysis(op_list_));
cycle_detector_.reset(new GraphCycles(op_list_.size()));
BuildNodeMap();
}
// Returns a fusion plan if success, otherwise none.
llvm::Optional<FusionPlan> Run() {
// Greedily search connected fusible pattern, and ops belonging to
// a same fusion pattern are grouped into a cluster.
RunEdgeContractionLoop();
// After doing edge contraction, each unique cluster having size
// more than one represents a potential fusion pattern.
// We collect all these clusters and construct a fusion plan.
FusionPlan plan;
DenseSet<Cluster*> seen_clusters;
for (Operation* op : op_list_) {
Cluster* cluster = GetClusterForNode(op);
if (!seen_clusters.insert(cluster).second) continue;
FusionPattern& fusion_pattern = cluster->fused_pattern();
// Make sure the ops in a fusion pattern are in topological ordering.
fusion_pattern.sortFusionOpListBy(op_to_node_id_);
if (!fusion_pattern.isFusible() || fusion_pattern.effectiveSize() <= 1) {
continue;
}
plan.emplace_back(fusion_pattern);
}
// Re-order ops inside the blocks to make sure all producers are placed
// before its consumers after fusion.
ReorderOperationsInsideBlock();
return plan;
}
// Returns the op_list this planner operates on.
const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
private:
// Represent a (partial) fused pattern
class Cluster {
public:
Cluster(int node_id, FusionPlanner* planner)
: node_id_(node_id), pattern_(planner->op_list()[node_id]) {}
// Merges `other` into this cluster, and clears `other`.
void Merge(Cluster* other) {
pattern_.mergeInplace(other->fused_pattern());
}
// The number of nodes in this cluster.
int cluster_size() { return pattern_.size(); }
// The ID of the cluster as represented in `cycle_detector_`.
int cycles_graph_node_id() const { return node_id_; }
// Sets the ID of the cluster as represented in `cycle_detector_`.
void set_cycles_graph_node_id(int cycles_graph_node_id) {
node_id_ = cycles_graph_node_id;
}
// Currently the fused pattern this cluster holds.
FusionPattern& fused_pattern() { return pattern_; }
private:
// ID of the representative node of this cluster.
int node_id_;
// the fused pattern this cluster holds.
FusionPattern pattern_;
};
private:
// Returns a new cluster with specified `cycles_graph_node_id`
Cluster* MakeCluster(int cycles_graph_node_id) {
cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
return cluster_storage_.back().get();
}
// Metadata ops (e.g. shapeOf, dimOp) don't change data thus we move forward
// them as far as possible inside the same block to enable more fusion
// opportunities.
void MoveUpMetadataOnlyOpsForFusion() {
SmallVector<Operation*, 4> ops;
for (Operation& op : *block_) {
ops.push_back(&op);
}
auto inBlock = [&](Operation* op, Block* block) {
return op && op->getBlock() == block;
};
for (Operation* op : ops) {
Block* block = op->getBlock();
if (isa<shape::ShapeOfOp>(op)) {
Operation* definingOp = op->getOperand(0).getDefiningOp();
if (!inBlock(definingOp, block)) {
op->moveBefore(block, block->begin());
} else {
op->moveAfter(definingOp);
}
} else if (isa<memref::DimOp>(op)) {
Operation* firstOperandOp = op->getOperand(0).getDefiningOp();
Operation* secondOperandOp = op->getOperand(1).getDefiningOp();
if (!inBlock(firstOperandOp, block) &&
!inBlock(secondOperandOp, block)) {
op->moveBefore(block, block->begin());
} else if (!inBlock(firstOperandOp, block)) {
op->moveAfter(secondOperandOp);
} else if (!inBlock(secondOperandOp, block)) {
op->moveAfter(firstOperandOp);
} else if (firstOperandOp->isBeforeInBlock(secondOperandOp)) {
op->moveAfter(secondOperandOp);
} else {
op->moveAfter(firstOperandOp);
}
}
}
}
// Returns all the values touched by this op or its nested ops.
SmallVector<Value, 4> GetAllPossibleUsedValues(Operation* op) {
SmallVector<Value, 4> values;
op->walk([&](Operation* nest_op) {
for (Value v : nest_op->getOperands()) {
values.push_back(v);
}
});
return values;
}
// Builds the initial dependency graph.
void BuildNodeMap() {
int num_nodes = op_list_.size();
for (int node_id = 0; node_id < num_nodes; ++node_id) {
Operation* op = op_list_[node_id];
MakeCluster(node_id);
op_to_node_id_[op] = node_id;
leader_for_node_.insert(node_id);
for (Value operand : GetAllPossibleUsedValues(op)) {
Operation* operand_op = FindLastWriter(operand);
// Only consider the operand_op inside the target block.
auto iter = op_to_node_id_.find(operand_op);
if (iter == op_to_node_id_.end()) {
continue;
}
// Add an edge to connect the last writer and the current consumer.
cycle_detector_->InsertEdge(iter->second, node_id);
}
// For some ops (e.g. lmhlo ops), some operands are the output memrefs
// Thus these operands are supposed to be updated.
// Suppose that a op (or its nested ops) can only write the buffers
// explicit passed in as operands of this op.
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
for (Value v : op->getOperands().drop_front(num_input_operand)) {
auto it = last_writer_.try_emplace(v, op);
(void)it;
// Currently, a buffer is only supposed to be written once (as the
// output operand of one lmhlo op).
assert(it.second);
}
}
}
// Returns the cluster contains this op.
Cluster* GetClusterForNode(Operation* n) {
int id = op_to_node_id_[n];
id = leader_for_node_.getLeaderValue(id);
return cluster_storage_[id].get();
}
// Returns the cluster contains the op having `node_id`.
Cluster* GetClusterForCyclesGraphNode(int node_id) {
return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
}
// Merges the clusters `cluster_from` and `cluster_to`.
bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
int from = cluster_from->cycles_graph_node_id();
int to = cluster_to->cycles_graph_node_id();
auto optional_merged_node = cycle_detector_->ContractEdge(from, to);
if (!optional_merged_node.hasValue()) {
llvm::dbgs() << "Could not contract " << from << " -> " << to
<< " because contracting the edge would create a cycle.";
return false;
}
// Merge the clusters.
cluster_from->Merge(cluster_to);
cluster_from->set_cycles_graph_node_id(*optional_merged_node);
// Merge the UnionFind Set.
leader_for_node_.unionSets(from, to);
return true;
}
using FnTy = llvm::function_ref<bool(Cluster*, Cluster*)>;
bool ForEachEdgeInPostOrder(FnTy fn, bool enable_cross_fusion = false) {
bool changed = false;
for (int32_t node : cycle_detector_->AllNodesInPostOrder()) {
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
// Make a copy of the set of successors because we may modify the graph in
// TryToContractEdge.
std::vector<int32_t> successors_copy =
cycle_detector_->SuccessorsCopy(cluster_from->cycles_graph_node_id());
for (int to : successors_copy) {
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
bool contracted_edge = fn(cluster_from, cluster_to);
changed |= contracted_edge;
}
}
if (!enable_cross_fusion) return changed;
// To enable even more fusion opportunities (e.g. horizontal fusion)
for (int32_t lhs : cycle_detector_->AllNodesInPostOrder()) {
Cluster* cluster_lhs = GetClusterForCyclesGraphNode(lhs);
if (!cluster_lhs) {
continue;
}
for (int32_t rhs : cycle_detector_->AllNodesInPostOrder()) {
Cluster* cluster_rhs = GetClusterForCyclesGraphNode(rhs);
if (!cluster_rhs || cluster_lhs == cluster_rhs) {
continue;
}
bool contracted_edge = fn(cluster_lhs, cluster_rhs);
changed |= contracted_edge;
}
}
return changed;
}
// This function check if fusing `from` with `to` is valid and if so perform
// the merge. The validity is based on the operations in the clusters and
// the compatibility of the shapes of the outputs of the would-be fused
// clusters.
// Returns true is the merge was performed.
bool TryToContractEdge(Cluster* from, Cluster* to) {
// Try merge and check if valid.
if (!from->fused_pattern().isMergeable(to->fused_pattern())) return false;
FusionPattern fused_pattern =
from->fused_pattern().merge(to->fused_pattern());
auto& op_list = fused_pattern.getOpList();
auto& operands = fused_pattern.getOperands();
auto& results = fused_pattern.getResults();
if (results.size() + operands.size() >
options_.max_num_arguments_per_kernel) {
// some backend devices (e.g. GPU) do not support a kernel with
// too many arguments.
return false;
}
// We currently do not support a constant op as final output of a fusion
// pattern.
// TODO(disc): copy small const in case necessary.
for (Value result : results) {
Operation* result_op = FindLastWriter(result);
assert(result_op);
if (isa<lmhlo::ConstOp>(result_op)) {
return false;
}
}
// ReduceOp can not have consumer within the fusion pattern.
for (Operation* op : op_list) {
if (!isa<lmhlo::ReduceOp>(op)) continue;
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
for (Value v : op->getOperands().drop_front(num_input_operand)) {
for (Operation* user : getValueUsers(v)) {
if (user == op) continue;
if (std::find(op_list.begin(), op_list.end(), user) !=
op_list.end()) {
return false;
}
}
}
}
// All outputs of a fusion pattern should have compatible shape.
// Here `compatible` means:
// - if `to` and `from` are both kInput fusion, all output should have same
// shape.
// - otherwise, all output should have same number of elements.
// No outside users, these ops may be eliminated. We fused it here and let
// latter pass to do such DCE.
if (results.empty()) return true;
bool check_same_shape = (to->fused_pattern().isKInputFusion() &&
from->fused_pattern().isKInputFusion());
auto get_effective_shape = [&](Value v) {
auto result_op = FindLastWriter(v);
assert(result_op);
// effective shape of reduce op is its operand's shape.
return isa<lmhlo::ReduceOp>(result_op) ? result_op->getOperand(0) : v;
};
Value ref_shape = get_effective_shape(results[0]);
if (!llvm::all_of(results, [&](Value result) {
Value shape = get_effective_shape(result);
return check_same_shape
? shape_analysis_->HasSameShape(ref_shape, shape)
: shape_analysis_->HasSameNumElements(ref_shape, shape);
})) {
return false;
}
return MergeClusters(from, to);
}
// Greedily fuse connected node.
bool RunEdgeContractionLoop() {
using std::placeholders::_1;
using std::placeholders::_2;
bool changed = false;
// Run fusion pass repeatedly until nothing to be fused
while (ForEachEdgeInPostOrder(
std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2), false)) {
// empty statement by design
}
return changed;
}
// Here `value` is supported to be a pointer to buffer.
// Returns the defining op of `value `if no known op updates the buffer,
// otherwise returns the last op that updates the buffer pointed by the
// `value`.
Operation* FindLastWriter(Value value) {
auto it = last_writer_.find(value);
if (it != last_writer_.end()) {
return it->second;
}
return value.getDefiningOp();
}
// Re-order ops inside the block to make sure that producers are before
// consumers after fusion.
void ReorderOperationsInsideBlock() {
auto reorder_func = [&](Cluster* from, Cluster* to) {
FusionPattern& from_pattern = from->fused_pattern();
FusionPattern& to_pattern = to->fused_pattern();
Operation* last_op_in_from = from_pattern.getOpList().back();
for (Operation* op : llvm::reverse(to_pattern.getOpList())) {
if (!last_op_in_from->isBeforeInBlock(op))
op->moveAfter(last_op_in_from);
}
return false;
};
ForEachEdgeInPostOrder(reorder_func);
}
// hyper-parameters that controls the behaviour of the fusion planner.
FusionOptions options_;
// The block that fusion planner works on.
Block* block_;
// Ops inside the block
SmallVector<Operation*, 4> op_list_;
// Shape equality checker
std::unique_ptr<ShapeConstraintAnalysis> shape_analysis_;
// op -> node_id
DenseMap<Operation*, int> op_to_node_id_;
// make sure not introduce cycle after fusion
std::unique_ptr<GraphCycles> cycle_detector_;
std::vector<std::unique_ptr<Cluster>> cluster_storage_;
// a UnionFind set. Each set represents a (partial) fused pattern
// and has a leader as representation.
EquivalenceClasses<int32_t> leader_for_node_;
// Here `value` is supported to be a pointer to buffer.
// Returns the defining op of `value `if no known op updates the buffer,
// otherwise returns the last op that updates the buffer pointed by the
// `value`.
DenseMap<Value, Operation*> last_writer_;
};
struct LhloFusionPass : public LhloFusionPassBase<LhloFusionPass> {
using LhloFusionPassBase<LhloFusionPass>::LhloFusionPassBase;
explicit LhloFusionPass(int max_num_arguments_per_kernel)
: LhloFusionPassBase<LhloFusionPass>::LhloFusionPassBase() {
this->max_num_arguments_per_kernel_ = max_num_arguments_per_kernel;
}
void runOnFunction() override {
FuncOp func = getFunction();
// collect all blocks inside the function.
SmallVector<Block*, 4> blocks;
CollectBlocksInsideFunction(func, blocks);
// process each block and do fusion within a block.
FusionOptions options;
options.max_num_arguments_per_kernel = max_num_arguments_per_kernel_;
for (Block* block : blocks) {
FusionPlanner planner(options, block);
llvm::Optional<FusionPlan> plan = planner.Run();
if (!plan) {
emitError(func.getLoc(),
"an error occurs while trying to find fusion candidates");
signalPassFailure();
return;
}
if (!ApplyFusionPlan(*plan)) {
emitError(func.getLoc(), "apply fusion plan failed");
signalPassFailure();
return;
}
}
}
bool ApplyFusionPlan(FusionPlan& plan) {
for (FusionPattern& pattern : plan) {
auto& op_list = pattern.getOpList();
OpBuilder b(op_list.back());
// Get the fused locations
SmallVector<Location, 4> locations;
locations.reserve(op_list.size());
for (Operation* op : op_list) {
locations.push_back(op->getLoc());
}
Location fused_loc =
FusedLoc::get(op_list.back()->getContext(), locations);
// Move ops inside fusion pattern to the region attached to the fusion op.
FusionOp fusion = b.create<lmhlo::FusionOp>(fused_loc);
Region& region = fusion.region();
Block& block = region.front();
for (Operation* op : llvm::reverse(op_list)) {
op->moveBefore(&block, block.begin());
}
}
return true;
}
void CollectBlocksInsideFunction(FuncOp op, SmallVectorImpl<Block*>& blocks) {
op.walk([&](Block* block) {
// It does not make sense to fuse the region attached to these ops.
if (!isa<lmhlo::ReduceOp, lmhlo::FusionOp>(block->getParentOp()))
blocks.push_back(block);
});
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> createLhloFusionPass(
int max_num_arguments_per_kernel) {
return std::make_unique<LhloFusionPass>(max_num_arguments_per_kernel);
}
} // namespace lmhlo
} // namespace mlir

286
tests/lhlo-fusion.mlir Normal file
View File

@ -0,0 +1,286 @@
// RUN: mlir-hlo-opt --lhlo-fusion -split-input-file %s -o - | FileCheck %s
// CHECK-LABEL: @simple_kloop_fusion
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>) -> memref<?x?xf32>
func @simple_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
%arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> memref<?x?xf32> {
// CHECK: "lmhlo.fusion"() ( {
// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: }) : () -> ()
// CHECK: return %[[ARG3]] : memref<?x?xf32>
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
"lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return %arg3 : memref<?x?xf32>
}
// -----
// CHECK-LABEL: @simple_multi_output_kloop_fusion
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>)
func @simple_multi_output_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
%arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) {
// CHECK: "lmhlo.fusion"() ( {
// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: }) : () -> ()
// CHECK: return %[[ARG1]], %[[ARG3]] : memref<?x?xf32>, memref<?x?xf32>
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
"lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return %arg1, %arg3 : memref<?x?xf32>, memref<?x?xf32>
}
// -----
// CHECK-LABEL: @simple_multi_output_kloop_fusion_with_reorder
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>, %[[ARG4:.*]]: memref<2xindex>, %[[ARG5:.*]]: memref<?x?xf32>)
func @simple_multi_output_kloop_fusion_with_reorder(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
%arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
%arg4: memref<2xindex>, %arg5: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) {
// CHECK: "lmhlo.fusion"() ( {
// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: }) : () -> ()
// CHECK: "lmhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[ARG4]], %[[ARG5]])
// CHECK: return %[[ARG1]], %[[ARG3]], %[[ARG5]] : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
"lmhlo.dynamic_broadcast_in_dim"(%arg1, %arg4, %arg5) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (memref<?x?xf32>, memref<2xindex>, memref<?x?xf32>) -> ()
"lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return %arg1, %arg3, %arg5 : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
}
// -----
// CHECK-LABEL: @same_num_elements_multi_output_kloop_fusion
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<2xi64>, %[[ARG3:.*]]: memref<?x?x?xf32>, %[[ARG4:.*]]: memref<?x?x?xf32>, %[[ARG5:.*]]: memref<?x?x?xf32>)
func @same_num_elements_multi_output_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
%arg2: memref<2xi64>, %arg3: memref<?x?x?xf32>,
%arg4: memref<?x?x?xf32>, %arg5: memref<?x?x?xf32>) -> (memref<?x?xf32>, memref<?x?x?xf32>) {
// CHECK: "lmhlo.fusion"() ( {
// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.dynamic_reshape"(%[[ARG1]], %[[ARG2]], %[[ARG3]])
// CHECK: "lmhlo.add"(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
// CHECK: }) : () -> ()
// CHECK: return %[[ARG1]], %[[ARG5]] : memref<?x?xf32>, memref<?x?x?xf32>
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
"lmhlo.dynamic_reshape"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<2xi64>, memref<?x?x?xf32>) -> ()
"lmhlo.add"(%arg3, %arg4, %arg5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
return %arg1, %arg5 : memref<?x?xf32>, memref<?x?x?xf32>
}
// -----
// CHECK-LABEL: @check_not_kloop_fusion
func @check_not_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) {
// CHECK-NOT: "lmhlo.fusion"
"lmhlo.add"(%arg0, %arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
"lmhlo.subtract"(%arg2, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return %arg1, %arg3: memref<?x?xf32>, memref<?x?xf32>
}
// -----
// CHECK-LABEL: @kloop_fusion_with_dealloc
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>)
func @kloop_fusion_with_dealloc(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) {
// CHECK: %[[TMP3:.*]] = memref.alloc
// CHECK: %[[TMP5:.*]] = memref.alloc
// CHECK: %[[TMP9:.*]] = memref.alloc
// CHECK: %[[TMP13:.*]] = memref.alloc
// CHECK: %[[TMP16:.*]] = memref.alloc
// CHECK: "lmhlo.fusion"() ( {
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.multiply"(%[[ARG0]], %[[ARG1]], %[[TMP5]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.abs"(%[[TMP3]], %[[TMP9]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.abs"(%[[TMP5]], %[[TMP13]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.multiply"(%[[TMP9]], %[[TMP13]], %[[TMP16]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: }) : () -> ()
// CHECK: memref.dealloc %[[TMP3]] : memref<?x?xf32>
// CHECK: memref.dealloc %[[TMP5]] : memref<?x?xf32>
// CHECK: memref.dealloc %[[TMP13]] : memref<?x?xf32>
// CHECK: return %[[TMP9]], %[[TMP16]] : memref<?x?xf32>, memref<?x?xf32>
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
%1 = tensor.extract %0[%c0] : tensor<2xindex>
%2 = tensor.extract %0[%c1] : tensor<2xindex>
%3 = memref.alloc(%1, %2) : memref<?x?xf32>
"lmhlo.add"(%arg0, %arg1, %3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
%4 = memref.alloc(%1, %2) : memref<?x?xf32>
"lmhlo.multiply"(%arg0, %arg1, %4) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
%5 = shape.shape_of %3 : memref<?x?xf32> -> tensor<2xindex>
%6 = tensor.extract %5[%c0] : tensor<2xindex>
%7 = tensor.extract %5[%c1] : tensor<2xindex>
%8 = memref.alloc(%6, %7) : memref<?x?xf32>
"lmhlo.abs"(%3, %8) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
memref.dealloc %3 : memref<?x?xf32>
%9 = shape.shape_of %4 : memref<?x?xf32> -> tensor<2xindex>
%10 = tensor.extract %9[%c0] : tensor<2xindex>
%11 = tensor.extract %9[%c1] : tensor<2xindex>
%12 = memref.alloc(%10, %11) : memref<?x?xf32>
"lmhlo.abs"(%4, %12) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
memref.dealloc %4 : memref<?x?xf32>
%13 = shape.shape_of %8 : memref<?x?xf32> -> tensor<2xindex>
%14 = tensor.extract %13[%c0] : tensor<2xindex>
%15 = tensor.extract %13[%c1] : tensor<2xindex>
%16 = memref.alloc(%14, %15) : memref<?x?xf32>
"lmhlo.multiply"(%8, %12, %16) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
memref.dealloc %12 : memref<?x?xf32>
return %8, %16 : memref<?x?xf32>, memref<?x?xf32>
}
// -----
// CHECK-LABEL: @simple_kinput
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<f32>
func @simple_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?xf32>, %init: memref<f32>) -> memref<?xf32> {
// CHECK: "lmhlo.fusion"() ( {
// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( {
// CHECK: }) : () -> ()
// CHECK: return %[[ARG2]] : memref<?xf32>
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
"lmhlo.reduce"(%arg1, %init, %arg2) ( {
^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
"lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
return %arg2: memref<?xf32>
}
// -----
// CHECK-LABEL: @multi_output_kinput
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<f32>
func @multi_output_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?xf32>, %init: memref<f32>) -> (memref<?x?xf32>, memref<?xf32>) {
// CHECK: "lmhlo.fusion"() ( {
// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( {
// CHECK: }) : () -> ()
// CHECK: return %[[ARG1]], %[[ARG2]] : memref<?x?xf32>, memref<?xf32>
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
"lmhlo.reduce"(%arg1, %init, %arg2) ( {
^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
"lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
return %arg1, %arg2: memref<?x?xf32>, memref<?xf32>
}
// -----
// CHECK-LABEL: @row_red_and_row_red_kinput
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?xf32>, %[[ARG4:.*]]: memref<?xf32>, %[[ARG5:.*]]: memref<?x?xf32>, %[[ARG6:.*]]: memref<f32>
func @row_red_and_row_red_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?xf32>, %arg4: memref<?xf32>, %arg5: memref<?x?xf32>, %init: memref<f32>) -> (memref<?xf32>, memref<?xf32>) {
// CHECK: "lmhlo.fusion"() ( {
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( {
// CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( {
// CHECK: }) : () -> ()
// CHECK: return %[[ARG3]], %[[ARG4]] : memref<?xf32>, memref<?xf32>
"lmhlo.add"(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
"lmhlo.abs"(%arg2, %arg5) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
"lmhlo.reduce"(%arg5, %init, %arg3) ( {
^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
"lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
"lmhlo.reduce"(%arg2, %init, %arg4) ( {
^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
"lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
return %arg3, %arg4: memref<?xf32>, memref<?xf32>
}
// -----
// CHECK-LABEL: @row_red_and_col_red_kinput
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?xf32>, %[[ARG4:.*]]: memref<?xf32>, %[[ARG5:.*]]: memref<?x?xf32>, %[[ARG6:.*]]: memref<f32>
func @row_red_and_col_red_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?xf32>, %arg4: memref<?xf32>, %arg5: memref<?x?xf32>, %init: memref<f32>) -> (memref<?xf32>, memref<?xf32>) {
// CHECK: "lmhlo.fusion"() ( {
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( {
// CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( {
// CHECK: }) : () -> ()
// CHECK: return %[[ARG3]], %[[ARG4]] : memref<?xf32>, memref<?xf32>
"lmhlo.add"(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
"lmhlo.abs"(%arg2, %arg5) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
"lmhlo.reduce"(%arg5, %init, %arg3) ( {
^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
"lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
"lmhlo.reduce"(%arg2, %init, %arg4) ( {
^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
"lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
return %arg3, %arg4: memref<?xf32>, memref<?xf32>
}
// -----
// CHECK-LABEL: @reduce_should_not_have_consumer_in_the_fusion
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>
func @reduce_should_not_have_consumer_in_the_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>)
-> (memref<?x?xf32>, memref<?xf32>) {
// CHECK: %[[TMP4:.*]] = memref.alloc
// CHECK: %[[TMP7:.*]] = memref.alloc
// CHECK: %[[TMP8:.*]] = memref.alloc
// CHECK: %[[TMP9:.*]] = memref.alloc
// CHECK: "lmhlo.fusion"() ( {
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP4]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.subtract"(%[[ARG0]], %[[TMP4]], %[[TMP7]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: "lmhlo.constant"(%[[TMP8]]) {value = dense<0.000000e+00> : tensor<f32>} : (memref<f32>) -> ()
// CHECK: "lmhlo.reduce"(%[[TMP7]], %[[TMP8]], %[[TMP9]]) ( {
// CHECK: }) : () -> ()
// CHECK: memref.dealloc %[[TMP4]] : memref<?x?xf32>
// CHECK: memref.dealloc %[[TMP8]] : memref<f32>
// CHECK: %[[TMP12:.*]] = memref.alloc
// CHECK: "lmhlo.add"(%[[TMP9]], %[[TMP9]], %[[TMP12]]) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
// CHECK: memref.dealloc %[[TMP9]] : memref<?xf32>
// CHECK: return %[[TMP7]], %[[TMP12]] : memref<?x?xf32>, memref<?xf32>
%c1 = constant 1 : index
%c0 = constant 0 : index
%0 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
%1 = tensor.extract %0[%c0] : tensor<2xindex>
%2 = tensor.extract %0[%c1] : tensor<2xindex>
%3 = memref.alloc(%1, %2) : memref<?x?xf32>
"lmhlo.add"(%arg0, %arg1, %3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
%4 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
%5 = tensor.extract %4[%c0] : tensor<2xindex>
%6 = tensor.extract %4[%c1] : tensor<2xindex>
%7 = memref.alloc(%5, %6) : memref<?x?xf32>
"lmhlo.subtract"(%arg0, %3, %7) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
memref.dealloc %3 : memref<?x?xf32>
%8 = memref.alloc() : memref<f32>
"lmhlo.constant"(%8) {value = dense<0.000000e+00> : tensor<f32>} : (memref<f32>) -> ()
%9 = memref.alloc(%5) : memref<?xf32>
"lmhlo.reduce"(%7, %8, %9) ( {
^bb0(%arg2: memref<f32>, %arg3: memref<f32>, %arg4: memref<f32>): // no predecessors
"lmhlo.add"(%arg2, %arg3, %arg4) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
memref.dealloc %8 : memref<f32>
%10 = shape.shape_of %9 : memref<?xf32> -> tensor<1xindex>
%11 = tensor.extract %10[%c0] : tensor<1xindex>
%12 = memref.alloc(%11) : memref<?xf32>
"lmhlo.add"(%9, %9, %12) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
memref.dealloc %9 : memref<?xf32>
return %7, %12 : memref<?x?xf32>, memref<?xf32>
}
// -----
// CHECK-LABEL: @const_should_not_be_output
func @const_should_not_be_output(%arg0: memref<f32>) -> (memref<f32>, memref<f32>) {
// CHECK-NOT: lmhlo.fusion
%0 = memref.alloc() : memref<f32>
"lmhlo.constant"(%0) {value = dense<1.000000e+00> : tensor<f32>} : (memref<f32>) -> ()
%1 = memref.alloc() : memref<f32>
"lmhlo.add"(%arg0, %0, %1) : (memref<f32>, memref<f32>, memref<f32>) -> ()
return %0, %1 : memref<f32>, memref<f32>
}