diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index b3ea455..8bd261d 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -128,3 +128,8 @@ def RankSpecializationClusterPass : Pass<"mhlo-rank-specialization-cluster", "FuncOp"> { let constructor = "createRankSpecializationClusterPass()"; } + +def RankSpecializationToSCFPass + : Pass<"mhlo-rank-specialization-to-scf", "FuncOp"> { + let constructor = "createRankSpecializationToSCFPass()"; +} diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 76e27b2..dcdfe11 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -69,10 +69,12 @@ createLegalizeTrigonometricToApproximationPass(); std::unique_ptr createMoveUpDynamicBroadcastsForFusionPass(); -/// Rank specialization passes. +/// Rank specialization passes: /// - Find compatible operations and group them together in one rank -/// specialization region. +/// specialization cluster. +/// - Lower rank specialization clusters to SCF and ranked operations. std::unique_ptr createRankSpecializationClusterPass(); +std::unique_ptr createRankSpecializationToSCFPass(); std::unique_ptr createOptimizeMhloPass(); std::unique_ptr createLowerComplexPass(); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 3da5b97..a86c56d 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -100,9 +100,11 @@ void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target); void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MLIRContext *context, OwningRewritePatternList *patterns); -/// Populate rank specialization clustering patterns. +/// Populate rank specialization clustering and lowering patterns. void PopulateRankSpecializationClusterPatterns( MLIRContext *context, OwningRewritePatternList *patterns); +void PopulateRankSpecializationToSCFPatterns( + MLIRContext *context, OwningRewritePatternList *patterns); } // namespace mhlo diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 88e4dd4..42d8e19 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -34,7 +35,7 @@ limitations under the License. namespace mlir { -// Needed to build `llvm::SmallSet`s of `mlir::Value`s. +/// Needed to build `llvm::SmallSet`s of `mlir::Value`s. static bool operator<(const Value &lhs, const Value &rhs) { return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); } @@ -164,6 +165,97 @@ struct RankSpecializationClusterPass } }; +/// Lower rank specialization cluster to SCF. + +Type DeriveRankedTensorTypes(Type ty, int64_t rank) { + auto unranked_ty = ty.dyn_cast(); + if (!unranked_ty) return ty; + SmallVector shape(rank, ShapedType::kDynamicSize); + return RankedTensorType::get(shape, unranked_ty.getElementType()); +} + +/// Unary element-wise operations on unranked tensors can be applied to the +/// flattened tensor and reshaped to the expected shape afterwards. +struct LowerUnaryRankSpecializationClusterPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, + PatternRewriter &rewriter) const override { + // Only apply this to unary operations. + if (op.operands().size() != 1) return failure(); + + // Compute flattened operand shape. + Location loc = op.getLoc(); + Value arg = op.operands().front(); + Value shape = rewriter.create(loc, arg); + Value flat_shape = rewriter.create( + loc, + rewriter + .create(loc, rewriter.getIndexType(), shape) + .result()); + + // Flatten operand. + Value flat_arg = rewriter.create( + loc, DeriveRankedTensorTypes(arg.getType(), /*rank=*/1), arg, + flat_shape); + + // Materialize ranked versions of the element-wise operations. + BlockAndValueMapping bvm; + bvm.map(op.getBody()->getArguments().front(), flat_arg); + for (Operation &nested_op : op.getBody()->without_terminator()) { + auto mapped_operands = llvm::to_vector<4>(llvm::map_range( + nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); })); + auto ranked_result_types = llvm::to_vector<2>(llvm::map_range( + nested_op.getResultTypes(), + [](Type ty) { return DeriveRankedTensorTypes(ty, /*rank=*/1); })); + OperationState ranked_op_state(loc, nested_op.getName().getStringRef(), + mapped_operands, ranked_result_types, + nested_op.getAttrs()); + Operation *ranked_op = rewriter.createOperation(ranked_op_state); + for (auto it : + llvm::zip(nested_op.getResults(), ranked_op->getResults())) { + bvm.map(std::get<0>(it), std::get<1>(it)); + } + } + + // Collect results and restore their shape. We don't have to reify a shape + // computation in the unary case as the operand shapes to all the + // element-wise ops can only be the unique input shape. + SmallVector results; + for (Value v : llvm::cast( + op.getBody()->getTerminator()) + .results()) { + Value flat_result = bvm.lookup(v); + Value result = rewriter.create( + loc, v.getType(), flat_result, shape); + results.push_back(result); + } + + // Replace the rank specialization cluster. + rewriter.replaceOp(op, results); + return success(); + } +}; + +struct RankSpecializationToSCFPass + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + PopulateRankSpecializationToSCFPatterns(ctx, &patterns); + if (failed( + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + } // namespace void PopulateRankSpecializationClusterPatterns( @@ -171,9 +263,18 @@ void PopulateRankSpecializationClusterPatterns( patterns->insert(context); } +void PopulateRankSpecializationToSCFPatterns( + MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert(context); +} + std::unique_ptr createRankSpecializationClusterPass() { return std::make_unique(); } +std::unique_ptr createRankSpecializationToSCFPass() { + return std::make_unique(); +} + } // namespace mhlo } // namespace mlir diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index cc0e132..32c3ad2 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -1,4 +1,5 @@ // RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s +// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf | FileCheck %s --check-prefix CHECK-SCF // CHECK-LABEL: @add_mul // CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) @@ -37,6 +38,18 @@ func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> { return %2 : tensor<*xf32> } +// CHECK-SCF-LABEL: @sqrt +// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>) +// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] +// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]] +// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor) +// CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor) +// CHECK-SCF: %[[TMP2:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor) +// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[TMP2]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-SCF: return %[[RES]] + // ----- // Don't cluster single ranked operation.