[MLIR][HLO] Add `rank-specialization-cluster` pass
Add a pass to cluster unranked C/HLO operations in one `chlo.rank_specialization_cluster` op. The C/HLO operations are moved to the body of the operation. Later passes can use this to rank-specialize all these operations together. PiperOrigin-RevId: 373336725
This commit is contained in:
parent
7f84779868
commit
313d24bc8f
23
BUILD
23
BUILD
|
@ -706,6 +706,28 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rank_specialization",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/rank_specialization.cc"],
|
||||
hdrs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
|
||||
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:InferTypeOpInterface",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:TensorDialect",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lhlo_legalize_to_gpu",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc"],
|
||||
|
@ -1079,6 +1101,7 @@ cc_library(
|
|||
":mhlo_fusion",
|
||||
":mhlo_to_mhlo_lowering_patterns",
|
||||
":move_up_dynamic_broadcasts_for_fusion",
|
||||
":rank_specialization",
|
||||
":sink_constants_to_control_flow",
|
||||
":test_passes",
|
||||
":transform_unranked_hlo",
|
||||
|
|
|
@ -48,6 +48,22 @@ class HloClientDialect : public Dialect {
|
|||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace chlo {
|
||||
namespace OpTrait {
|
||||
|
||||
template <typename ConcreteType>
|
||||
class BroadcastingElementwise
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
|
||||
|
||||
template <typename ConcreteType>
|
||||
class Broadcasting
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, Broadcasting> {};
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
|
||||
|
||||
|
|
|
@ -61,6 +61,17 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
|
|||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
class HLOClient_NativeOpTrait<string name> : NativeOpTrait<name> {
|
||||
let cppNamespace = "::mlir::chlo::OpTrait";
|
||||
}
|
||||
|
||||
def HLOClient_Broadcasting : HLOClient_NativeOpTrait<"Broadcasting"> {
|
||||
}
|
||||
|
||||
def HLOClient_BroadcastingElementwise
|
||||
: HLOClient_NativeOpTrait<"BroadcastingElementwise"> {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CHLO binary elementwise op definitions.
|
||||
// From the client perspective, each of these support both explicit rank
|
||||
|
@ -78,6 +89,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
|
|||
|
||||
class HLOClient_BroadcastBinaryElementwiseOp<
|
||||
string mnemonic, list<OpTrait> traits> : HLOClient_Op<mnemonic, traits # [
|
||||
HLOClient_BroadcastingElementwise, HLOClient_Broadcasting,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
|
||||
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> {
|
||||
let arguments = (ins
|
||||
|
|
|
@ -121,3 +121,10 @@ def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
|
|||
let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
|
||||
let constructor = "createTestUnfuseBatchNormPass()";
|
||||
}
|
||||
|
||||
/// Rank specialization passes.
|
||||
|
||||
def RankSpecializationClusterPass
|
||||
: Pass<"mhlo-rank-specialization-cluster", "FuncOp"> {
|
||||
let constructor = "createRankSpecializationClusterPass()";
|
||||
}
|
||||
|
|
|
@ -69,6 +69,11 @@ createLegalizeTrigonometricToApproximationPass();
|
|||
|
||||
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
|
||||
|
||||
/// Rank specialization passes.
|
||||
/// - Find compatible operations and group them together in one rank
|
||||
/// specialization region.
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass();
|
||||
|
||||
std::unique_ptr<FunctionPass> createOptimizeMhloPass();
|
||||
std::unique_ptr<FunctionPass> createLowerComplexPass();
|
||||
std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass();
|
||||
|
|
|
@ -100,6 +100,10 @@ void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target);
|
|||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
|
||||
/// Populate rank specialization clustering patterns.
|
||||
void PopulateRankSpecializationClusterPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace mhlo
|
||||
|
||||
namespace chlo {
|
||||
|
|
|
@ -48,7 +48,6 @@ add_mlir_library(ChloPasses
|
|||
)
|
||||
|
||||
add_mlir_library(MhloPasses
|
||||
move_up_dynamic_broadcasts_for_fusion.cc
|
||||
legalize_gather_to_torch_index_select.cc
|
||||
legalize_trigonometric_to_approximation.cc
|
||||
lower_complex.cc
|
||||
|
@ -57,8 +56,10 @@ add_mlir_library(MhloPasses
|
|||
materialize_broadcasts.cc
|
||||
materialize_broadcasts_pass.cc
|
||||
mhlo_fusion.cc
|
||||
move_up_dynamic_broadcasts_for_fusion.cc
|
||||
optimize_mhlo.cc
|
||||
optimize_mhlo_pass.cc
|
||||
rank_specialization.cc
|
||||
sink_constants_to_control_flow.cc
|
||||
test_infer_shaped_type_pass.cc
|
||||
transform_unranked_hlo.cc
|
||||
|
|
|
@ -0,0 +1,179 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// Needed to build `llvm::SmallSet`s of `mlir::Value`s.
|
||||
static bool operator<(const Value &lhs, const Value &rhs) {
|
||||
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
|
||||
}
|
||||
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
/// Identify clusters of operations that can be rank-specialized together. The
|
||||
/// required traits for clustered operations are:
|
||||
/// - Element-wise: All operations in the group must be element-wise. This
|
||||
/// allows to reshape operands before applying the operations as well as
|
||||
/// reshaping the result to the desired shape afterwards. This way, we can,
|
||||
/// e.g., apply unary ops to a completely flattened operand and restore the
|
||||
/// original shape afterwards.
|
||||
/// - Broadcasting semantics: All operations must implement broadcasting
|
||||
/// semantics. Most importantly, this allows extending operand shapes such
|
||||
/// that they match in rank.
|
||||
/// - Shape reification: All operations must implement
|
||||
/// `InferShapedTypeOpInterface`. This is later needed to compute and to
|
||||
/// restore the desired result shape.
|
||||
|
||||
bool IsClusterable(Operation *op) {
|
||||
if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
|
||||
unsigned int num_operands = op->getNumOperands();
|
||||
if (num_operands == 0) return false;
|
||||
if (num_operands == 1) return op->hasTrait<OpTrait::Elementwise>();
|
||||
return op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() &&
|
||||
op->hasTrait<chlo::OpTrait::Broadcasting>();
|
||||
}
|
||||
|
||||
struct RankSpecializationClusterPattern : public RewritePattern {
|
||||
explicit RankSpecializationClusterPattern(MLIRContext *ctx)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *root_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only apply to operations that have not been clustered yet.
|
||||
if (root_op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Only cluster when rank specialization is needed.
|
||||
if (!IsClusterable(root_op) ||
|
||||
!llvm::any_of(root_op->getOperandTypes(),
|
||||
[](Type ty) { return ty.isa<UnrankedTensorType>(); })) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Collect all collectively rank specializable ops.
|
||||
SmallVector<Operation *, 16> cluster;
|
||||
llvm::SmallSet<Value, 16> operand_set;
|
||||
llvm::SmallSet<Value, 16> result_set;
|
||||
Operation *new_op = root_op;
|
||||
while (new_op != nullptr && IsClusterable(new_op)) {
|
||||
// Find results that escape the cluster.
|
||||
for (OpOperand &use : new_op->getUses()) {
|
||||
if (!llvm::is_contained(cluster, use.getOwner()))
|
||||
result_set.insert(use.get());
|
||||
}
|
||||
|
||||
// Update cluster operands.
|
||||
for (OpResult v : new_op->getResults()) operand_set.erase(Value(v));
|
||||
for (OpOperand &v : new_op->getOpOperands()) operand_set.insert(v.get());
|
||||
|
||||
cluster.push_back(new_op);
|
||||
new_op = new_op->getPrevNode();
|
||||
}
|
||||
|
||||
// Create `RankSpecializationClusterOp`.
|
||||
auto operands = llvm::to_vector<16>(operand_set);
|
||||
auto results = llvm::to_vector<16>(result_set);
|
||||
auto result_types = llvm::to_vector<16>(
|
||||
llvm::map_range(result_set, [](Value v) { return v.getType(); }));
|
||||
Location loc = root_op->getLoc();
|
||||
auto cluster_op = rewriter.create<chlo::RankSpecializationClusterOp>(
|
||||
loc, result_types, operands);
|
||||
|
||||
// Create body block.
|
||||
auto operand_types = llvm::to_vector<16>(
|
||||
llvm::map_range(operand_set, [](Value v) { return v.getType(); }));
|
||||
Block *block = rewriter.createBlock(&cluster_op.body(), {}, operand_types);
|
||||
|
||||
// Copy operations into the body.
|
||||
BlockAndValueMapping bvm;
|
||||
for (auto it : llvm::zip(operands, block->getArguments()))
|
||||
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm);
|
||||
|
||||
// Create `RankSpecializationClusterYieldOp`.
|
||||
auto mapped_results = llvm::to_vector<16>(
|
||||
llvm::map_range(results, [&](Value v) { return bvm.lookup(v); }));
|
||||
rewriter.create<chlo::RankSpecializationClusterYieldOp>(loc,
|
||||
mapped_results);
|
||||
|
||||
// Replace original ops with the new results.
|
||||
for (auto it : llvm::zip(results, cluster_op.results()))
|
||||
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||
for (Operation *it : cluster) {
|
||||
if (it->getUses().empty()) {
|
||||
rewriter.eraseOp(it);
|
||||
continue;
|
||||
}
|
||||
auto replacements = llvm::to_vector<16>(llvm::map_range(
|
||||
it->getResults(), [&](Value v) { return bvm.lookup(v); }));
|
||||
rewriter.replaceOp(root_op, replacements);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct RankSpecializationClusterPass
|
||||
: public PassWrapper<RankSpecializationClusterPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
MLIRContext *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
mhlo::PopulateRankSpecializationClusterPatterns(ctx, &patterns);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void PopulateRankSpecializationClusterPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
patterns->insert<RankSpecializationClusterPattern>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
||||
return std::make_unique<RankSpecializationClusterPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
|
@ -0,0 +1,19 @@
|
|||
// RUN: mlir-hlo-opt %s --mhlo-rank-specialization-cluster | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @add_mul
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
||||
func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||
%arg2 : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG2]], %[[ARG0]], %[[ARG1]]) ( {
|
||||
// CHECK: ^bb0(%[[ARG2_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>):
|
||||
// CHECK: %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]]
|
||||
// CHECK: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[ARG2_]]
|
||||
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]])
|
||||
// CHECK: }) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = chlo.broadcast_multiply %arg0, %arg1
|
||||
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%1 = chlo.broadcast_add %0, %arg2
|
||||
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %1 : tensor<*xf32>
|
||||
}
|
Loading…
Reference in New Issue