[MLIR][HLO] Add `rank-specialization-to-scf` pass
Currently the lowering is only implemented for the unary case. The n-ary case will follow. PiperOrigin-RevId: 374162772
This commit is contained in:
parent
295ef229d6
commit
ccd70d5717
|
@ -128,3 +128,8 @@ def RankSpecializationClusterPass
|
||||||
: Pass<"mhlo-rank-specialization-cluster", "FuncOp"> {
|
: Pass<"mhlo-rank-specialization-cluster", "FuncOp"> {
|
||||||
let constructor = "createRankSpecializationClusterPass()";
|
let constructor = "createRankSpecializationClusterPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def RankSpecializationToSCFPass
|
||||||
|
: Pass<"mhlo-rank-specialization-to-scf", "FuncOp"> {
|
||||||
|
let constructor = "createRankSpecializationToSCFPass()";
|
||||||
|
}
|
||||||
|
|
|
@ -69,10 +69,12 @@ createLegalizeTrigonometricToApproximationPass();
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
|
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
|
||||||
|
|
||||||
/// Rank specialization passes.
|
/// Rank specialization passes:
|
||||||
/// - Find compatible operations and group them together in one rank
|
/// - Find compatible operations and group them together in one rank
|
||||||
/// specialization region.
|
/// specialization cluster.
|
||||||
|
/// - Lower rank specialization clusters to SCF and ranked operations.
|
||||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass();
|
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass();
|
||||||
|
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass();
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createOptimizeMhloPass();
|
std::unique_ptr<FunctionPass> createOptimizeMhloPass();
|
||||||
std::unique_ptr<FunctionPass> createLowerComplexPass();
|
std::unique_ptr<FunctionPass> createLowerComplexPass();
|
||||||
|
|
|
@ -100,9 +100,11 @@ void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target);
|
||||||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
/// Populate rank specialization clustering patterns.
|
/// Populate rank specialization clustering and lowering patterns.
|
||||||
void PopulateRankSpecializationClusterPatterns(
|
void PopulateRankSpecializationClusterPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||||
|
void PopulateRankSpecializationToSCFPatterns(
|
||||||
|
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
|
@ -34,7 +35,7 @@ limitations under the License.
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
// Needed to build `llvm::SmallSet`s of `mlir::Value`s.
|
/// Needed to build `llvm::SmallSet`s of `mlir::Value`s.
|
||||||
static bool operator<(const Value &lhs, const Value &rhs) {
|
static bool operator<(const Value &lhs, const Value &rhs) {
|
||||||
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
|
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
|
||||||
}
|
}
|
||||||
|
@ -164,6 +165,97 @@ struct RankSpecializationClusterPass
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Lower rank specialization cluster to SCF.
|
||||||
|
|
||||||
|
Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
|
||||||
|
auto unranked_ty = ty.dyn_cast<UnrankedTensorType>();
|
||||||
|
if (!unranked_ty) return ty;
|
||||||
|
SmallVector<int64_t, 8> shape(rank, ShapedType::kDynamicSize);
|
||||||
|
return RankedTensorType::get(shape, unranked_ty.getElementType());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unary element-wise operations on unranked tensors can be applied to the
|
||||||
|
/// flattened tensor and reshaped to the expected shape afterwards.
|
||||||
|
struct LowerUnaryRankSpecializationClusterPattern
|
||||||
|
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
||||||
|
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Only apply this to unary operations.
|
||||||
|
if (op.operands().size() != 1) return failure();
|
||||||
|
|
||||||
|
// Compute flattened operand shape.
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value arg = op.operands().front();
|
||||||
|
Value shape = rewriter.create<shape::ShapeOfOp>(loc, arg);
|
||||||
|
Value flat_shape = rewriter.create<tensor::FromElementsOp>(
|
||||||
|
loc,
|
||||||
|
rewriter
|
||||||
|
.create<shape::NumElementsOp>(loc, rewriter.getIndexType(), shape)
|
||||||
|
.result());
|
||||||
|
|
||||||
|
// Flatten operand.
|
||||||
|
Value flat_arg = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, DeriveRankedTensorTypes(arg.getType(), /*rank=*/1), arg,
|
||||||
|
flat_shape);
|
||||||
|
|
||||||
|
// Materialize ranked versions of the element-wise operations.
|
||||||
|
BlockAndValueMapping bvm;
|
||||||
|
bvm.map(op.getBody()->getArguments().front(), flat_arg);
|
||||||
|
for (Operation &nested_op : op.getBody()->without_terminator()) {
|
||||||
|
auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
|
||||||
|
nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); }));
|
||||||
|
auto ranked_result_types = llvm::to_vector<2>(llvm::map_range(
|
||||||
|
nested_op.getResultTypes(),
|
||||||
|
[](Type ty) { return DeriveRankedTensorTypes(ty, /*rank=*/1); }));
|
||||||
|
OperationState ranked_op_state(loc, nested_op.getName().getStringRef(),
|
||||||
|
mapped_operands, ranked_result_types,
|
||||||
|
nested_op.getAttrs());
|
||||||
|
Operation *ranked_op = rewriter.createOperation(ranked_op_state);
|
||||||
|
for (auto it :
|
||||||
|
llvm::zip(nested_op.getResults(), ranked_op->getResults())) {
|
||||||
|
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect results and restore their shape. We don't have to reify a shape
|
||||||
|
// computation in the unary case as the operand shapes to all the
|
||||||
|
// element-wise ops can only be the unique input shape.
|
||||||
|
SmallVector<Value> results;
|
||||||
|
for (Value v : llvm::cast<chlo::RankSpecializationClusterYieldOp>(
|
||||||
|
op.getBody()->getTerminator())
|
||||||
|
.results()) {
|
||||||
|
Value flat_result = bvm.lookup(v);
|
||||||
|
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, v.getType(), flat_result, shape);
|
||||||
|
results.push_back(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace the rank specialization cluster.
|
||||||
|
rewriter.replaceOp(op, results);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RankSpecializationToSCFPass
|
||||||
|
: public PassWrapper<RankSpecializationToSCFPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect,
|
||||||
|
shape::ShapeDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnFunction() override {
|
||||||
|
MLIRContext *ctx = &getContext();
|
||||||
|
RewritePatternSet patterns(ctx);
|
||||||
|
PopulateRankSpecializationToSCFPatterns(ctx, &patterns);
|
||||||
|
if (failed(
|
||||||
|
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void PopulateRankSpecializationClusterPatterns(
|
void PopulateRankSpecializationClusterPatterns(
|
||||||
|
@ -171,9 +263,18 @@ void PopulateRankSpecializationClusterPatterns(
|
||||||
patterns->insert<RankSpecializationClusterPattern>(context);
|
patterns->insert<RankSpecializationClusterPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PopulateRankSpecializationToSCFPatterns(
|
||||||
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
|
patterns->insert<LowerUnaryRankSpecializationClusterPattern>(context);
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
||||||
return std::make_unique<RankSpecializationClusterPass>();
|
return std::make_unique<RankSpecializationClusterPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass() {
|
||||||
|
return std::make_unique<RankSpecializationToSCFPass>();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
|
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
|
||||||
|
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf | FileCheck %s --check-prefix CHECK-SCF
|
||||||
|
|
||||||
// CHECK-LABEL: @add_mul
|
// CHECK-LABEL: @add_mul
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
||||||
|
@ -37,6 +38,18 @@ func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
return %2 : tensor<*xf32>
|
return %2 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-SCF-LABEL: @sqrt
|
||||||
|
// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||||
|
// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||||
|
// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]]
|
||||||
|
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||||
|
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor<?xf32>)
|
||||||
|
// CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor<?xf32>)
|
||||||
|
// CHECK-SCF: %[[TMP2:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>)
|
||||||
|
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[TMP2]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-SCF: return %[[RES]]
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Don't cluster single ranked operation.
|
// Don't cluster single ranked operation.
|
||||||
|
|
Loading…
Reference in New Issue