PR #50020: [MLIR][DISC] support fusion on buffer
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50020 This pass implements the logic to group kLoop/kInput fusion patterns on buffer level. The reason for this is that we can avoid a lot of headaches to handle `shape-only` consumers specially (e.g. memref.dim, shape.shapeOf) since shapes are already resolved in buffer world. It may be better to move this pass to tensor level after more shape inference/constraint infras are ready on mhlo level. Copybara import of the project: -- e31f8344b59aa9860097197585215ea1689b8ff4 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] support fusion on buffer This pass implements the logic to group kLoop/kInput fusion patterns on buffer level. The reason for this is that we can avoid a lot of headaches to handle `shape-only` consumers specially (e.g. memref.dim, shape.shapeOf) since shapes are already resolved in buffer world. It may be better to move this pass to tensor level after more shape inference/constraint infras are ready on mhlo level. -- 35f2eb2791241b0ab5db1ddcaf1b4006278ddccf by Wenyi Zhao <reyizero@gmail.com>: fix -- 923c8d61f7fe00a2a0df22d5be396508f0667964 by Wenyi Zhao <reyizero@gmail.com>: fix sanity check failure PiperOrigin-RevId: 379743424
This commit is contained in:
parent
82696f8598
commit
34dc5f2a79
38
BUILD
38
BUILD
|
@ -1213,6 +1213,42 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "fusion_utils",
|
||||||
|
srcs = ["lib/Dialect/mhlo/transforms/fusion_utils.cc"],
|
||||||
|
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"],
|
||||||
|
deps = [
|
||||||
|
":lhlo",
|
||||||
|
"@llvm-project//llvm:Core",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:Shape",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "lhlo_fusion",
|
||||||
|
srcs = ["lib/Dialect/mhlo/transforms/lhlo_fusion.cc"],
|
||||||
|
deps = [
|
||||||
|
":cycle_detector",
|
||||||
|
":fusion_utils",
|
||||||
|
":lhlo",
|
||||||
|
":pass_details",
|
||||||
|
"@llvm-project//llvm:Core",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:Shape",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
"@llvm-project//mlir:TransformUtils",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "chlo_legalize_to_hlo",
|
name = "chlo_legalize_to_hlo",
|
||||||
srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"],
|
srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"],
|
||||||
|
@ -1259,6 +1295,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":DiscRalPassIncGen",
|
":DiscRalPassIncGen",
|
||||||
|
":LmhloPassIncGen",
|
||||||
":MhloPassIncGen",
|
":MhloPassIncGen",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
],
|
],
|
||||||
|
@ -1316,6 +1353,7 @@ cc_library(
|
||||||
":legalize_trigonometric_to_approximation",
|
":legalize_trigonometric_to_approximation",
|
||||||
":lhlo",
|
":lhlo",
|
||||||
":lhlo_fuse_linalg",
|
":lhlo_fuse_linalg",
|
||||||
|
":lhlo_fusion",
|
||||||
":lhlo_legalize_to_affine",
|
":lhlo_legalize_to_affine",
|
||||||
":lhlo_legalize_to_gpu",
|
":lhlo_legalize_to_gpu",
|
||||||
":lhlo_legalize_to_parallel_loops",
|
":lhlo_legalize_to_parallel_loops",
|
||||||
|
|
|
@ -25,6 +25,14 @@ namespace mhlo {
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc"
|
#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc"
|
||||||
|
|
||||||
} // end namespace mhlo
|
} // end namespace mhlo
|
||||||
|
|
||||||
|
namespace lmhlo {
|
||||||
|
|
||||||
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"
|
||||||
|
|
||||||
|
} // end namespace lmhlo
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
|
@ -0,0 +1,244 @@
|
||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "llvm/ADT/EquivalenceClasses.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
|
|
||||||
|
// This file implements some helper functions and classes used to do fusion
|
||||||
|
// & code generation.
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace lmhlo {
|
||||||
|
|
||||||
|
// kLoop fusion template satisfies:
|
||||||
|
// - all ops in the fusion pattern are element-wise.
|
||||||
|
// - all the shapes of outputs of fusion pattern are same or have same number
|
||||||
|
// of elements, and thus can fit into a same parallel loop.
|
||||||
|
//
|
||||||
|
// kInput fusion template satisfies:
|
||||||
|
// - any op in the fusion pattern is either element-wise or a reduction.
|
||||||
|
// - if a op is a reduction, its output cannot be consumed by other
|
||||||
|
// ops in the same fusion pattern.
|
||||||
|
// - all the effective shapes of outputs of fusion pattern are same.
|
||||||
|
// - For element-wise op, its effective shape is its output shape.
|
||||||
|
// - For reduction op, its effective shape is its operand shape.
|
||||||
|
// - currently our downstreaming codegen engine only support 2d -> 1d tensor
|
||||||
|
// reduction. TODO: lift this limitation.
|
||||||
|
// - 2D row reduction: out[i] = sum({in[i][j] for all j})
|
||||||
|
// - 2D column reduction: out[j] = sum({in[i][j] for all i})
|
||||||
|
enum FusionType {
|
||||||
|
// Not a fusion pattern
|
||||||
|
kNone,
|
||||||
|
// kLoop fusion pattern
|
||||||
|
kLoop,
|
||||||
|
// kInput fusion pattern and all reduce ops of the fused pattern are row
|
||||||
|
// reduction
|
||||||
|
kRowReduction,
|
||||||
|
// kInput fusion pattern and all reduce ops of the fused pattern are column
|
||||||
|
// reduction
|
||||||
|
kColReduction,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns true if the op is an elementwise unary lmhlo op.
|
||||||
|
// TODO: use fusibility interface
|
||||||
|
bool isElementWiseUnary(Operation* op);
|
||||||
|
|
||||||
|
// Returns true if the op is an elementwise binary lmhlo op.
|
||||||
|
// TODO: use fusibility interface
|
||||||
|
bool isElementWiseBinary(Operation* op);
|
||||||
|
|
||||||
|
// Returns true if the op is an elementwise lmhlo op.
|
||||||
|
// TODO: use fusibility interface
|
||||||
|
bool isElementWise(Operation* op);
|
||||||
|
|
||||||
|
// Returns true if this op is a rank-2 row reduction.
|
||||||
|
bool isRank2RowReduction(Operation* op);
|
||||||
|
|
||||||
|
// Returns true if this op is a rank-2 column reduction.
|
||||||
|
bool isRank2ColReduction(Operation* op);
|
||||||
|
|
||||||
|
// Returns true if the op is supported by the downstreaming fusion codegen
|
||||||
|
// engine.
|
||||||
|
bool isFusible(Operation* op);
|
||||||
|
|
||||||
|
// Returns the number of operands that are supposed to be written.
|
||||||
|
// For some ops (e.g. lmhlo ops), some operands are the output memrefs
|
||||||
|
// Thus these operands are supposed to be updated.
|
||||||
|
int getNumResultOperands(Operation* op);
|
||||||
|
|
||||||
|
// Returns data users of the value and its aliases (e.g. memref.cast).
|
||||||
|
// Here non-data users means DimOp, DeallocOp and ShapeOfOp.
|
||||||
|
SmallVector<Operation*, 4> getValueUsers(Value v);
|
||||||
|
|
||||||
|
// Represents a list of lmhlo ops that are going to be fused.
|
||||||
|
class FusionPattern {
|
||||||
|
public:
|
||||||
|
using FusionOpList = SmallVector<Operation*, 4>;
|
||||||
|
using FusionValueList = SmallVector<Value, 4>;
|
||||||
|
|
||||||
|
// Create a new fusion pattern from a single op.
|
||||||
|
FusionPattern(Operation* op);
|
||||||
|
|
||||||
|
// Create a new fusion pattern from the ops inside the lmhlo fusion op.
|
||||||
|
FusionPattern(lmhlo::FusionOp op);
|
||||||
|
|
||||||
|
// Returns the op list this fusion pattern represents.
|
||||||
|
FusionOpList& getOpList() { return op_list_; }
|
||||||
|
|
||||||
|
// Returns the dominant op of this fusion pattern.
|
||||||
|
// For kLoop fusion, a dominant op may be any op that has external users.
|
||||||
|
// For kInput fusion, a dominant op may be a row reduction (if exists), or
|
||||||
|
// a column reduction op.
|
||||||
|
Operation* getDominantOp() { return dominant_op_; }
|
||||||
|
|
||||||
|
// Sets the dominant op to the op provided.
|
||||||
|
void setDominantOp(Operation* op) { dominant_op_ = op; }
|
||||||
|
|
||||||
|
// Returns the fusion kind of the fusion pattern.
|
||||||
|
FusionType getFusionType() { return fusion_type_; }
|
||||||
|
|
||||||
|
// Sets the fusion type to the the type provided.
|
||||||
|
void setFusionType(FusionType type) { fusion_type_ = type; }
|
||||||
|
|
||||||
|
// Returns true if this a fusible fusion pattern.
|
||||||
|
bool isFusible() { return getFusionType() != FusionType::kNone; }
|
||||||
|
|
||||||
|
// Returns true if this fusion pattern is a kLoop fusion.
|
||||||
|
bool isKLoopFusion() { return getFusionType() == FusionType::kLoop; }
|
||||||
|
|
||||||
|
// Returns true if this fusion pattern is a kInput fusion.
|
||||||
|
bool isKInputFusion() {
|
||||||
|
return (getFusionType() == FusionType::kRowReduction ||
|
||||||
|
getFusionType() == FusionType::kColReduction);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if two fusion patterns can be merged into one bigger fusion
|
||||||
|
// pattern.
|
||||||
|
bool isMergeable(FusionPattern& other);
|
||||||
|
|
||||||
|
// Merges two fusion patterns and returns the merged pattern. The original
|
||||||
|
// pattern remains unmodified.
|
||||||
|
FusionPattern merge(FusionPattern& other);
|
||||||
|
|
||||||
|
// Merges two fusion patterns and returns the merged pattern. Replaces the
|
||||||
|
// original pattern with new merged pattern.
|
||||||
|
FusionPattern& mergeInplace(FusionPattern& other);
|
||||||
|
|
||||||
|
// Returns values that are consumed by the lmhlo ops inside the fusion
|
||||||
|
// pattern.
|
||||||
|
FusionValueList& getOperands() { return operands_; }
|
||||||
|
|
||||||
|
// Returns values that are outputs of any lmhlo op in the fused pattern and
|
||||||
|
// have consumers outside the fusion pattern.
|
||||||
|
FusionValueList& getResults() { return results_; }
|
||||||
|
|
||||||
|
// Returns values that are outputs of any lmhlo op in the fused pattern and
|
||||||
|
// are only consumed by the lmhlo ops inside the fused pattern.
|
||||||
|
FusionValueList& getInternalResults() { return internal_results_; }
|
||||||
|
|
||||||
|
// Returns the size of the ops this fusion pattern contains.
|
||||||
|
int size() { return op_list_.size(); }
|
||||||
|
|
||||||
|
// Returns the effective size (e.g. not counting const ops) of the ops this
|
||||||
|
// fusion pattern contains.
|
||||||
|
int effectiveSize();
|
||||||
|
|
||||||
|
// Sorts the ops inside the fusion pattern according to the keys provided.
|
||||||
|
void sortFusionOpListBy(DenseMap<Operation*, int>& op_to_idx);
|
||||||
|
|
||||||
|
private:
|
||||||
|
FusionPattern(SmallVectorImpl<Operation*>& op_list);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Calculates the inputs and outputs of the fusion pattern.
|
||||||
|
void calculateOperandsAndResults();
|
||||||
|
|
||||||
|
private:
|
||||||
|
FusionOpList op_list_;
|
||||||
|
Operation* dominant_op_ = nullptr;
|
||||||
|
FusionType fusion_type_ = FusionType::kNone;
|
||||||
|
FusionValueList operands_;
|
||||||
|
FusionValueList results_;
|
||||||
|
FusionValueList internal_results_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Represents a list of disjoint fusion patterns for a block.
|
||||||
|
using FusionPlan = std::vector<FusionPattern>;
|
||||||
|
|
||||||
|
using llvm::EquivalenceClasses;
|
||||||
|
|
||||||
|
// Supports using EquivalenceClasses for Value
|
||||||
|
class ValueWrapper {
|
||||||
|
public:
|
||||||
|
explicit ValueWrapper(Value value) : value_(std::move(value)) {}
|
||||||
|
|
||||||
|
Value getValue() const { return value_; }
|
||||||
|
|
||||||
|
bool operator==(const ValueWrapper& rhs) const {
|
||||||
|
return getValue() == rhs.getValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Value value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs);
|
||||||
|
|
||||||
|
// This is a simple shape constraint analysis, which is used to
|
||||||
|
// guide fusion decision (e.g. we only fuse shape-compatible ops).
|
||||||
|
//
|
||||||
|
// Currently, We only consider shape equality and same-number-elements equality
|
||||||
|
// propagation based on the shape constraint traits of elementwise ops (assuming
|
||||||
|
// that implicit shape broadcast is forbidden).
|
||||||
|
class ShapeConstraintAnalysis {
|
||||||
|
public:
|
||||||
|
explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
|
||||||
|
PropagateEquality(op_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if `lhs` and `rhs` are supposed to have same shape.
|
||||||
|
bool HasSameShape(Value lhs, Value rhs) {
|
||||||
|
return same_shape_impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if `lhs` and `rhs` are supposed to have same number of
|
||||||
|
// elements.
|
||||||
|
bool HasSameNumElements(Value lhs, Value rhs) {
|
||||||
|
return same_num_elements_impl_.isEquivalent(ValueWrapper(lhs),
|
||||||
|
ValueWrapper(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// shape equality propagation based on the shape constrains of
|
||||||
|
// elementwise ops.
|
||||||
|
void PropagateEquality(const SmallVectorImpl<Operation*>& op_list);
|
||||||
|
|
||||||
|
// a UnionFind set
|
||||||
|
EquivalenceClasses<ValueWrapper> same_shape_impl_;
|
||||||
|
EquivalenceClasses<ValueWrapper> same_num_elements_impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace lmhlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_
|
|
@ -56,3 +56,12 @@ def LegalizeTensorLoadOpPass : Pass<"lhlo-legalize-tensor-load-op", "FuncOp"> {
|
||||||
let constructor = "createLegalizeTensorLoadOpPass()";
|
let constructor = "createLegalizeTensorLoadOpPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def LhloFusionPass : FunctionPass<"lhlo-fusion"> {
|
||||||
|
let summary = "Fuse lmhlo ops to kLoop/kInput fusion patterns.";
|
||||||
|
let constructor = "createLhloFusionPass()";
|
||||||
|
let options = [
|
||||||
|
Option<"max_num_arguments_per_kernel_", "max-num-arguments-per-kernel", "int",
|
||||||
|
/*default=*/"64", "Maximum allowed number of arguments per fused kernel.">,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -117,6 +117,10 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
||||||
// Legalizes tensor load ops that are inserted during mhlo to lmhlo conversion.
|
// Legalizes tensor load ops that are inserted during mhlo to lmhlo conversion.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTensorLoadOpPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTensorLoadOpPass();
|
||||||
|
|
||||||
|
// fuse lmhlo ops to kLoop/kInput fusion patterns
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLhloFusionPass(
|
||||||
|
int max_num_arguments_per_kernel = 64);
|
||||||
|
|
||||||
} // namespace lmhlo
|
} // namespace lmhlo
|
||||||
|
|
||||||
namespace disc_ral {
|
namespace disc_ral {
|
||||||
|
|
|
@ -137,8 +137,10 @@ add_mlir_library(MhloLhloToLinalg
|
||||||
)
|
)
|
||||||
|
|
||||||
add_mlir_library(LmhloPasses
|
add_mlir_library(LmhloPasses
|
||||||
|
fusion_utils.cc
|
||||||
legalize_tensor_load_op.cc
|
legalize_tensor_load_op.cc
|
||||||
lhlo_fuse_linalg.cc
|
lhlo_fuse_linalg.cc
|
||||||
|
lhlo_fusion.cc
|
||||||
lhlo_legalize_to_affine.cc
|
lhlo_legalize_to_affine.cc
|
||||||
lhlo_legalize_to_gpu.cc
|
lhlo_legalize_to_gpu.cc
|
||||||
lhlo_legalize_to_parallel_loops.cc
|
lhlo_legalize_to_parallel_loops.cc
|
||||||
|
|
|
@ -0,0 +1,394 @@
|
||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
|
||||||
|
// This file implements some helper functions and classes used to do fusion
|
||||||
|
// & code generation.
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace lmhlo {
|
||||||
|
|
||||||
|
// Returns true if the op is an elementwise unary lmhlo op.
|
||||||
|
// TODO(disc): use fusibility interface
|
||||||
|
bool isElementWiseUnary(Operation* op) {
|
||||||
|
// clang-format off
|
||||||
|
return isa<
|
||||||
|
lmhlo::AbsOp,
|
||||||
|
lmhlo::CeilOp,
|
||||||
|
lmhlo::ConvertOp,
|
||||||
|
lmhlo::CopyOp,
|
||||||
|
lmhlo::CosOp,
|
||||||
|
lmhlo::ExpOp,
|
||||||
|
lmhlo::FloorOp,
|
||||||
|
lmhlo::IsFiniteOp,
|
||||||
|
lmhlo::LogOp,
|
||||||
|
lmhlo::NegOp,
|
||||||
|
lmhlo::NotOp,
|
||||||
|
lmhlo::RsqrtOp,
|
||||||
|
lmhlo::SignOp,
|
||||||
|
lmhlo::SqrtOp,
|
||||||
|
lmhlo::TanhOp
|
||||||
|
>(op);
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if the op is an elementwise binary lmhlo op.
|
||||||
|
// TODO(disc): use fusibility interface
|
||||||
|
bool isElementWiseBinary(Operation* op) {
|
||||||
|
// clang-format off
|
||||||
|
return isa<
|
||||||
|
lmhlo::AddOp,
|
||||||
|
lmhlo::AndOp,
|
||||||
|
lmhlo::CompareOp,
|
||||||
|
lmhlo::DivOp,
|
||||||
|
lmhlo::MaxOp,
|
||||||
|
lmhlo::MinOp,
|
||||||
|
lmhlo::MulOp,
|
||||||
|
lmhlo::OrOp,
|
||||||
|
lmhlo::PowOp,
|
||||||
|
lmhlo::SubOp
|
||||||
|
>(op);
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if the op is an elementwise lmhlo op.
|
||||||
|
// TODO(disc): use fusibility interface
|
||||||
|
bool isElementWise(Operation* op) {
|
||||||
|
return isElementWiseUnary(op) || isElementWiseBinary(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if this op is a rank-2 row reduction.
|
||||||
|
bool isRank2RowReduction(Operation* op) {
|
||||||
|
auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
|
||||||
|
if (!reduce_op || reduce_op.dimensions().getNumElements() != 1) return false;
|
||||||
|
|
||||||
|
int rank = op->getOperand(0).getType().cast<MemRefType>().getRank();
|
||||||
|
auto dimensions = reduce_op.dimensions().getValues<int64_t>();
|
||||||
|
return ((*dimensions.begin() == 1) && (rank == 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if this op is a rank-2 column reduction.
|
||||||
|
bool isRank2ColReduction(Operation* op) {
|
||||||
|
auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
|
||||||
|
if (!reduce_op || reduce_op.dimensions().getNumElements() != 1) return false;
|
||||||
|
|
||||||
|
int rank = op->getOperand(0).getType().cast<MemRefType>().getRank();
|
||||||
|
auto dimensions = reduce_op.dimensions().getValues<int64_t>();
|
||||||
|
return ((*dimensions.begin() == 0) && (rank == 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if the op is supported by the downstreaming fusion codegen
|
||||||
|
// engine.
|
||||||
|
bool isFusible(Operation* op) {
|
||||||
|
// Only scalar const are supported by the fusion codegen engine a.t.m.
|
||||||
|
if (dyn_cast<lmhlo::ConstOp>(op)) {
|
||||||
|
MemRefType type = op->getOperand(0).getType().cast<MemRefType>();
|
||||||
|
return (type.getRank() == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// All element ops are supported by the fusion codegen engine.
|
||||||
|
if (isElementWise(op)) return true;
|
||||||
|
|
||||||
|
// Only rank-2 tensor -> rank-1 tensor reduction are supported now.
|
||||||
|
if (isRank2RowReduction(op) || isRank2ColReduction(op)) return true;
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
return isa<
|
||||||
|
lmhlo::BroadcastInDimOp,
|
||||||
|
lmhlo::BroadcastOp,
|
||||||
|
lmhlo::ConcatenateOp,
|
||||||
|
lmhlo::DynamicBroadcastInDimOp,
|
||||||
|
lmhlo::DynamicGatherOp,
|
||||||
|
lmhlo::DynamicIotaOp,
|
||||||
|
lmhlo::DynamicPadOp,
|
||||||
|
lmhlo::DynamicReshapeOp,
|
||||||
|
lmhlo::GatherOp,
|
||||||
|
lmhlo::RealDynamicSliceOp,
|
||||||
|
lmhlo::ReshapeOp,
|
||||||
|
lmhlo::SelectOp,
|
||||||
|
lmhlo::SliceOp,
|
||||||
|
lmhlo::TransposeOp
|
||||||
|
>(op);
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the number of operands that are supposed to be written.
|
||||||
|
// For some ops (e.g. lmhlo ops), some operands are the output memrefs
|
||||||
|
// Thus these operands are supposed to be updated.
|
||||||
|
int getNumResultOperands(Operation* op) {
|
||||||
|
if (op->getDialect()->getNamespace() != "lmhlo") {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto isWritable = [&](Value operand) -> bool {
|
||||||
|
llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
|
||||||
|
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
|
||||||
|
// Suppose that operands of op without `MemoryEffectOpInterface` are
|
||||||
|
// readonly.
|
||||||
|
if (!interface) return false;
|
||||||
|
|
||||||
|
interface.getEffectsOnValue(operand, effects);
|
||||||
|
return llvm::any_of(
|
||||||
|
effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
|
||||||
|
return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return llvm::count_if(op->getOperands(),
|
||||||
|
[&](Value v) { return isWritable(v); });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns data users of the value and its aliases (e.g. memref.cast).
|
||||||
|
// Here non-data users means DimOp, DeallocOp and ShapeOfOp.
|
||||||
|
SmallVector<Operation*, 4> getValueUsers(Value v) {
|
||||||
|
SmallVector<Operation*, 4> users;
|
||||||
|
SmallVector<Value, 4> worklist;
|
||||||
|
worklist.push_back(v);
|
||||||
|
while (!worklist.empty()) {
|
||||||
|
Value curr = worklist.back();
|
||||||
|
worklist.pop_back();
|
||||||
|
for (Operation* user : curr.getUsers()) {
|
||||||
|
// Skip non-data users
|
||||||
|
if (isa<memref::DimOp, memref::DeallocOp, shape::ShapeOfOp>(user)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// alias value
|
||||||
|
if (isa<memref::CastOp>(user)) {
|
||||||
|
worklist.push_back(user->getResult(0));
|
||||||
|
} else {
|
||||||
|
users.push_back(user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return users;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new fusion pattern from a single op.
|
||||||
|
FusionPattern::FusionPattern(Operation* op) {
|
||||||
|
op_list_.push_back(op);
|
||||||
|
if (isRank2RowReduction(op)) {
|
||||||
|
fusion_type_ = FusionType::kRowReduction;
|
||||||
|
} else if (isRank2ColReduction(op)) {
|
||||||
|
fusion_type_ = FusionType::kColReduction;
|
||||||
|
} else if (mlir::lmhlo::isFusible(op)) {
|
||||||
|
fusion_type_ = FusionType::kLoop;
|
||||||
|
} else {
|
||||||
|
fusion_type_ = FusionType::kNone;
|
||||||
|
}
|
||||||
|
dominant_op_ = op;
|
||||||
|
calculateOperandsAndResults();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new fusion pattern from the ops inside the lmhlo fusion op.
|
||||||
|
FusionPattern::FusionPattern(lmhlo::FusionOp op) {
|
||||||
|
for (Operation& op : op.region().getBlocks().front()) {
|
||||||
|
op_list_.push_back(&op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Figure out fusion type and dominant op for the fusion pattern.
|
||||||
|
for (Operation* op : op_list_) {
|
||||||
|
if (isRank2RowReduction(op)) {
|
||||||
|
fusion_type_ = FusionType::kRowReduction;
|
||||||
|
dominant_op_ = op;
|
||||||
|
} else if (isRank2ColReduction(op)) {
|
||||||
|
if (fusion_type_ != FusionType::kRowReduction) {
|
||||||
|
fusion_type_ = FusionType::kColReduction;
|
||||||
|
dominant_op_ = op;
|
||||||
|
}
|
||||||
|
} else if (lmhlo::isFusible(op)) {
|
||||||
|
// Ignore if already a kRowReduction or kColReduction, otherwise update
|
||||||
|
// the fusion type to kLoop and dominant op to current op. This supposes
|
||||||
|
// that the last op inside the block is a valid candidate dominant op if
|
||||||
|
// the fusion pattern is a kLoop.
|
||||||
|
if (fusion_type_ == FusionType::kNone ||
|
||||||
|
fusion_type_ == FusionType::kLoop) {
|
||||||
|
fusion_type_ = FusionType::kLoop;
|
||||||
|
dominant_op_ = op;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Not a supported fusionOp, early stop.
|
||||||
|
fusion_type_ = FusionType::kNone;
|
||||||
|
dominant_op_ = nullptr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isFusible()) calculateOperandsAndResults();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new fusion pattern from a valid fusion op list.
|
||||||
|
FusionPattern::FusionPattern(SmallVectorImpl<Operation*>& op_list)
|
||||||
|
: op_list_(op_list.begin(), op_list.end()) {
|
||||||
|
calculateOperandsAndResults();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if two fusion patterns can be merged into one bigger fusion
|
||||||
|
// pattern.
|
||||||
|
bool FusionPattern::isMergeable(FusionPattern& other) {
|
||||||
|
if (!this->isFusible() || !other.isFusible()) return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merges two fusion patterns and returns the merged pattern. The original
|
||||||
|
// pattern remains unmodified.
|
||||||
|
FusionPattern FusionPattern::merge(FusionPattern& other) {
|
||||||
|
assert(isMergeable(other));
|
||||||
|
FusionOpList new_op_list = op_list_;
|
||||||
|
new_op_list.insert(new_op_list.end(), other.getOpList().begin(),
|
||||||
|
other.getOpList().end());
|
||||||
|
FusionPattern new_fusion_pattern{new_op_list};
|
||||||
|
|
||||||
|
FusionType newType = FusionType::kLoop;
|
||||||
|
Operation* newDominant = getDominantOp();
|
||||||
|
|
||||||
|
// kRowReduction + (kRowReduction | kColReduction | kLoop) = kRowReduction
|
||||||
|
// kColReduction + (kColReduction | kLoop) = kColReduction
|
||||||
|
// kLoop + kLoop = kLoop
|
||||||
|
if (getFusionType() == FusionType::kRowReduction ||
|
||||||
|
other.getFusionType() == FusionType::kRowReduction) {
|
||||||
|
newType = FusionType::kRowReduction;
|
||||||
|
if (getFusionType() != FusionType::kRowReduction)
|
||||||
|
newDominant = other.getDominantOp();
|
||||||
|
} else if (getFusionType() == FusionType::kColReduction ||
|
||||||
|
other.getFusionType() == FusionType::kColReduction) {
|
||||||
|
newType = FusionType::kColReduction;
|
||||||
|
if (getFusionType() != FusionType::kColReduction)
|
||||||
|
newDominant = other.getDominantOp();
|
||||||
|
}
|
||||||
|
|
||||||
|
new_fusion_pattern.setDominantOp(newDominant);
|
||||||
|
new_fusion_pattern.setFusionType(newType);
|
||||||
|
return new_fusion_pattern;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merges two fusion patterns and returns the merged pattern. Replaces the
|
||||||
|
// original pattern with new merged pattern.
|
||||||
|
FusionPattern& FusionPattern::mergeInplace(FusionPattern& other) {
|
||||||
|
*this = merge(other);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the effective size (e.g. not counting const ops) of the ops this
|
||||||
|
// fusion pattern contains.
|
||||||
|
int FusionPattern::effectiveSize() {
|
||||||
|
return llvm::count_if(
|
||||||
|
op_list_, [](Operation* op) { return !matchPattern(op, m_Constant()); });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sorts the ops inside the fusion pattern according to the keys provided.
|
||||||
|
void FusionPattern::sortFusionOpListBy(DenseMap<Operation*, int>& op_to_idx) {
|
||||||
|
std::sort(op_list_.begin(), op_list_.end(),
|
||||||
|
[&](Operation* lhs, Operation* rhs) {
|
||||||
|
return op_to_idx[lhs] < op_to_idx[rhs];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculates the inputs and outputs of the fusion pattern.
|
||||||
|
void FusionPattern::calculateOperandsAndResults() {
|
||||||
|
DenseSet<Value> input_set;
|
||||||
|
DenseSet<Value> result_set;
|
||||||
|
DenseSet<Value> internal_result_set;
|
||||||
|
DenseSet<Operation*> op_set(op_list_.begin(), op_list_.end());
|
||||||
|
|
||||||
|
DenseMap<Value, Operation*> last_writer;
|
||||||
|
for (Operation* op : op_list_) {
|
||||||
|
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
|
||||||
|
for (Value v : op->getOperands().drop_front(num_input_operand)) {
|
||||||
|
bool inserted = last_writer.try_emplace(v, op).second;
|
||||||
|
(void)inserted;
|
||||||
|
assert(inserted);
|
||||||
|
|
||||||
|
bool has_external_user = false;
|
||||||
|
for (Operation* user : getValueUsers(v)) {
|
||||||
|
if (!op_set.contains(user)) {
|
||||||
|
has_external_user = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_external_user) {
|
||||||
|
results_.push_back(v);
|
||||||
|
} else {
|
||||||
|
internal_results_.push_back(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation* op : op_list_) {
|
||||||
|
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
|
||||||
|
for (Value value : op->getOperands().take_front(num_input_operand)) {
|
||||||
|
if (last_writer.find(value) != last_writer.end()) {
|
||||||
|
// skip if defining op is in the pattern
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
input_set.insert(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Value v : input_set) operands_.push_back(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Supports using EquivalenceClasses for Value
|
||||||
|
bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
|
||||||
|
auto lhs_value = lhs.getValue().getAsOpaquePointer();
|
||||||
|
auto rhs_value = rhs.getValue().getAsOpaquePointer();
|
||||||
|
return lhs_value < rhs_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// shape equality propagation based on the shape constrains of
|
||||||
|
// elementwise ops.
|
||||||
|
void ShapeConstraintAnalysis::PropagateEquality(
|
||||||
|
const SmallVectorImpl<Operation*>& op_list) {
|
||||||
|
bool converged = true;
|
||||||
|
do {
|
||||||
|
converged = true;
|
||||||
|
auto update = [&](Value lhs, Value rhs,
|
||||||
|
EquivalenceClasses<ValueWrapper>& impl) {
|
||||||
|
if (!impl.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
|
||||||
|
converged = false;
|
||||||
|
impl.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (Operation* op : op_list) {
|
||||||
|
int num_operand = op->getNumOperands();
|
||||||
|
// Propagates same num_elements equality, and shape equality
|
||||||
|
if (isElementWise(op)) {
|
||||||
|
Value lhs = op->getOperand(0);
|
||||||
|
for (Value rhs : op->getOperands().drop_front()) {
|
||||||
|
update(lhs, rhs, same_num_elements_impl_);
|
||||||
|
update(lhs, rhs, same_shape_impl_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Propagates same num_elements equality, not shape equality
|
||||||
|
if (isa<lmhlo::DynamicReshapeOp, lmhlo::ReshapeOp, lmhlo::TransposeOp>(
|
||||||
|
op)) {
|
||||||
|
Value input = op->getOperand(0);
|
||||||
|
// The last operand is the output memref by design
|
||||||
|
Value output = op->getOperand(num_operand - 1);
|
||||||
|
update(input, output, same_num_elements_impl_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} while (!converged);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace lmhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,570 @@
|
||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"
|
||||||
|
#include "mlir-hlo/utils/cycle_detector.h"
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h" // TF:llvm-project
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
|
||||||
|
|
||||||
|
// This pass has similar functionality of the fusion pass in XLA stack.
|
||||||
|
// However, unlike XLA, it targets the fully dynamic shape scenario.
|
||||||
|
// Currently, it implements the kLoop and kInput fusion templates.
|
||||||
|
// During conversion, it tries to greedily find kLoop/kInput fusion
|
||||||
|
// patterns.
|
||||||
|
//
|
||||||
|
// Similar to XLA, this pass supports fusion pattern having multiple outputs
|
||||||
|
// if all the shape of outputs are consistent. Following are some examples.
|
||||||
|
//
|
||||||
|
// kLoop kInput
|
||||||
|
// +----+ +----+ +----+ +----+ +----+ +----+
|
||||||
|
// |elem| |elem| |elem| |elem<----+elem+---->elem+----+
|
||||||
|
// +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ |
|
||||||
|
// | | | | | |
|
||||||
|
// | | | | |
|
||||||
|
// +-v--+ | +-v--+ +--v---+ +--v---+ |
|
||||||
|
// |elem+<---+----<+elem| |reduce| |reduce| |
|
||||||
|
// +-+--+ +-+--+ +--+---+ +--+---+ |
|
||||||
|
// | | | | |
|
||||||
|
// | | | | |
|
||||||
|
// v v v v v
|
||||||
|
//
|
||||||
|
// To this end, we also add an simple shape constraint analysis phase.
|
||||||
|
// For kLoop fusion template, it requires all the outputs of the fused
|
||||||
|
// pattern have the same shape. However, we don't know the actual value
|
||||||
|
// of the shape at the compile time in the dynamic shape world.
|
||||||
|
// Fortunately, we could still infer the relationship among different ops
|
||||||
|
// according to their shape constraint traits. Currently, We only consider
|
||||||
|
// shape equality propagation for elementwise ops (assuming that implicit
|
||||||
|
// shape broadcast is forbidden). The above process could be built on the
|
||||||
|
// shape dialect once it is ready.
|
||||||
|
//
|
||||||
|
// TODO(disc): This file implements fusion on buffer level, re-visit this after
|
||||||
|
// more shape inference/constraint infras are ready in mhlo level.
|
||||||
|
// TODO(disc): Not using fusibility interface a.t.m, re-visit this if necessary.
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace lmhlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct FusionOptions {
|
||||||
|
// Maximum allowed number of arguments per fused kernel. Here arguments
|
||||||
|
// include both ready-only buffers and writable buffers.
|
||||||
|
int max_num_arguments_per_kernel;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A fusion planner that can propose a fusion plan for a block of ops.
|
||||||
|
// The fusion plan is consisted of a group of fusion patterns.
|
||||||
|
//
|
||||||
|
// Currently all proposed patterns followed xla kLoop/kInput like fusion
|
||||||
|
// templates while are adapted to the fully dynamic shape world.
|
||||||
|
//
|
||||||
|
// kLoop fusion template satisfies:
|
||||||
|
// - all ops in the fusion pattern are element-wise.
|
||||||
|
// - all the shapes of outputs of fusion pattern are same or have same number
|
||||||
|
// of elements, and thus can fit into a same parallel loop.
|
||||||
|
//
|
||||||
|
// kInput fusion template satisfies:
|
||||||
|
// - any op in the fusion pattern is either element-wise or a reduction.
|
||||||
|
// - if a op is a reduction, its output cannot be consumed by other
|
||||||
|
// ops in the same fusion pattern.
|
||||||
|
// - all the effective shapes of outputs of fusion pattern are same.
|
||||||
|
// - For element-wise op, its effective shape is its output shape.
|
||||||
|
// - For reduction op, its effective shape is its operand shape.
|
||||||
|
// - currently our downstreaming codegen engine only support 2d -> 1d tensor
|
||||||
|
// reduction. TODO(disc): lift this limitation.
|
||||||
|
// - 2D row reduction: out[i] = sum({in[i][j] for all j})
|
||||||
|
// - 2D column reduction: out[j] = sum({in[i][j] for all i}
|
||||||
|
class FusionPlanner {
|
||||||
|
public:
|
||||||
|
explicit FusionPlanner(const FusionOptions& options, Block* block)
|
||||||
|
: options_(options), block_(block) {
|
||||||
|
// Move up metadata-only ops (e.g. dim, shape_of) as far as possible.
|
||||||
|
MoveUpMetadataOnlyOpsForFusion();
|
||||||
|
|
||||||
|
for (Operation& op : *block) {
|
||||||
|
op_list_.push_back(&op);
|
||||||
|
}
|
||||||
|
shape_analysis_.reset(new ShapeConstraintAnalysis(op_list_));
|
||||||
|
cycle_detector_.reset(new GraphCycles(op_list_.size()));
|
||||||
|
BuildNodeMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a fusion plan if success, otherwise none.
|
||||||
|
llvm::Optional<FusionPlan> Run() {
|
||||||
|
// Greedily search connected fusible pattern, and ops belonging to
|
||||||
|
// a same fusion pattern are grouped into a cluster.
|
||||||
|
RunEdgeContractionLoop();
|
||||||
|
|
||||||
|
// After doing edge contraction, each unique cluster having size
|
||||||
|
// more than one represents a potential fusion pattern.
|
||||||
|
// We collect all these clusters and construct a fusion plan.
|
||||||
|
FusionPlan plan;
|
||||||
|
DenseSet<Cluster*> seen_clusters;
|
||||||
|
for (Operation* op : op_list_) {
|
||||||
|
Cluster* cluster = GetClusterForNode(op);
|
||||||
|
if (!seen_clusters.insert(cluster).second) continue;
|
||||||
|
FusionPattern& fusion_pattern = cluster->fused_pattern();
|
||||||
|
// Make sure the ops in a fusion pattern are in topological ordering.
|
||||||
|
fusion_pattern.sortFusionOpListBy(op_to_node_id_);
|
||||||
|
if (!fusion_pattern.isFusible() || fusion_pattern.effectiveSize() <= 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
plan.emplace_back(fusion_pattern);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-order ops inside the blocks to make sure all producers are placed
|
||||||
|
// before its consumers after fusion.
|
||||||
|
ReorderOperationsInsideBlock();
|
||||||
|
return plan;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the op_list this planner operates on.
|
||||||
|
const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Represent a (partial) fused pattern
|
||||||
|
class Cluster {
|
||||||
|
public:
|
||||||
|
Cluster(int node_id, FusionPlanner* planner)
|
||||||
|
: node_id_(node_id), pattern_(planner->op_list()[node_id]) {}
|
||||||
|
|
||||||
|
// Merges `other` into this cluster, and clears `other`.
|
||||||
|
void Merge(Cluster* other) {
|
||||||
|
pattern_.mergeInplace(other->fused_pattern());
|
||||||
|
}
|
||||||
|
|
||||||
|
// The number of nodes in this cluster.
|
||||||
|
int cluster_size() { return pattern_.size(); }
|
||||||
|
|
||||||
|
// The ID of the cluster as represented in `cycle_detector_`.
|
||||||
|
int cycles_graph_node_id() const { return node_id_; }
|
||||||
|
|
||||||
|
// Sets the ID of the cluster as represented in `cycle_detector_`.
|
||||||
|
void set_cycles_graph_node_id(int cycles_graph_node_id) {
|
||||||
|
node_id_ = cycles_graph_node_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Currently the fused pattern this cluster holds.
|
||||||
|
FusionPattern& fused_pattern() { return pattern_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// ID of the representative node of this cluster.
|
||||||
|
int node_id_;
|
||||||
|
|
||||||
|
// the fused pattern this cluster holds.
|
||||||
|
FusionPattern pattern_;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Returns a new cluster with specified `cycles_graph_node_id`
|
||||||
|
Cluster* MakeCluster(int cycles_graph_node_id) {
|
||||||
|
cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
|
||||||
|
return cluster_storage_.back().get();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metadata ops (e.g. shapeOf, dimOp) don't change data thus we move forward
|
||||||
|
// them as far as possible inside the same block to enable more fusion
|
||||||
|
// opportunities.
|
||||||
|
void MoveUpMetadataOnlyOpsForFusion() {
|
||||||
|
SmallVector<Operation*, 4> ops;
|
||||||
|
for (Operation& op : *block_) {
|
||||||
|
ops.push_back(&op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inBlock = [&](Operation* op, Block* block) {
|
||||||
|
return op && op->getBlock() == block;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (Operation* op : ops) {
|
||||||
|
Block* block = op->getBlock();
|
||||||
|
if (isa<shape::ShapeOfOp>(op)) {
|
||||||
|
Operation* definingOp = op->getOperand(0).getDefiningOp();
|
||||||
|
if (!inBlock(definingOp, block)) {
|
||||||
|
op->moveBefore(block, block->begin());
|
||||||
|
} else {
|
||||||
|
op->moveAfter(definingOp);
|
||||||
|
}
|
||||||
|
} else if (isa<memref::DimOp>(op)) {
|
||||||
|
Operation* firstOperandOp = op->getOperand(0).getDefiningOp();
|
||||||
|
Operation* secondOperandOp = op->getOperand(1).getDefiningOp();
|
||||||
|
if (!inBlock(firstOperandOp, block) &&
|
||||||
|
!inBlock(secondOperandOp, block)) {
|
||||||
|
op->moveBefore(block, block->begin());
|
||||||
|
} else if (!inBlock(firstOperandOp, block)) {
|
||||||
|
op->moveAfter(secondOperandOp);
|
||||||
|
} else if (!inBlock(secondOperandOp, block)) {
|
||||||
|
op->moveAfter(firstOperandOp);
|
||||||
|
} else if (firstOperandOp->isBeforeInBlock(secondOperandOp)) {
|
||||||
|
op->moveAfter(secondOperandOp);
|
||||||
|
} else {
|
||||||
|
op->moveAfter(firstOperandOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns all the values touched by this op or its nested ops.
|
||||||
|
SmallVector<Value, 4> GetAllPossibleUsedValues(Operation* op) {
|
||||||
|
SmallVector<Value, 4> values;
|
||||||
|
op->walk([&](Operation* nest_op) {
|
||||||
|
for (Value v : nest_op->getOperands()) {
|
||||||
|
values.push_back(v);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Builds the initial dependency graph.
|
||||||
|
void BuildNodeMap() {
|
||||||
|
int num_nodes = op_list_.size();
|
||||||
|
for (int node_id = 0; node_id < num_nodes; ++node_id) {
|
||||||
|
Operation* op = op_list_[node_id];
|
||||||
|
MakeCluster(node_id);
|
||||||
|
op_to_node_id_[op] = node_id;
|
||||||
|
leader_for_node_.insert(node_id);
|
||||||
|
for (Value operand : GetAllPossibleUsedValues(op)) {
|
||||||
|
Operation* operand_op = FindLastWriter(operand);
|
||||||
|
// Only consider the operand_op inside the target block.
|
||||||
|
auto iter = op_to_node_id_.find(operand_op);
|
||||||
|
if (iter == op_to_node_id_.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Add an edge to connect the last writer and the current consumer.
|
||||||
|
cycle_detector_->InsertEdge(iter->second, node_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For some ops (e.g. lmhlo ops), some operands are the output memrefs
|
||||||
|
// Thus these operands are supposed to be updated.
|
||||||
|
// Suppose that a op (or its nested ops) can only write the buffers
|
||||||
|
// explicit passed in as operands of this op.
|
||||||
|
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
|
||||||
|
for (Value v : op->getOperands().drop_front(num_input_operand)) {
|
||||||
|
auto it = last_writer_.try_emplace(v, op);
|
||||||
|
(void)it;
|
||||||
|
// Currently, a buffer is only supposed to be written once (as the
|
||||||
|
// output operand of one lmhlo op).
|
||||||
|
assert(it.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the cluster contains this op.
|
||||||
|
Cluster* GetClusterForNode(Operation* n) {
|
||||||
|
int id = op_to_node_id_[n];
|
||||||
|
id = leader_for_node_.getLeaderValue(id);
|
||||||
|
return cluster_storage_[id].get();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the cluster contains the op having `node_id`.
|
||||||
|
Cluster* GetClusterForCyclesGraphNode(int node_id) {
|
||||||
|
return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merges the clusters `cluster_from` and `cluster_to`.
|
||||||
|
bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
|
||||||
|
int from = cluster_from->cycles_graph_node_id();
|
||||||
|
int to = cluster_to->cycles_graph_node_id();
|
||||||
|
|
||||||
|
auto optional_merged_node = cycle_detector_->ContractEdge(from, to);
|
||||||
|
if (!optional_merged_node.hasValue()) {
|
||||||
|
llvm::dbgs() << "Could not contract " << from << " -> " << to
|
||||||
|
<< " because contracting the edge would create a cycle.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge the clusters.
|
||||||
|
cluster_from->Merge(cluster_to);
|
||||||
|
cluster_from->set_cycles_graph_node_id(*optional_merged_node);
|
||||||
|
|
||||||
|
// Merge the UnionFind Set.
|
||||||
|
leader_for_node_.unionSets(from, to);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
using FnTy = llvm::function_ref<bool(Cluster*, Cluster*)>;
|
||||||
|
bool ForEachEdgeInPostOrder(FnTy fn, bool enable_cross_fusion = false) {
|
||||||
|
bool changed = false;
|
||||||
|
for (int32_t node : cycle_detector_->AllNodesInPostOrder()) {
|
||||||
|
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
|
||||||
|
// Make a copy of the set of successors because we may modify the graph in
|
||||||
|
// TryToContractEdge.
|
||||||
|
std::vector<int32_t> successors_copy =
|
||||||
|
cycle_detector_->SuccessorsCopy(cluster_from->cycles_graph_node_id());
|
||||||
|
|
||||||
|
for (int to : successors_copy) {
|
||||||
|
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
|
||||||
|
bool contracted_edge = fn(cluster_from, cluster_to);
|
||||||
|
changed |= contracted_edge;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!enable_cross_fusion) return changed;
|
||||||
|
|
||||||
|
// To enable even more fusion opportunities (e.g. horizontal fusion)
|
||||||
|
for (int32_t lhs : cycle_detector_->AllNodesInPostOrder()) {
|
||||||
|
Cluster* cluster_lhs = GetClusterForCyclesGraphNode(lhs);
|
||||||
|
if (!cluster_lhs) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int32_t rhs : cycle_detector_->AllNodesInPostOrder()) {
|
||||||
|
Cluster* cluster_rhs = GetClusterForCyclesGraphNode(rhs);
|
||||||
|
if (!cluster_rhs || cluster_lhs == cluster_rhs) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool contracted_edge = fn(cluster_lhs, cluster_rhs);
|
||||||
|
changed |= contracted_edge;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function check if fusing `from` with `to` is valid and if so perform
|
||||||
|
// the merge. The validity is based on the operations in the clusters and
|
||||||
|
// the compatibility of the shapes of the outputs of the would-be fused
|
||||||
|
// clusters.
|
||||||
|
// Returns true is the merge was performed.
|
||||||
|
bool TryToContractEdge(Cluster* from, Cluster* to) {
|
||||||
|
// Try merge and check if valid.
|
||||||
|
if (!from->fused_pattern().isMergeable(to->fused_pattern())) return false;
|
||||||
|
FusionPattern fused_pattern =
|
||||||
|
from->fused_pattern().merge(to->fused_pattern());
|
||||||
|
auto& op_list = fused_pattern.getOpList();
|
||||||
|
auto& operands = fused_pattern.getOperands();
|
||||||
|
auto& results = fused_pattern.getResults();
|
||||||
|
|
||||||
|
if (results.size() + operands.size() >
|
||||||
|
options_.max_num_arguments_per_kernel) {
|
||||||
|
// some backend devices (e.g. GPU) do not support a kernel with
|
||||||
|
// too many arguments.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We currently do not support a constant op as final output of a fusion
|
||||||
|
// pattern.
|
||||||
|
// TODO(disc): copy small const in case necessary.
|
||||||
|
for (Value result : results) {
|
||||||
|
Operation* result_op = FindLastWriter(result);
|
||||||
|
assert(result_op);
|
||||||
|
if (isa<lmhlo::ConstOp>(result_op)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReduceOp can not have consumer within the fusion pattern.
|
||||||
|
for (Operation* op : op_list) {
|
||||||
|
if (!isa<lmhlo::ReduceOp>(op)) continue;
|
||||||
|
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
|
||||||
|
for (Value v : op->getOperands().drop_front(num_input_operand)) {
|
||||||
|
for (Operation* user : getValueUsers(v)) {
|
||||||
|
if (user == op) continue;
|
||||||
|
if (std::find(op_list.begin(), op_list.end(), user) !=
|
||||||
|
op_list.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All outputs of a fusion pattern should have compatible shape.
|
||||||
|
// Here `compatible` means:
|
||||||
|
// - if `to` and `from` are both kInput fusion, all output should have same
|
||||||
|
// shape.
|
||||||
|
// - otherwise, all output should have same number of elements.
|
||||||
|
|
||||||
|
// No outside users, these ops may be eliminated. We fused it here and let
|
||||||
|
// latter pass to do such DCE.
|
||||||
|
if (results.empty()) return true;
|
||||||
|
|
||||||
|
bool check_same_shape = (to->fused_pattern().isKInputFusion() &&
|
||||||
|
from->fused_pattern().isKInputFusion());
|
||||||
|
auto get_effective_shape = [&](Value v) {
|
||||||
|
auto result_op = FindLastWriter(v);
|
||||||
|
assert(result_op);
|
||||||
|
// effective shape of reduce op is its operand's shape.
|
||||||
|
return isa<lmhlo::ReduceOp>(result_op) ? result_op->getOperand(0) : v;
|
||||||
|
};
|
||||||
|
|
||||||
|
Value ref_shape = get_effective_shape(results[0]);
|
||||||
|
if (!llvm::all_of(results, [&](Value result) {
|
||||||
|
Value shape = get_effective_shape(result);
|
||||||
|
return check_same_shape
|
||||||
|
? shape_analysis_->HasSameShape(ref_shape, shape)
|
||||||
|
: shape_analysis_->HasSameNumElements(ref_shape, shape);
|
||||||
|
})) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return MergeClusters(from, to);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Greedily fuse connected node.
|
||||||
|
bool RunEdgeContractionLoop() {
|
||||||
|
using std::placeholders::_1;
|
||||||
|
using std::placeholders::_2;
|
||||||
|
bool changed = false;
|
||||||
|
|
||||||
|
// Run fusion pass repeatedly until nothing to be fused
|
||||||
|
while (ForEachEdgeInPostOrder(
|
||||||
|
std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2), false)) {
|
||||||
|
// empty statement by design
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Here `value` is supported to be a pointer to buffer.
|
||||||
|
// Returns the defining op of `value `if no known op updates the buffer,
|
||||||
|
// otherwise returns the last op that updates the buffer pointed by the
|
||||||
|
// `value`.
|
||||||
|
Operation* FindLastWriter(Value value) {
|
||||||
|
auto it = last_writer_.find(value);
|
||||||
|
if (it != last_writer_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
return value.getDefiningOp();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-order ops inside the block to make sure that producers are before
|
||||||
|
// consumers after fusion.
|
||||||
|
void ReorderOperationsInsideBlock() {
|
||||||
|
auto reorder_func = [&](Cluster* from, Cluster* to) {
|
||||||
|
FusionPattern& from_pattern = from->fused_pattern();
|
||||||
|
FusionPattern& to_pattern = to->fused_pattern();
|
||||||
|
|
||||||
|
Operation* last_op_in_from = from_pattern.getOpList().back();
|
||||||
|
for (Operation* op : llvm::reverse(to_pattern.getOpList())) {
|
||||||
|
if (!last_op_in_from->isBeforeInBlock(op))
|
||||||
|
op->moveAfter(last_op_in_from);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
ForEachEdgeInPostOrder(reorder_func);
|
||||||
|
}
|
||||||
|
|
||||||
|
// hyper-parameters that controls the behaviour of the fusion planner.
|
||||||
|
FusionOptions options_;
|
||||||
|
|
||||||
|
// The block that fusion planner works on.
|
||||||
|
Block* block_;
|
||||||
|
|
||||||
|
// Ops inside the block
|
||||||
|
SmallVector<Operation*, 4> op_list_;
|
||||||
|
|
||||||
|
// Shape equality checker
|
||||||
|
std::unique_ptr<ShapeConstraintAnalysis> shape_analysis_;
|
||||||
|
|
||||||
|
// op -> node_id
|
||||||
|
DenseMap<Operation*, int> op_to_node_id_;
|
||||||
|
|
||||||
|
// make sure not introduce cycle after fusion
|
||||||
|
std::unique_ptr<GraphCycles> cycle_detector_;
|
||||||
|
std::vector<std::unique_ptr<Cluster>> cluster_storage_;
|
||||||
|
|
||||||
|
// a UnionFind set. Each set represents a (partial) fused pattern
|
||||||
|
// and has a leader as representation.
|
||||||
|
EquivalenceClasses<int32_t> leader_for_node_;
|
||||||
|
|
||||||
|
// Here `value` is supported to be a pointer to buffer.
|
||||||
|
// Returns the defining op of `value `if no known op updates the buffer,
|
||||||
|
// otherwise returns the last op that updates the buffer pointed by the
|
||||||
|
// `value`.
|
||||||
|
DenseMap<Value, Operation*> last_writer_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LhloFusionPass : public LhloFusionPassBase<LhloFusionPass> {
|
||||||
|
using LhloFusionPassBase<LhloFusionPass>::LhloFusionPassBase;
|
||||||
|
explicit LhloFusionPass(int max_num_arguments_per_kernel)
|
||||||
|
: LhloFusionPassBase<LhloFusionPass>::LhloFusionPassBase() {
|
||||||
|
this->max_num_arguments_per_kernel_ = max_num_arguments_per_kernel;
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnFunction() override {
|
||||||
|
FuncOp func = getFunction();
|
||||||
|
|
||||||
|
// collect all blocks inside the function.
|
||||||
|
SmallVector<Block*, 4> blocks;
|
||||||
|
CollectBlocksInsideFunction(func, blocks);
|
||||||
|
|
||||||
|
// process each block and do fusion within a block.
|
||||||
|
FusionOptions options;
|
||||||
|
options.max_num_arguments_per_kernel = max_num_arguments_per_kernel_;
|
||||||
|
for (Block* block : blocks) {
|
||||||
|
FusionPlanner planner(options, block);
|
||||||
|
llvm::Optional<FusionPlan> plan = planner.Run();
|
||||||
|
if (!plan) {
|
||||||
|
emitError(func.getLoc(),
|
||||||
|
"an error occurs while trying to find fusion candidates");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!ApplyFusionPlan(*plan)) {
|
||||||
|
emitError(func.getLoc(), "apply fusion plan failed");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ApplyFusionPlan(FusionPlan& plan) {
|
||||||
|
for (FusionPattern& pattern : plan) {
|
||||||
|
auto& op_list = pattern.getOpList();
|
||||||
|
OpBuilder b(op_list.back());
|
||||||
|
|
||||||
|
// Get the fused locations
|
||||||
|
SmallVector<Location, 4> locations;
|
||||||
|
locations.reserve(op_list.size());
|
||||||
|
for (Operation* op : op_list) {
|
||||||
|
locations.push_back(op->getLoc());
|
||||||
|
}
|
||||||
|
Location fused_loc =
|
||||||
|
FusedLoc::get(op_list.back()->getContext(), locations);
|
||||||
|
|
||||||
|
// Move ops inside fusion pattern to the region attached to the fusion op.
|
||||||
|
FusionOp fusion = b.create<lmhlo::FusionOp>(fused_loc);
|
||||||
|
Region& region = fusion.region();
|
||||||
|
Block& block = region.front();
|
||||||
|
for (Operation* op : llvm::reverse(op_list)) {
|
||||||
|
op->moveBefore(&block, block.begin());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CollectBlocksInsideFunction(FuncOp op, SmallVectorImpl<Block*>& blocks) {
|
||||||
|
op.walk([&](Block* block) {
|
||||||
|
// It does not make sense to fuse the region attached to these ops.
|
||||||
|
if (!isa<lmhlo::ReduceOp, lmhlo::FusionOp>(block->getParentOp()))
|
||||||
|
blocks.push_back(block);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLhloFusionPass(
|
||||||
|
int max_num_arguments_per_kernel) {
|
||||||
|
return std::make_unique<LhloFusionPass>(max_num_arguments_per_kernel);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace lmhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,286 @@
|
||||||
|
// RUN: mlir-hlo-opt --lhlo-fusion -split-input-file %s -o - | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: @simple_kloop_fusion
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>) -> memref<?x?xf32>
|
||||||
|
func @simple_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
|
||||||
|
%arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> memref<?x?xf32> {
|
||||||
|
// CHECK: "lmhlo.fusion"() ( {
|
||||||
|
// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: }) : () -> ()
|
||||||
|
// CHECK: return %[[ARG3]] : memref<?x?xf32>
|
||||||
|
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
"lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
return %arg3 : memref<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @simple_multi_output_kloop_fusion
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>)
|
||||||
|
func @simple_multi_output_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
|
||||||
|
%arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) {
|
||||||
|
// CHECK: "lmhlo.fusion"() ( {
|
||||||
|
// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: }) : () -> ()
|
||||||
|
// CHECK: return %[[ARG1]], %[[ARG3]] : memref<?x?xf32>, memref<?x?xf32>
|
||||||
|
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
"lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
return %arg1, %arg3 : memref<?x?xf32>, memref<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @simple_multi_output_kloop_fusion_with_reorder
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>, %[[ARG4:.*]]: memref<2xindex>, %[[ARG5:.*]]: memref<?x?xf32>)
|
||||||
|
func @simple_multi_output_kloop_fusion_with_reorder(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
|
||||||
|
%arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
|
||||||
|
%arg4: memref<2xindex>, %arg5: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) {
|
||||||
|
// CHECK: "lmhlo.fusion"() ( {
|
||||||
|
// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: }) : () -> ()
|
||||||
|
// CHECK: "lmhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[ARG4]], %[[ARG5]])
|
||||||
|
// CHECK: return %[[ARG1]], %[[ARG3]], %[[ARG5]] : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
|
||||||
|
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
"lmhlo.dynamic_broadcast_in_dim"(%arg1, %arg4, %arg5) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (memref<?x?xf32>, memref<2xindex>, memref<?x?xf32>) -> ()
|
||||||
|
"lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
return %arg1, %arg3, %arg5 : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @same_num_elements_multi_output_kloop_fusion
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<2xi64>, %[[ARG3:.*]]: memref<?x?x?xf32>, %[[ARG4:.*]]: memref<?x?x?xf32>, %[[ARG5:.*]]: memref<?x?x?xf32>)
|
||||||
|
func @same_num_elements_multi_output_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
|
||||||
|
%arg2: memref<2xi64>, %arg3: memref<?x?x?xf32>,
|
||||||
|
%arg4: memref<?x?x?xf32>, %arg5: memref<?x?x?xf32>) -> (memref<?x?xf32>, memref<?x?x?xf32>) {
|
||||||
|
// CHECK: "lmhlo.fusion"() ( {
|
||||||
|
// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.dynamic_reshape"(%[[ARG1]], %[[ARG2]], %[[ARG3]])
|
||||||
|
// CHECK: "lmhlo.add"(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
|
||||||
|
// CHECK: }) : () -> ()
|
||||||
|
// CHECK: return %[[ARG1]], %[[ARG5]] : memref<?x?xf32>, memref<?x?x?xf32>
|
||||||
|
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
"lmhlo.dynamic_reshape"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<2xi64>, memref<?x?x?xf32>) -> ()
|
||||||
|
"lmhlo.add"(%arg3, %arg4, %arg5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
|
||||||
|
return %arg1, %arg5 : memref<?x?xf32>, memref<?x?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @check_not_kloop_fusion
|
||||||
|
func @check_not_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) {
|
||||||
|
// CHECK-NOT: "lmhlo.fusion"
|
||||||
|
"lmhlo.add"(%arg0, %arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
"lmhlo.subtract"(%arg2, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
return %arg1, %arg3: memref<?x?xf32>, memref<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @kloop_fusion_with_dealloc
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>)
|
||||||
|
func @kloop_fusion_with_dealloc(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) {
|
||||||
|
// CHECK: %[[TMP3:.*]] = memref.alloc
|
||||||
|
// CHECK: %[[TMP5:.*]] = memref.alloc
|
||||||
|
// CHECK: %[[TMP9:.*]] = memref.alloc
|
||||||
|
// CHECK: %[[TMP13:.*]] = memref.alloc
|
||||||
|
// CHECK: %[[TMP16:.*]] = memref.alloc
|
||||||
|
// CHECK: "lmhlo.fusion"() ( {
|
||||||
|
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.multiply"(%[[ARG0]], %[[ARG1]], %[[TMP5]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.abs"(%[[TMP3]], %[[TMP9]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.abs"(%[[TMP5]], %[[TMP13]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.multiply"(%[[TMP9]], %[[TMP13]], %[[TMP16]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: }) : () -> ()
|
||||||
|
// CHECK: memref.dealloc %[[TMP3]] : memref<?x?xf32>
|
||||||
|
// CHECK: memref.dealloc %[[TMP5]] : memref<?x?xf32>
|
||||||
|
// CHECK: memref.dealloc %[[TMP13]] : memref<?x?xf32>
|
||||||
|
// CHECK: return %[[TMP9]], %[[TMP16]] : memref<?x?xf32>, memref<?x?xf32>
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%0 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
||||||
|
%1 = tensor.extract %0[%c0] : tensor<2xindex>
|
||||||
|
%2 = tensor.extract %0[%c1] : tensor<2xindex>
|
||||||
|
%3 = memref.alloc(%1, %2) : memref<?x?xf32>
|
||||||
|
"lmhlo.add"(%arg0, %arg1, %3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
%4 = memref.alloc(%1, %2) : memref<?x?xf32>
|
||||||
|
"lmhlo.multiply"(%arg0, %arg1, %4) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
%5 = shape.shape_of %3 : memref<?x?xf32> -> tensor<2xindex>
|
||||||
|
%6 = tensor.extract %5[%c0] : tensor<2xindex>
|
||||||
|
%7 = tensor.extract %5[%c1] : tensor<2xindex>
|
||||||
|
%8 = memref.alloc(%6, %7) : memref<?x?xf32>
|
||||||
|
"lmhlo.abs"(%3, %8) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
memref.dealloc %3 : memref<?x?xf32>
|
||||||
|
%9 = shape.shape_of %4 : memref<?x?xf32> -> tensor<2xindex>
|
||||||
|
%10 = tensor.extract %9[%c0] : tensor<2xindex>
|
||||||
|
%11 = tensor.extract %9[%c1] : tensor<2xindex>
|
||||||
|
%12 = memref.alloc(%10, %11) : memref<?x?xf32>
|
||||||
|
"lmhlo.abs"(%4, %12) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
memref.dealloc %4 : memref<?x?xf32>
|
||||||
|
%13 = shape.shape_of %8 : memref<?x?xf32> -> tensor<2xindex>
|
||||||
|
%14 = tensor.extract %13[%c0] : tensor<2xindex>
|
||||||
|
%15 = tensor.extract %13[%c1] : tensor<2xindex>
|
||||||
|
%16 = memref.alloc(%14, %15) : memref<?x?xf32>
|
||||||
|
"lmhlo.multiply"(%8, %12, %16) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
memref.dealloc %12 : memref<?x?xf32>
|
||||||
|
return %8, %16 : memref<?x?xf32>, memref<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @simple_kinput
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<f32>
|
||||||
|
func @simple_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?xf32>, %init: memref<f32>) -> memref<?xf32> {
|
||||||
|
// CHECK: "lmhlo.fusion"() ( {
|
||||||
|
// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( {
|
||||||
|
// CHECK: }) : () -> ()
|
||||||
|
// CHECK: return %[[ARG2]] : memref<?xf32>
|
||||||
|
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
"lmhlo.reduce"(%arg1, %init, %arg2) ( {
|
||||||
|
^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
|
||||||
|
"lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
} ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
|
||||||
|
return %arg2: memref<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @multi_output_kinput
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<f32>
|
||||||
|
func @multi_output_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?xf32>, %init: memref<f32>) -> (memref<?x?xf32>, memref<?xf32>) {
|
||||||
|
// CHECK: "lmhlo.fusion"() ( {
|
||||||
|
// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( {
|
||||||
|
// CHECK: }) : () -> ()
|
||||||
|
// CHECK: return %[[ARG1]], %[[ARG2]] : memref<?x?xf32>, memref<?xf32>
|
||||||
|
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
"lmhlo.reduce"(%arg1, %init, %arg2) ( {
|
||||||
|
^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
|
||||||
|
"lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
} ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
|
||||||
|
return %arg1, %arg2: memref<?x?xf32>, memref<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @row_red_and_row_red_kinput
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?xf32>, %[[ARG4:.*]]: memref<?xf32>, %[[ARG5:.*]]: memref<?x?xf32>, %[[ARG6:.*]]: memref<f32>
|
||||||
|
func @row_red_and_row_red_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?xf32>, %arg4: memref<?xf32>, %arg5: memref<?x?xf32>, %init: memref<f32>) -> (memref<?xf32>, memref<?xf32>) {
|
||||||
|
// CHECK: "lmhlo.fusion"() ( {
|
||||||
|
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( {
|
||||||
|
// CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( {
|
||||||
|
// CHECK: }) : () -> ()
|
||||||
|
// CHECK: return %[[ARG3]], %[[ARG4]] : memref<?xf32>, memref<?xf32>
|
||||||
|
"lmhlo.add"(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
"lmhlo.abs"(%arg2, %arg5) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
"lmhlo.reduce"(%arg5, %init, %arg3) ( {
|
||||||
|
^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
|
||||||
|
"lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
} ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
|
||||||
|
"lmhlo.reduce"(%arg2, %init, %arg4) ( {
|
||||||
|
^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
|
||||||
|
"lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
} ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
|
||||||
|
return %arg3, %arg4: memref<?xf32>, memref<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @row_red_and_col_red_kinput
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?xf32>, %[[ARG4:.*]]: memref<?xf32>, %[[ARG5:.*]]: memref<?x?xf32>, %[[ARG6:.*]]: memref<f32>
|
||||||
|
func @row_red_and_col_red_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?xf32>, %arg4: memref<?xf32>, %arg5: memref<?x?xf32>, %init: memref<f32>) -> (memref<?xf32>, memref<?xf32>) {
|
||||||
|
// CHECK: "lmhlo.fusion"() ( {
|
||||||
|
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( {
|
||||||
|
// CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( {
|
||||||
|
// CHECK: }) : () -> ()
|
||||||
|
// CHECK: return %[[ARG3]], %[[ARG4]] : memref<?xf32>, memref<?xf32>
|
||||||
|
"lmhlo.add"(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
"lmhlo.abs"(%arg2, %arg5) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
"lmhlo.reduce"(%arg5, %init, %arg3) ( {
|
||||||
|
^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
|
||||||
|
"lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
} ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
|
||||||
|
"lmhlo.reduce"(%arg2, %init, %arg4) ( {
|
||||||
|
^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
|
||||||
|
"lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
} ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
|
||||||
|
return %arg3, %arg4: memref<?xf32>, memref<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @reduce_should_not_have_consumer_in_the_fusion
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>
|
||||||
|
func @reduce_should_not_have_consumer_in_the_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>)
|
||||||
|
-> (memref<?x?xf32>, memref<?xf32>) {
|
||||||
|
// CHECK: %[[TMP4:.*]] = memref.alloc
|
||||||
|
// CHECK: %[[TMP7:.*]] = memref.alloc
|
||||||
|
// CHECK: %[[TMP8:.*]] = memref.alloc
|
||||||
|
// CHECK: %[[TMP9:.*]] = memref.alloc
|
||||||
|
// CHECK: "lmhlo.fusion"() ( {
|
||||||
|
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP4]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.subtract"(%[[ARG0]], %[[TMP4]], %[[TMP7]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
// CHECK: "lmhlo.constant"(%[[TMP8]]) {value = dense<0.000000e+00> : tensor<f32>} : (memref<f32>) -> ()
|
||||||
|
// CHECK: "lmhlo.reduce"(%[[TMP7]], %[[TMP8]], %[[TMP9]]) ( {
|
||||||
|
// CHECK: }) : () -> ()
|
||||||
|
// CHECK: memref.dealloc %[[TMP4]] : memref<?x?xf32>
|
||||||
|
// CHECK: memref.dealloc %[[TMP8]] : memref<f32>
|
||||||
|
// CHECK: %[[TMP12:.*]] = memref.alloc
|
||||||
|
// CHECK: "lmhlo.add"(%[[TMP9]], %[[TMP9]], %[[TMP12]]) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
|
||||||
|
// CHECK: memref.dealloc %[[TMP9]] : memref<?xf32>
|
||||||
|
// CHECK: return %[[TMP7]], %[[TMP12]] : memref<?x?xf32>, memref<?xf32>
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%0 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
||||||
|
%1 = tensor.extract %0[%c0] : tensor<2xindex>
|
||||||
|
%2 = tensor.extract %0[%c1] : tensor<2xindex>
|
||||||
|
%3 = memref.alloc(%1, %2) : memref<?x?xf32>
|
||||||
|
"lmhlo.add"(%arg0, %arg1, %3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
%4 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
||||||
|
%5 = tensor.extract %4[%c0] : tensor<2xindex>
|
||||||
|
%6 = tensor.extract %4[%c1] : tensor<2xindex>
|
||||||
|
%7 = memref.alloc(%5, %6) : memref<?x?xf32>
|
||||||
|
"lmhlo.subtract"(%arg0, %3, %7) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
memref.dealloc %3 : memref<?x?xf32>
|
||||||
|
%8 = memref.alloc() : memref<f32>
|
||||||
|
"lmhlo.constant"(%8) {value = dense<0.000000e+00> : tensor<f32>} : (memref<f32>) -> ()
|
||||||
|
%9 = memref.alloc(%5) : memref<?xf32>
|
||||||
|
"lmhlo.reduce"(%7, %8, %9) ( {
|
||||||
|
^bb0(%arg2: memref<f32>, %arg3: memref<f32>, %arg4: memref<f32>): // no predecessors
|
||||||
|
"lmhlo.add"(%arg2, %arg3, %arg4) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
}) {dimensions = dense<1> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
|
||||||
|
memref.dealloc %8 : memref<f32>
|
||||||
|
%10 = shape.shape_of %9 : memref<?xf32> -> tensor<1xindex>
|
||||||
|
%11 = tensor.extract %10[%c0] : tensor<1xindex>
|
||||||
|
%12 = memref.alloc(%11) : memref<?xf32>
|
||||||
|
"lmhlo.add"(%9, %9, %12) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
|
||||||
|
memref.dealloc %9 : memref<?xf32>
|
||||||
|
return %7, %12 : memref<?x?xf32>, memref<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @const_should_not_be_output
|
||||||
|
func @const_should_not_be_output(%arg0: memref<f32>) -> (memref<f32>, memref<f32>) {
|
||||||
|
// CHECK-NOT: lmhlo.fusion
|
||||||
|
%0 = memref.alloc() : memref<f32>
|
||||||
|
"lmhlo.constant"(%0) {value = dense<1.000000e+00> : tensor<f32>} : (memref<f32>) -> ()
|
||||||
|
%1 = memref.alloc() : memref<f32>
|
||||||
|
"lmhlo.add"(%arg0, %0, %1) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
|
return %0, %1 : memref<f32>, memref<f32>
|
||||||
|
}
|
Loading…
Reference in New Issue