[MLIR][HLO] Add `rank-specialization-to-scf` pass
Currently the lowering is only implemented for the unary case. The n-ary case will follow. PiperOrigin-RevId: 374162772
This commit is contained in:
parent
295ef229d6
commit
ccd70d5717
|
@ -128,3 +128,8 @@ def RankSpecializationClusterPass
|
|||
: Pass<"mhlo-rank-specialization-cluster", "FuncOp"> {
|
||||
let constructor = "createRankSpecializationClusterPass()";
|
||||
}
|
||||
|
||||
def RankSpecializationToSCFPass
|
||||
: Pass<"mhlo-rank-specialization-to-scf", "FuncOp"> {
|
||||
let constructor = "createRankSpecializationToSCFPass()";
|
||||
}
|
||||
|
|
|
@ -69,10 +69,12 @@ createLegalizeTrigonometricToApproximationPass();
|
|||
|
||||
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
|
||||
|
||||
/// Rank specialization passes.
|
||||
/// Rank specialization passes:
|
||||
/// - Find compatible operations and group them together in one rank
|
||||
/// specialization region.
|
||||
/// specialization cluster.
|
||||
/// - Lower rank specialization clusters to SCF and ranked operations.
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass();
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass();
|
||||
|
||||
std::unique_ptr<FunctionPass> createOptimizeMhloPass();
|
||||
std::unique_ptr<FunctionPass> createLowerComplexPass();
|
||||
|
|
|
@ -100,9 +100,11 @@ void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target);
|
|||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
|
||||
/// Populate rank specialization clustering patterns.
|
||||
/// Populate rank specialization clustering and lowering patterns.
|
||||
void PopulateRankSpecializationClusterPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
void PopulateRankSpecializationToSCFPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace mhlo
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
|
@ -34,7 +35,7 @@ limitations under the License.
|
|||
|
||||
namespace mlir {
|
||||
|
||||
// Needed to build `llvm::SmallSet`s of `mlir::Value`s.
|
||||
/// Needed to build `llvm::SmallSet`s of `mlir::Value`s.
|
||||
static bool operator<(const Value &lhs, const Value &rhs) {
|
||||
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
|
||||
}
|
||||
|
@ -164,6 +165,97 @@ struct RankSpecializationClusterPass
|
|||
}
|
||||
};
|
||||
|
||||
/// Lower rank specialization cluster to SCF.
|
||||
|
||||
Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
|
||||
auto unranked_ty = ty.dyn_cast<UnrankedTensorType>();
|
||||
if (!unranked_ty) return ty;
|
||||
SmallVector<int64_t, 8> shape(rank, ShapedType::kDynamicSize);
|
||||
return RankedTensorType::get(shape, unranked_ty.getElementType());
|
||||
}
|
||||
|
||||
/// Unary element-wise operations on unranked tensors can be applied to the
|
||||
/// flattened tensor and reshaped to the expected shape afterwards.
|
||||
struct LowerUnaryRankSpecializationClusterPattern
|
||||
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
||||
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only apply this to unary operations.
|
||||
if (op.operands().size() != 1) return failure();
|
||||
|
||||
// Compute flattened operand shape.
|
||||
Location loc = op.getLoc();
|
||||
Value arg = op.operands().front();
|
||||
Value shape = rewriter.create<shape::ShapeOfOp>(loc, arg);
|
||||
Value flat_shape = rewriter.create<tensor::FromElementsOp>(
|
||||
loc,
|
||||
rewriter
|
||||
.create<shape::NumElementsOp>(loc, rewriter.getIndexType(), shape)
|
||||
.result());
|
||||
|
||||
// Flatten operand.
|
||||
Value flat_arg = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, DeriveRankedTensorTypes(arg.getType(), /*rank=*/1), arg,
|
||||
flat_shape);
|
||||
|
||||
// Materialize ranked versions of the element-wise operations.
|
||||
BlockAndValueMapping bvm;
|
||||
bvm.map(op.getBody()->getArguments().front(), flat_arg);
|
||||
for (Operation &nested_op : op.getBody()->without_terminator()) {
|
||||
auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
|
||||
nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); }));
|
||||
auto ranked_result_types = llvm::to_vector<2>(llvm::map_range(
|
||||
nested_op.getResultTypes(),
|
||||
[](Type ty) { return DeriveRankedTensorTypes(ty, /*rank=*/1); }));
|
||||
OperationState ranked_op_state(loc, nested_op.getName().getStringRef(),
|
||||
mapped_operands, ranked_result_types,
|
||||
nested_op.getAttrs());
|
||||
Operation *ranked_op = rewriter.createOperation(ranked_op_state);
|
||||
for (auto it :
|
||||
llvm::zip(nested_op.getResults(), ranked_op->getResults())) {
|
||||
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||
}
|
||||
}
|
||||
|
||||
// Collect results and restore their shape. We don't have to reify a shape
|
||||
// computation in the unary case as the operand shapes to all the
|
||||
// element-wise ops can only be the unique input shape.
|
||||
SmallVector<Value> results;
|
||||
for (Value v : llvm::cast<chlo::RankSpecializationClusterYieldOp>(
|
||||
op.getBody()->getTerminator())
|
||||
.results()) {
|
||||
Value flat_result = bvm.lookup(v);
|
||||
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, v.getType(), flat_result, shape);
|
||||
results.push_back(result);
|
||||
}
|
||||
|
||||
// Replace the rank specialization cluster.
|
||||
rewriter.replaceOp(op, results);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct RankSpecializationToSCFPass
|
||||
: public PassWrapper<RankSpecializationToSCFPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect,
|
||||
shape::ShapeDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
MLIRContext *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
PopulateRankSpecializationToSCFPatterns(ctx, &patterns);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void PopulateRankSpecializationClusterPatterns(
|
||||
|
@ -171,9 +263,18 @@ void PopulateRankSpecializationClusterPatterns(
|
|||
patterns->insert<RankSpecializationClusterPattern>(context);
|
||||
}
|
||||
|
||||
void PopulateRankSpecializationToSCFPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
patterns->insert<LowerUnaryRankSpecializationClusterPattern>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
||||
return std::make_unique<RankSpecializationClusterPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass() {
|
||||
return std::make_unique<RankSpecializationToSCFPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf | FileCheck %s --check-prefix CHECK-SCF
|
||||
|
||||
// CHECK-LABEL: @add_mul
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
||||
|
@ -37,6 +38,18 @@ func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
|||
return %2 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-SCF-LABEL: @sqrt
|
||||
// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||
// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||
// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor<?xf32>)
|
||||
// CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor<?xf32>)
|
||||
// CHECK-SCF: %[[TMP2:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>)
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[TMP2]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
||||
// -----
|
||||
|
||||
// Don't cluster single ranked operation.
|
||||
|
|
Loading…
Reference in New Issue