diff --git a/BUILD b/BUILD index 549ee59..06133e2 100644 --- a/BUILD +++ b/BUILD @@ -706,6 +706,28 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "rank_specialization", + srcs = ["lib/Dialect/mhlo/transforms/rank_specialization.cc"], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", + "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", + ], + deps = [ + ":hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "lhlo_legalize_to_gpu", srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc"], @@ -1079,6 +1101,7 @@ cc_library( ":mhlo_fusion", ":mhlo_to_mhlo_lowering_patterns", ":move_up_dynamic_broadcasts_for_fusion", + ":rank_specialization", ":sink_constants_to_control_flow", ":test_passes", ":transform_unranked_hlo", diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index 87f3025..b314bd4 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -48,6 +48,22 @@ class HloClientDialect : public Dialect { } // namespace chlo } // namespace mlir +namespace mlir { +namespace chlo { +namespace OpTrait { + +template +class BroadcastingElementwise + : public mlir::OpTrait::TraitBase {}; + +template +class Broadcasting + : public mlir::OpTrait::TraitBase {}; + +} // namespace OpTrait +} // namespace chlo +} // namespace mlir + #define GET_OP_CLASSES #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index ce384d2..b7c5417 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -61,6 +61,17 @@ class HLOClient_Op traits> : let verifier = [{ return Verify(*this); }]; } +class HLOClient_NativeOpTrait : NativeOpTrait { + let cppNamespace = "::mlir::chlo::OpTrait"; +} + +def HLOClient_Broadcasting : HLOClient_NativeOpTrait<"Broadcasting"> { +} + +def HLOClient_BroadcastingElementwise + : HLOClient_NativeOpTrait<"BroadcastingElementwise"> { +} + //===----------------------------------------------------------------------===// // CHLO binary elementwise op definitions. // From the client perspective, each of these support both explicit rank @@ -78,6 +89,7 @@ class HLOClient_Op traits> : class HLOClient_BroadcastBinaryElementwiseOp< string mnemonic, list traits> : HLOClient_Op]> { let arguments = (ins diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index 34e7722..b3ea455 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -121,3 +121,10 @@ def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> { let summary = "Test pass for materializing 'broadcast_dimensions' attributes."; let constructor = "createTestUnfuseBatchNormPass()"; } + +/// Rank specialization passes. + +def RankSpecializationClusterPass + : Pass<"mhlo-rank-specialization-cluster", "FuncOp"> { + let constructor = "createRankSpecializationClusterPass()"; +} diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 82b3d1d..76e27b2 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -69,6 +69,11 @@ createLegalizeTrigonometricToApproximationPass(); std::unique_ptr createMoveUpDynamicBroadcastsForFusionPass(); +/// Rank specialization passes. +/// - Find compatible operations and group them together in one rank +/// specialization region. +std::unique_ptr createRankSpecializationClusterPass(); + std::unique_ptr createOptimizeMhloPass(); std::unique_ptr createLowerComplexPass(); std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass(); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index bd12379..3da5b97 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -100,6 +100,10 @@ void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target); void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MLIRContext *context, OwningRewritePatternList *patterns); +/// Populate rank specialization clustering patterns. +void PopulateRankSpecializationClusterPatterns( + MLIRContext *context, OwningRewritePatternList *patterns); + } // namespace mhlo namespace chlo { diff --git a/lib/Dialect/mhlo/transforms/CMakeLists.txt b/lib/Dialect/mhlo/transforms/CMakeLists.txt index f27b7a1..8949ac3 100644 --- a/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -48,7 +48,6 @@ add_mlir_library(ChloPasses ) add_mlir_library(MhloPasses - move_up_dynamic_broadcasts_for_fusion.cc legalize_gather_to_torch_index_select.cc legalize_trigonometric_to_approximation.cc lower_complex.cc @@ -57,8 +56,10 @@ add_mlir_library(MhloPasses materialize_broadcasts.cc materialize_broadcasts_pass.cc mhlo_fusion.cc + move_up_dynamic_broadcasts_for_fusion.cc optimize_mhlo.cc optimize_mhlo_pass.cc + rank_specialization.cc sink_constants_to_control_flow.cc test_infer_shaped_type_pass.cc transform_unranked_hlo.cc diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc new file mode 100644 index 0000000..88e4dd4 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -0,0 +1,179 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +==============================================================================*/ + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { + +// Needed to build `llvm::SmallSet`s of `mlir::Value`s. +static bool operator<(const Value &lhs, const Value &rhs) { + return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); +} + +namespace mhlo { +namespace { + +/// Identify clusters of operations that can be rank-specialized together. The +/// required traits for clustered operations are: +/// - Element-wise: All operations in the group must be element-wise. This +/// allows to reshape operands before applying the operations as well as +/// reshaping the result to the desired shape afterwards. This way, we can, +/// e.g., apply unary ops to a completely flattened operand and restore the +/// original shape afterwards. +/// - Broadcasting semantics: All operations must implement broadcasting +/// semantics. Most importantly, this allows extending operand shapes such +/// that they match in rank. +/// - Shape reification: All operations must implement +/// `InferShapedTypeOpInterface`. This is later needed to compute and to +/// restore the desired result shape. + +bool IsClusterable(Operation *op) { + if (!llvm::isa(op)) return false; + unsigned int num_operands = op->getNumOperands(); + if (num_operands == 0) return false; + if (num_operands == 1) return op->hasTrait(); + return op->hasTrait() && + op->hasTrait(); +} + +struct RankSpecializationClusterPattern : public RewritePattern { + explicit RankSpecializationClusterPattern(MLIRContext *ctx) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite(Operation *root_op, + PatternRewriter &rewriter) const override { + // Only apply to operations that have not been clustered yet. + if (root_op->getParentOfType()) { + return failure(); + } + + // Only cluster when rank specialization is needed. + if (!IsClusterable(root_op) || + !llvm::any_of(root_op->getOperandTypes(), + [](Type ty) { return ty.isa(); })) { + return failure(); + } + + // Collect all collectively rank specializable ops. + SmallVector cluster; + llvm::SmallSet operand_set; + llvm::SmallSet result_set; + Operation *new_op = root_op; + while (new_op != nullptr && IsClusterable(new_op)) { + // Find results that escape the cluster. + for (OpOperand &use : new_op->getUses()) { + if (!llvm::is_contained(cluster, use.getOwner())) + result_set.insert(use.get()); + } + + // Update cluster operands. + for (OpResult v : new_op->getResults()) operand_set.erase(Value(v)); + for (OpOperand &v : new_op->getOpOperands()) operand_set.insert(v.get()); + + cluster.push_back(new_op); + new_op = new_op->getPrevNode(); + } + + // Create `RankSpecializationClusterOp`. + auto operands = llvm::to_vector<16>(operand_set); + auto results = llvm::to_vector<16>(result_set); + auto result_types = llvm::to_vector<16>( + llvm::map_range(result_set, [](Value v) { return v.getType(); })); + Location loc = root_op->getLoc(); + auto cluster_op = rewriter.create( + loc, result_types, operands); + + // Create body block. + auto operand_types = llvm::to_vector<16>( + llvm::map_range(operand_set, [](Value v) { return v.getType(); })); + Block *block = rewriter.createBlock(&cluster_op.body(), {}, operand_types); + + // Copy operations into the body. + BlockAndValueMapping bvm; + for (auto it : llvm::zip(operands, block->getArguments())) + bvm.map(std::get<0>(it), std::get<1>(it)); + rewriter.setInsertionPointToStart(block); + for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm); + + // Create `RankSpecializationClusterYieldOp`. + auto mapped_results = llvm::to_vector<16>( + llvm::map_range(results, [&](Value v) { return bvm.lookup(v); })); + rewriter.create(loc, + mapped_results); + + // Replace original ops with the new results. + for (auto it : llvm::zip(results, cluster_op.results())) + bvm.map(std::get<0>(it), std::get<1>(it)); + for (Operation *it : cluster) { + if (it->getUses().empty()) { + rewriter.eraseOp(it); + continue; + } + auto replacements = llvm::to_vector<16>(llvm::map_range( + it->getResults(), [&](Value v) { return bvm.lookup(v); })); + rewriter.replaceOp(root_op, replacements); + } + + return success(); + } +}; + +struct RankSpecializationClusterPass + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + mhlo::PopulateRankSpecializationClusterPatterns(ctx, &patterns); + if (failed( + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +void PopulateRankSpecializationClusterPatterns( + MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert(context); +} + +std::unique_ptr createRankSpecializationClusterPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir new file mode 100644 index 0000000..8e8b646 --- /dev/null +++ b/tests/rank-specialization.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-hlo-opt %s --mhlo-rank-specialization-cluster | FileCheck %s + +// CHECK-LABEL: @add_mul +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) +func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, + %arg2 : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG2]], %[[ARG0]], %[[ARG1]]) ( { + // CHECK: ^bb0(%[[ARG2_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>): + // CHECK: %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]] + // CHECK: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[ARG2_]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]]) + // CHECK: }) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + // CHECK: return %[[RES]] + %0 = chlo.broadcast_multiply %arg0, %arg1 + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = chlo.broadcast_add %0, %arg2 + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %1 : tensor<*xf32> +}