[MLIR][HLO] Add `rank-specialization-cluster` pass
Add a pass to cluster unranked C/HLO operations in one `chlo.rank_specialization_cluster` op. The C/HLO operations are moved to the body of the operation. Later passes can use this to rank-specialize all these operations together. PiperOrigin-RevId: 373336725
This commit is contained in:
		
							parent
							
								
									7f84779868
								
							
						
					
					
						commit
						313d24bc8f
					
				
							
								
								
									
										23
									
								
								BUILD
								
								
								
								
							
							
						
						
									
										23
									
								
								BUILD
								
								
								
								
							| 
						 | 
				
			
			@ -706,6 +706,28 @@ cc_library(
 | 
			
		|||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "rank_specialization",
 | 
			
		||||
    srcs = ["lib/Dialect/mhlo/transforms/rank_specialization.cc"],
 | 
			
		||||
    hdrs = [
 | 
			
		||||
        "include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
 | 
			
		||||
        "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hlo",
 | 
			
		||||
        "@llvm-project//llvm:Support",
 | 
			
		||||
        "@llvm-project//mlir:IR",
 | 
			
		||||
        "@llvm-project//mlir:InferTypeOpInterface",
 | 
			
		||||
        "@llvm-project//mlir:Pass",
 | 
			
		||||
        "@llvm-project//mlir:SCFDialect",
 | 
			
		||||
        "@llvm-project//mlir:Shape",
 | 
			
		||||
        "@llvm-project//mlir:StandardOps",
 | 
			
		||||
        "@llvm-project//mlir:TensorDialect",
 | 
			
		||||
        "@llvm-project//mlir:Transforms",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "lhlo_legalize_to_gpu",
 | 
			
		||||
    srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc"],
 | 
			
		||||
| 
						 | 
				
			
			@ -1079,6 +1101,7 @@ cc_library(
 | 
			
		|||
        ":mhlo_fusion",
 | 
			
		||||
        ":mhlo_to_mhlo_lowering_patterns",
 | 
			
		||||
        ":move_up_dynamic_broadcasts_for_fusion",
 | 
			
		||||
        ":rank_specialization",
 | 
			
		||||
        ":sink_constants_to_control_flow",
 | 
			
		||||
        ":test_passes",
 | 
			
		||||
        ":transform_unranked_hlo",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -48,6 +48,22 @@ class HloClientDialect : public Dialect {
 | 
			
		|||
}  // namespace chlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace chlo {
 | 
			
		||||
namespace OpTrait {
 | 
			
		||||
 | 
			
		||||
template <typename ConcreteType>
 | 
			
		||||
class BroadcastingElementwise
 | 
			
		||||
    : public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
 | 
			
		||||
 | 
			
		||||
template <typename ConcreteType>
 | 
			
		||||
class Broadcasting
 | 
			
		||||
    : public mlir::OpTrait::TraitBase<ConcreteType, Broadcasting> {};
 | 
			
		||||
 | 
			
		||||
}  // namespace OpTrait
 | 
			
		||||
}  // namespace chlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#define GET_OP_CLASSES
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -61,6 +61,17 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
 | 
			
		|||
  let verifier = [{ return Verify(*this); }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class HLOClient_NativeOpTrait<string name> : NativeOpTrait<name> {
 | 
			
		||||
  let cppNamespace = "::mlir::chlo::OpTrait";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_Broadcasting : HLOClient_NativeOpTrait<"Broadcasting"> {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastingElementwise
 | 
			
		||||
    : HLOClient_NativeOpTrait<"BroadcastingElementwise"> {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// CHLO binary elementwise op definitions.
 | 
			
		||||
// From the client perspective, each of these support both explicit rank
 | 
			
		||||
| 
						 | 
				
			
			@ -78,6 +89,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
 | 
			
		|||
 | 
			
		||||
class HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
  string mnemonic, list<OpTrait> traits> : HLOClient_Op<mnemonic, traits # [
 | 
			
		||||
      HLOClient_BroadcastingElementwise, HLOClient_Broadcasting,
 | 
			
		||||
      DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
 | 
			
		||||
      "inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -121,3 +121,10 @@ def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
 | 
			
		|||
  let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
 | 
			
		||||
  let constructor = "createTestUnfuseBatchNormPass()";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Rank specialization passes.
 | 
			
		||||
 | 
			
		||||
def RankSpecializationClusterPass
 | 
			
		||||
    : Pass<"mhlo-rank-specialization-cluster", "FuncOp"> {
 | 
			
		||||
  let constructor = "createRankSpecializationClusterPass()";
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -69,6 +69,11 @@ createLegalizeTrigonometricToApproximationPass();
 | 
			
		|||
 | 
			
		||||
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
 | 
			
		||||
 | 
			
		||||
/// Rank specialization passes.
 | 
			
		||||
///   - Find compatible operations and group them together in one rank
 | 
			
		||||
///     specialization region.
 | 
			
		||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass();
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<FunctionPass> createOptimizeMhloPass();
 | 
			
		||||
std::unique_ptr<FunctionPass> createLowerComplexPass();
 | 
			
		||||
std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -100,6 +100,10 @@ void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target);
 | 
			
		|||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
 | 
			
		||||
    MLIRContext *context, OwningRewritePatternList *patterns);
 | 
			
		||||
 | 
			
		||||
/// Populate rank specialization clustering patterns.
 | 
			
		||||
void PopulateRankSpecializationClusterPatterns(
 | 
			
		||||
    MLIRContext *context, OwningRewritePatternList *patterns);
 | 
			
		||||
 | 
			
		||||
}  // namespace mhlo
 | 
			
		||||
 | 
			
		||||
namespace chlo {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -48,7 +48,6 @@ add_mlir_library(ChloPasses
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
add_mlir_library(MhloPasses
 | 
			
		||||
  move_up_dynamic_broadcasts_for_fusion.cc
 | 
			
		||||
  legalize_gather_to_torch_index_select.cc
 | 
			
		||||
  legalize_trigonometric_to_approximation.cc
 | 
			
		||||
  lower_complex.cc
 | 
			
		||||
| 
						 | 
				
			
			@ -57,8 +56,10 @@ add_mlir_library(MhloPasses
 | 
			
		|||
  materialize_broadcasts.cc
 | 
			
		||||
  materialize_broadcasts_pass.cc
 | 
			
		||||
  mhlo_fusion.cc
 | 
			
		||||
  move_up_dynamic_broadcasts_for_fusion.cc
 | 
			
		||||
  optimize_mhlo.cc
 | 
			
		||||
  optimize_mhlo_pass.cc
 | 
			
		||||
  rank_specialization.cc
 | 
			
		||||
  sink_constants_to_control_flow.cc
 | 
			
		||||
  test_infer_shaped_type_pass.cc
 | 
			
		||||
  transform_unranked_hlo.cc
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,179 @@
 | 
			
		|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "llvm/ADT/STLExtras.h"
 | 
			
		||||
#include "llvm/ADT/SmallSet.h"
 | 
			
		||||
#include "llvm/ADT/SmallVector.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
 | 
			
		||||
#include "mlir/IR/BlockAndValueMapping.h"
 | 
			
		||||
#include "mlir/IR/BuiltinOps.h"
 | 
			
		||||
#include "mlir/IR/BuiltinTypes.h"
 | 
			
		||||
#include "mlir/IR/Operation.h"
 | 
			
		||||
#include "mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
 | 
			
		||||
#include "mlir/Pass/Pass.h"
 | 
			
		||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
// Needed to build `llvm::SmallSet`s of `mlir::Value`s.
 | 
			
		||||
static bool operator<(const Value &lhs, const Value &rhs) {
 | 
			
		||||
  return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace mhlo {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
/// Identify clusters of operations that can be rank-specialized together. The
 | 
			
		||||
/// required traits for clustered operations are:
 | 
			
		||||
///   - Element-wise: All operations in the group must be element-wise. This
 | 
			
		||||
///     allows to reshape operands before applying the operations as well as
 | 
			
		||||
///     reshaping the result to the desired shape afterwards. This way, we can,
 | 
			
		||||
///     e.g., apply unary ops to a completely flattened operand and restore the
 | 
			
		||||
///     original shape afterwards.
 | 
			
		||||
///   - Broadcasting semantics: All operations must implement broadcasting
 | 
			
		||||
///     semantics. Most importantly, this allows extending operand shapes such
 | 
			
		||||
///     that they match in rank.
 | 
			
		||||
///   - Shape reification: All operations must implement
 | 
			
		||||
///     `InferShapedTypeOpInterface`. This is later needed to compute and to
 | 
			
		||||
///     restore the desired result shape.
 | 
			
		||||
 | 
			
		||||
bool IsClusterable(Operation *op) {
 | 
			
		||||
  if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
 | 
			
		||||
  unsigned int num_operands = op->getNumOperands();
 | 
			
		||||
  if (num_operands == 0) return false;
 | 
			
		||||
  if (num_operands == 1) return op->hasTrait<OpTrait::Elementwise>();
 | 
			
		||||
  return op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() &&
 | 
			
		||||
         op->hasTrait<chlo::OpTrait::Broadcasting>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct RankSpecializationClusterPattern : public RewritePattern {
 | 
			
		||||
  explicit RankSpecializationClusterPattern(MLIRContext *ctx)
 | 
			
		||||
      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(Operation *root_op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    // Only apply to operations that have not been clustered yet.
 | 
			
		||||
    if (root_op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Only cluster when rank specialization is needed.
 | 
			
		||||
    if (!IsClusterable(root_op) ||
 | 
			
		||||
        !llvm::any_of(root_op->getOperandTypes(),
 | 
			
		||||
                      [](Type ty) { return ty.isa<UnrankedTensorType>(); })) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Collect all collectively rank specializable ops.
 | 
			
		||||
    SmallVector<Operation *, 16> cluster;
 | 
			
		||||
    llvm::SmallSet<Value, 16> operand_set;
 | 
			
		||||
    llvm::SmallSet<Value, 16> result_set;
 | 
			
		||||
    Operation *new_op = root_op;
 | 
			
		||||
    while (new_op != nullptr && IsClusterable(new_op)) {
 | 
			
		||||
      // Find results that escape the cluster.
 | 
			
		||||
      for (OpOperand &use : new_op->getUses()) {
 | 
			
		||||
        if (!llvm::is_contained(cluster, use.getOwner()))
 | 
			
		||||
          result_set.insert(use.get());
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Update cluster operands.
 | 
			
		||||
      for (OpResult v : new_op->getResults()) operand_set.erase(Value(v));
 | 
			
		||||
      for (OpOperand &v : new_op->getOpOperands()) operand_set.insert(v.get());
 | 
			
		||||
 | 
			
		||||
      cluster.push_back(new_op);
 | 
			
		||||
      new_op = new_op->getPrevNode();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Create `RankSpecializationClusterOp`.
 | 
			
		||||
    auto operands = llvm::to_vector<16>(operand_set);
 | 
			
		||||
    auto results = llvm::to_vector<16>(result_set);
 | 
			
		||||
    auto result_types = llvm::to_vector<16>(
 | 
			
		||||
        llvm::map_range(result_set, [](Value v) { return v.getType(); }));
 | 
			
		||||
    Location loc = root_op->getLoc();
 | 
			
		||||
    auto cluster_op = rewriter.create<chlo::RankSpecializationClusterOp>(
 | 
			
		||||
        loc, result_types, operands);
 | 
			
		||||
 | 
			
		||||
    // Create body block.
 | 
			
		||||
    auto operand_types = llvm::to_vector<16>(
 | 
			
		||||
        llvm::map_range(operand_set, [](Value v) { return v.getType(); }));
 | 
			
		||||
    Block *block = rewriter.createBlock(&cluster_op.body(), {}, operand_types);
 | 
			
		||||
 | 
			
		||||
    // Copy operations into the body.
 | 
			
		||||
    BlockAndValueMapping bvm;
 | 
			
		||||
    for (auto it : llvm::zip(operands, block->getArguments()))
 | 
			
		||||
      bvm.map(std::get<0>(it), std::get<1>(it));
 | 
			
		||||
    rewriter.setInsertionPointToStart(block);
 | 
			
		||||
    for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm);
 | 
			
		||||
 | 
			
		||||
    // Create `RankSpecializationClusterYieldOp`.
 | 
			
		||||
    auto mapped_results = llvm::to_vector<16>(
 | 
			
		||||
        llvm::map_range(results, [&](Value v) { return bvm.lookup(v); }));
 | 
			
		||||
    rewriter.create<chlo::RankSpecializationClusterYieldOp>(loc,
 | 
			
		||||
                                                            mapped_results);
 | 
			
		||||
 | 
			
		||||
    // Replace original ops with the new results.
 | 
			
		||||
    for (auto it : llvm::zip(results, cluster_op.results()))
 | 
			
		||||
      bvm.map(std::get<0>(it), std::get<1>(it));
 | 
			
		||||
    for (Operation *it : cluster) {
 | 
			
		||||
      if (it->getUses().empty()) {
 | 
			
		||||
        rewriter.eraseOp(it);
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
      auto replacements = llvm::to_vector<16>(llvm::map_range(
 | 
			
		||||
          it->getResults(), [&](Value v) { return bvm.lookup(v); }));
 | 
			
		||||
      rewriter.replaceOp(root_op, replacements);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct RankSpecializationClusterPass
 | 
			
		||||
    : public PassWrapper<RankSpecializationClusterPass, FunctionPass> {
 | 
			
		||||
  void getDependentDialects(DialectRegistry ®istry) const override {
 | 
			
		||||
    registry.insert<mhlo::MhloDialect, chlo::HloClientDialect>();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    MLIRContext *ctx = &getContext();
 | 
			
		||||
    RewritePatternSet patterns(ctx);
 | 
			
		||||
    mhlo::PopulateRankSpecializationClusterPatterns(ctx, &patterns);
 | 
			
		||||
    if (failed(
 | 
			
		||||
            applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
 | 
			
		||||
      return signalPassFailure();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
void PopulateRankSpecializationClusterPatterns(
 | 
			
		||||
    MLIRContext *context, OwningRewritePatternList *patterns) {
 | 
			
		||||
  patterns->insert<RankSpecializationClusterPattern>(context);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
 | 
			
		||||
  return std::make_unique<RankSpecializationClusterPass>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mhlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,19 @@
 | 
			
		|||
// RUN: mlir-hlo-opt %s --mhlo-rank-specialization-cluster | FileCheck %s
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: @add_mul
 | 
			
		||||
// CHECK-SAME:  (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
 | 
			
		||||
func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
 | 
			
		||||
    %arg2 : tensor<*xf32>) -> tensor<*xf32> {
 | 
			
		||||
  // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG2]], %[[ARG0]], %[[ARG1]]) ( {
 | 
			
		||||
  // CHECK: ^bb0(%[[ARG2_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>):
 | 
			
		||||
  // CHECK:   %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]]
 | 
			
		||||
  // CHECK:   %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[ARG2_]]
 | 
			
		||||
  // CHECK:   "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]])
 | 
			
		||||
  // CHECK: }) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
 | 
			
		||||
  // CHECK: return %[[RES]]
 | 
			
		||||
  %0 = chlo.broadcast_multiply %arg0, %arg1
 | 
			
		||||
      : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
 | 
			
		||||
  %1 = chlo.broadcast_add %0, %arg2
 | 
			
		||||
      : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
 | 
			
		||||
  return %1 : tensor<*xf32>
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue