[MLIR][HLO] Add `rank-specialization-cluster` pass

Add a pass to cluster unranked C/HLO operations in one
`chlo.rank_specialization_cluster` op. The C/HLO operations are moved to the
body of the operation. Later passes can use this to rank-specialize all these
operations together.

PiperOrigin-RevId: 373336725
This commit is contained in:
A. Unique TensorFlower 2021-05-12 03:45:09 -07:00 committed by TensorFlow MLIR Team
parent 7f84779868
commit 313d24bc8f
9 changed files with 267 additions and 1 deletions

23
BUILD
View File

@ -706,6 +706,28 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "rank_specialization",
srcs = ["lib/Dialect/mhlo/transforms/rank_specialization.cc"],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
],
deps = [
":hlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
cc_library(
name = "lhlo_legalize_to_gpu",
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc"],
@ -1079,6 +1101,7 @@ cc_library(
":mhlo_fusion",
":mhlo_to_mhlo_lowering_patterns",
":move_up_dynamic_broadcasts_for_fusion",
":rank_specialization",
":sink_constants_to_control_flow",
":test_passes",
":transform_unranked_hlo",

View File

@ -48,6 +48,22 @@ class HloClientDialect : public Dialect {
} // namespace chlo
} // namespace mlir
namespace mlir {
namespace chlo {
namespace OpTrait {
template <typename ConcreteType>
class BroadcastingElementwise
: public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
template <typename ConcreteType>
class Broadcasting
: public mlir::OpTrait::TraitBase<ConcreteType, Broadcasting> {};
} // namespace OpTrait
} // namespace chlo
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"

View File

@ -61,6 +61,17 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
let verifier = [{ return Verify(*this); }];
}
class HLOClient_NativeOpTrait<string name> : NativeOpTrait<name> {
let cppNamespace = "::mlir::chlo::OpTrait";
}
def HLOClient_Broadcasting : HLOClient_NativeOpTrait<"Broadcasting"> {
}
def HLOClient_BroadcastingElementwise
: HLOClient_NativeOpTrait<"BroadcastingElementwise"> {
}
//===----------------------------------------------------------------------===//
// CHLO binary elementwise op definitions.
// From the client perspective, each of these support both explicit rank
@ -78,6 +89,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
class HLOClient_BroadcastBinaryElementwiseOp<
string mnemonic, list<OpTrait> traits> : HLOClient_Op<mnemonic, traits # [
HLOClient_BroadcastingElementwise, HLOClient_Broadcasting,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> {
let arguments = (ins

View File

@ -121,3 +121,10 @@ def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
let constructor = "createTestUnfuseBatchNormPass()";
}
/// Rank specialization passes.
def RankSpecializationClusterPass
: Pass<"mhlo-rank-specialization-cluster", "FuncOp"> {
let constructor = "createRankSpecializationClusterPass()";
}

View File

@ -69,6 +69,11 @@ createLegalizeTrigonometricToApproximationPass();
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
/// Rank specialization passes.
/// - Find compatible operations and group them together in one rank
/// specialization region.
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass();
std::unique_ptr<FunctionPass> createOptimizeMhloPass();
std::unique_ptr<FunctionPass> createLowerComplexPass();
std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass();

View File

@ -100,6 +100,10 @@ void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target);
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns);
/// Populate rank specialization clustering patterns.
void PopulateRankSpecializationClusterPatterns(
MLIRContext *context, OwningRewritePatternList *patterns);
} // namespace mhlo
namespace chlo {

View File

@ -48,7 +48,6 @@ add_mlir_library(ChloPasses
)
add_mlir_library(MhloPasses
move_up_dynamic_broadcasts_for_fusion.cc
legalize_gather_to_torch_index_select.cc
legalize_trigonometric_to_approximation.cc
lower_complex.cc
@ -57,8 +56,10 @@ add_mlir_library(MhloPasses
materialize_broadcasts.cc
materialize_broadcasts_pass.cc
mhlo_fusion.cc
move_up_dynamic_broadcasts_for_fusion.cc
optimize_mhlo.cc
optimize_mhlo_pass.cc
rank_specialization.cc
sink_constants_to_control_flow.cc
test_infer_shaped_type_pass.cc
transform_unranked_hlo.cc

View File

@ -0,0 +1,179 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
// Needed to build `llvm::SmallSet`s of `mlir::Value`s.
static bool operator<(const Value &lhs, const Value &rhs) {
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
}
namespace mhlo {
namespace {
/// Identify clusters of operations that can be rank-specialized together. The
/// required traits for clustered operations are:
/// - Element-wise: All operations in the group must be element-wise. This
/// allows to reshape operands before applying the operations as well as
/// reshaping the result to the desired shape afterwards. This way, we can,
/// e.g., apply unary ops to a completely flattened operand and restore the
/// original shape afterwards.
/// - Broadcasting semantics: All operations must implement broadcasting
/// semantics. Most importantly, this allows extending operand shapes such
/// that they match in rank.
/// - Shape reification: All operations must implement
/// `InferShapedTypeOpInterface`. This is later needed to compute and to
/// restore the desired result shape.
bool IsClusterable(Operation *op) {
if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
unsigned int num_operands = op->getNumOperands();
if (num_operands == 0) return false;
if (num_operands == 1) return op->hasTrait<OpTrait::Elementwise>();
return op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() &&
op->hasTrait<chlo::OpTrait::Broadcasting>();
}
struct RankSpecializationClusterPattern : public RewritePattern {
explicit RankSpecializationClusterPattern(MLIRContext *ctx)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
LogicalResult matchAndRewrite(Operation *root_op,
PatternRewriter &rewriter) const override {
// Only apply to operations that have not been clustered yet.
if (root_op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
return failure();
}
// Only cluster when rank specialization is needed.
if (!IsClusterable(root_op) ||
!llvm::any_of(root_op->getOperandTypes(),
[](Type ty) { return ty.isa<UnrankedTensorType>(); })) {
return failure();
}
// Collect all collectively rank specializable ops.
SmallVector<Operation *, 16> cluster;
llvm::SmallSet<Value, 16> operand_set;
llvm::SmallSet<Value, 16> result_set;
Operation *new_op = root_op;
while (new_op != nullptr && IsClusterable(new_op)) {
// Find results that escape the cluster.
for (OpOperand &use : new_op->getUses()) {
if (!llvm::is_contained(cluster, use.getOwner()))
result_set.insert(use.get());
}
// Update cluster operands.
for (OpResult v : new_op->getResults()) operand_set.erase(Value(v));
for (OpOperand &v : new_op->getOpOperands()) operand_set.insert(v.get());
cluster.push_back(new_op);
new_op = new_op->getPrevNode();
}
// Create `RankSpecializationClusterOp`.
auto operands = llvm::to_vector<16>(operand_set);
auto results = llvm::to_vector<16>(result_set);
auto result_types = llvm::to_vector<16>(
llvm::map_range(result_set, [](Value v) { return v.getType(); }));
Location loc = root_op->getLoc();
auto cluster_op = rewriter.create<chlo::RankSpecializationClusterOp>(
loc, result_types, operands);
// Create body block.
auto operand_types = llvm::to_vector<16>(
llvm::map_range(operand_set, [](Value v) { return v.getType(); }));
Block *block = rewriter.createBlock(&cluster_op.body(), {}, operand_types);
// Copy operations into the body.
BlockAndValueMapping bvm;
for (auto it : llvm::zip(operands, block->getArguments()))
bvm.map(std::get<0>(it), std::get<1>(it));
rewriter.setInsertionPointToStart(block);
for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm);
// Create `RankSpecializationClusterYieldOp`.
auto mapped_results = llvm::to_vector<16>(
llvm::map_range(results, [&](Value v) { return bvm.lookup(v); }));
rewriter.create<chlo::RankSpecializationClusterYieldOp>(loc,
mapped_results);
// Replace original ops with the new results.
for (auto it : llvm::zip(results, cluster_op.results()))
bvm.map(std::get<0>(it), std::get<1>(it));
for (Operation *it : cluster) {
if (it->getUses().empty()) {
rewriter.eraseOp(it);
continue;
}
auto replacements = llvm::to_vector<16>(llvm::map_range(
it->getResults(), [&](Value v) { return bvm.lookup(v); }));
rewriter.replaceOp(root_op, replacements);
}
return success();
}
};
struct RankSpecializationClusterPass
: public PassWrapper<RankSpecializationClusterPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect>();
}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
mhlo::PopulateRankSpecializationClusterPatterns(ctx, &patterns);
if (failed(
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
void PopulateRankSpecializationClusterPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<RankSpecializationClusterPattern>(context);
}
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
return std::make_unique<RankSpecializationClusterPass>();
}
} // namespace mhlo
} // namespace mlir

View File

@ -0,0 +1,19 @@
// RUN: mlir-hlo-opt %s --mhlo-rank-specialization-cluster | FileCheck %s
// CHECK-LABEL: @add_mul
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
%arg2 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG2]], %[[ARG0]], %[[ARG1]]) ( {
// CHECK: ^bb0(%[[ARG2_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>):
// CHECK: %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]]
// CHECK: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[ARG2_]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]])
// CHECK: }) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[RES]]
%0 = chlo.broadcast_multiply %arg0, %arg1
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%1 = chlo.broadcast_add %0, %arg2
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
}