[MLIR][HLO] Add `rank-specialization-to-scf` pass

Currently the lowering is only implemented for the unary case. The n-ary case
will follow.

PiperOrigin-RevId: 374162772
This commit is contained in:
A. Unique TensorFlower 2021-05-17 03:55:32 -07:00 committed by TensorFlow MLIR Team
parent 295ef229d6
commit ccd70d5717
5 changed files with 127 additions and 4 deletions

View File

@ -128,3 +128,8 @@ def RankSpecializationClusterPass
: Pass<"mhlo-rank-specialization-cluster", "FuncOp"> { : Pass<"mhlo-rank-specialization-cluster", "FuncOp"> {
let constructor = "createRankSpecializationClusterPass()"; let constructor = "createRankSpecializationClusterPass()";
} }
def RankSpecializationToSCFPass
: Pass<"mhlo-rank-specialization-to-scf", "FuncOp"> {
let constructor = "createRankSpecializationToSCFPass()";
}

View File

@ -69,10 +69,12 @@ createLegalizeTrigonometricToApproximationPass();
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass(); std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
/// Rank specialization passes. /// Rank specialization passes:
/// - Find compatible operations and group them together in one rank /// - Find compatible operations and group them together in one rank
/// specialization region. /// specialization cluster.
/// - Lower rank specialization clusters to SCF and ranked operations.
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass(); std::unique_ptr<FunctionPass> createRankSpecializationClusterPass();
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass();
std::unique_ptr<FunctionPass> createOptimizeMhloPass(); std::unique_ptr<FunctionPass> createOptimizeMhloPass();
std::unique_ptr<FunctionPass> createLowerComplexPass(); std::unique_ptr<FunctionPass> createLowerComplexPass();

View File

@ -100,9 +100,11 @@ void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target);
void PopulateMoveUpDynamicBroadcastsForFusionPatterns( void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns); MLIRContext *context, OwningRewritePatternList *patterns);
/// Populate rank specialization clustering patterns. /// Populate rank specialization clustering and lowering patterns.
void PopulateRankSpecializationClusterPatterns( void PopulateRankSpecializationClusterPatterns(
MLIRContext *context, OwningRewritePatternList *patterns); MLIRContext *context, OwningRewritePatternList *patterns);
void PopulateRankSpecializationToSCFPatterns(
MLIRContext *context, OwningRewritePatternList *patterns);
} // namespace mhlo } // namespace mhlo

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
@ -34,7 +35,7 @@ limitations under the License.
namespace mlir { namespace mlir {
// Needed to build `llvm::SmallSet`s of `mlir::Value`s. /// Needed to build `llvm::SmallSet`s of `mlir::Value`s.
static bool operator<(const Value &lhs, const Value &rhs) { static bool operator<(const Value &lhs, const Value &rhs) {
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
} }
@ -164,6 +165,97 @@ struct RankSpecializationClusterPass
} }
}; };
/// Lower rank specialization cluster to SCF.
Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
auto unranked_ty = ty.dyn_cast<UnrankedTensorType>();
if (!unranked_ty) return ty;
SmallVector<int64_t, 8> shape(rank, ShapedType::kDynamicSize);
return RankedTensorType::get(shape, unranked_ty.getElementType());
}
/// Unary element-wise operations on unranked tensors can be applied to the
/// flattened tensor and reshaped to the expected shape afterwards.
struct LowerUnaryRankSpecializationClusterPattern
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
PatternRewriter &rewriter) const override {
// Only apply this to unary operations.
if (op.operands().size() != 1) return failure();
// Compute flattened operand shape.
Location loc = op.getLoc();
Value arg = op.operands().front();
Value shape = rewriter.create<shape::ShapeOfOp>(loc, arg);
Value flat_shape = rewriter.create<tensor::FromElementsOp>(
loc,
rewriter
.create<shape::NumElementsOp>(loc, rewriter.getIndexType(), shape)
.result());
// Flatten operand.
Value flat_arg = rewriter.create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(arg.getType(), /*rank=*/1), arg,
flat_shape);
// Materialize ranked versions of the element-wise operations.
BlockAndValueMapping bvm;
bvm.map(op.getBody()->getArguments().front(), flat_arg);
for (Operation &nested_op : op.getBody()->without_terminator()) {
auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); }));
auto ranked_result_types = llvm::to_vector<2>(llvm::map_range(
nested_op.getResultTypes(),
[](Type ty) { return DeriveRankedTensorTypes(ty, /*rank=*/1); }));
OperationState ranked_op_state(loc, nested_op.getName().getStringRef(),
mapped_operands, ranked_result_types,
nested_op.getAttrs());
Operation *ranked_op = rewriter.createOperation(ranked_op_state);
for (auto it :
llvm::zip(nested_op.getResults(), ranked_op->getResults())) {
bvm.map(std::get<0>(it), std::get<1>(it));
}
}
// Collect results and restore their shape. We don't have to reify a shape
// computation in the unary case as the operand shapes to all the
// element-wise ops can only be the unique input shape.
SmallVector<Value> results;
for (Value v : llvm::cast<chlo::RankSpecializationClusterYieldOp>(
op.getBody()->getTerminator())
.results()) {
Value flat_result = bvm.lookup(v);
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
loc, v.getType(), flat_result, shape);
results.push_back(result);
}
// Replace the rank specialization cluster.
rewriter.replaceOp(op, results);
return success();
}
};
struct RankSpecializationToSCFPass
: public PassWrapper<RankSpecializationToSCFPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect,
shape::ShapeDialect>();
}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
PopulateRankSpecializationToSCFPatterns(ctx, &patterns);
if (failed(
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace } // namespace
void PopulateRankSpecializationClusterPatterns( void PopulateRankSpecializationClusterPatterns(
@ -171,9 +263,18 @@ void PopulateRankSpecializationClusterPatterns(
patterns->insert<RankSpecializationClusterPattern>(context); patterns->insert<RankSpecializationClusterPattern>(context);
} }
void PopulateRankSpecializationToSCFPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<LowerUnaryRankSpecializationClusterPattern>(context);
}
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() { std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
return std::make_unique<RankSpecializationClusterPass>(); return std::make_unique<RankSpecializationClusterPass>();
} }
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass() {
return std::make_unique<RankSpecializationToSCFPass>();
}
} // namespace mhlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -1,4 +1,5 @@
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s // RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf | FileCheck %s --check-prefix CHECK-SCF
// CHECK-LABEL: @add_mul // CHECK-LABEL: @add_mul
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
@ -37,6 +38,18 @@ func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> {
return %2 : tensor<*xf32> return %2 : tensor<*xf32>
} }
// CHECK-SCF-LABEL: @sqrt
// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>)
// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]]
// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor<?xf32>)
// CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor<?xf32>)
// CHECK-SCF: %[[TMP2:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>)
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[TMP2]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: return %[[RES]]
// ----- // -----
// Don't cluster single ranked operation. // Don't cluster single ranked operation.