[MLIR][HLO] Generalize rank specialization with single operand

The pattern can be generalized to also rank specialize operations with a single
non-scalar operand. Also extract helper functions that can be reused in
following specializations.

PiperOrigin-RevId: 374198381
This commit is contained in:
A. Unique TensorFlower 2021-05-17 08:11:59 -07:00 committed by TensorFlow MLIR Team
parent b86b18489c
commit 474e419729
2 changed files with 140 additions and 49 deletions

View File

@ -175,6 +175,11 @@ struct RankSpecializationClusterPass
/// Lower rank specialization cluster to SCF. /// Lower rank specialization cluster to SCF.
bool IsScalarTensorType(Type ty) {
auto ranked_ty = ty.dyn_cast<RankedTensorType>();
return ranked_ty && ranked_ty.getRank() == 0;
}
Type DeriveRankedTensorTypes(Type ty, int64_t rank) { Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
auto unranked_ty = ty.dyn_cast<UnrankedTensorType>(); auto unranked_ty = ty.dyn_cast<UnrankedTensorType>();
if (!unranked_ty) return ty; if (!unranked_ty) return ty;
@ -182,65 +187,112 @@ Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
return RankedTensorType::get(shape, unranked_ty.getElementType()); return RankedTensorType::get(shape, unranked_ty.getElementType());
} }
/// Unary element-wise operations on unranked tensors can be applied to the Type DeriveUnrankedTensorTypes(Type ty) {
/// flattened tensor and reshaped to the expected shape afterwards. if (auto ranked_ty = ty.dyn_cast<RankedTensorType>())
struct LowerUnaryRankSpecializationClusterPattern return UnrankedTensorType::get(ranked_ty.getElementType());
: public OpRewritePattern<chlo::RankSpecializationClusterOp> { return ty;
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern; }
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, Optional<Value> FindUniqueNonScalar(ValueRange values) {
PatternRewriter &rewriter) const override { Value unique_non_scalar;
// Only apply this to unary operations. for (Value v : values) {
if (op.operands().size() != 1) return failure(); if (!IsScalarTensorType(v.getType())) {
if (unique_non_scalar) return llvm::None;
unique_non_scalar = v;
}
}
if (!unique_non_scalar) return llvm::None;
return unique_non_scalar;
}
// Compute flattened operand shape. SmallVector<Value, 8> MaterializeRankedOperations(
Location loc = op.getLoc(); OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
Value arg = op.operands().front(); chlo::RankSpecializationClusterOp &op, int64_t target_rank) {
Value shape = rewriter.create<shape::ShapeOfOp>(loc, arg); // Create ranked operations.
Value flat_shape = rewriter.create<tensor::FromElementsOp>(
loc,
rewriter
.create<shape::NumElementsOp>(loc, rewriter.getIndexType(), shape)
.result());
// Flatten operand.
Value flat_arg = rewriter.create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(arg.getType(), /*rank=*/1), arg,
flat_shape);
// Materialize ranked versions of the element-wise operations.
BlockAndValueMapping bvm;
bvm.map(op.getBody()->getArguments().front(), flat_arg);
for (Operation &nested_op : op.getBody()->without_terminator()) { for (Operation &nested_op : op.getBody()->without_terminator()) {
auto mapped_operands = llvm::to_vector<4>(llvm::map_range( auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); })); nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); }));
auto ranked_result_types = llvm::to_vector<2>(llvm::map_range( auto ranked_result_types = llvm::to_vector<2>(llvm::map_range(
nested_op.getResultTypes(), nested_op.getResultTypes(),
[](Type ty) { return DeriveRankedTensorTypes(ty, /*rank=*/1); })); [&](Type ty) { return DeriveRankedTensorTypes(ty, target_rank); }));
OperationState ranked_op_state(loc, nested_op.getName().getStringRef(), OperationState ranked_op_state(loc, nested_op.getName().getStringRef(),
mapped_operands, ranked_result_types, mapped_operands, ranked_result_types,
nested_op.getAttrs()); nested_op.getAttrs());
Operation *ranked_op = rewriter.createOperation(ranked_op_state); Operation *ranked_op = b.createOperation(ranked_op_state);
for (auto it : for (auto it : llvm::zip(nested_op.getResults(), ranked_op->getResults()))
llvm::zip(nested_op.getResults(), ranked_op->getResults())) {
bvm.map(std::get<0>(it), std::get<1>(it)); bvm.map(std::get<0>(it), std::get<1>(it));
} }
}
// Collect results and restore their shape. We don't have to reify a shape // Collect ranked results.
// computation in the unary case as the operand shapes to all the auto yield_op = llvm::cast<chlo::RankSpecializationClusterYieldOp>(
// element-wise ops can only be the unique input shape. op.getBody()->getTerminator());
SmallVector<Value> results; return llvm::to_vector<8>(llvm::map_range(
for (Value v : llvm::cast<chlo::RankSpecializationClusterYieldOp>( yield_op.results(), [&](Value v) { return bvm.lookup(v); }));
op.getBody()->getTerminator()) }
.results()) {
Value flat_result = bvm.lookup(v);
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
loc, v.getType(), flat_result, shape);
results.push_back(result);
}
// Replace the rank specialization cluster. SmallVector<Value, 8> MaterializeFinalReshape(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
ValueRange unshaped_results) {
// Compute result shape.
auto non_scalar_operands = llvm::make_filter_range(
op.operands(), [](Value v) { return !IsScalarTensorType(v.getType()); });
SmallVector<Value, 8> results;
auto operand_shapes =
llvm::to_vector<8>(llvm::map_range(non_scalar_operands, [&](Value v) {
return b.create<shape::ShapeOfOp>(loc, v).result();
}));
auto shape = b.create<shape::BroadcastOp>(
loc, shape::getExtentTensorType(b.getContext()), operand_shapes);
// Reshape results.
return llvm::to_vector<8>(
llvm::map_range(unshaped_results, [&](Value unshaped) {
return b
.create<mhlo::DynamicReshapeOp>(
loc, DeriveUnrankedTensorTypes(unshaped.getType()), unshaped,
shape)
.result();
}));
}
struct LowerSingleNonScalarOperandPattern
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
PatternRewriter &rewriter) const override {
// Only apply this pattern if we can statically know that all operands have
// the same shape or are scalars, i.e. all but one operands are scalars.
Optional<Value> non_scalar_operand = FindUniqueNonScalar(op.operands());
if (!non_scalar_operand) return failure();
// Flatten the non-scalar operand.
Location loc = op.getLoc();
Value flat_shape = rewriter.create<tensor::FromElementsOp>(
loc,
rewriter
.create<shape::NumElementsOp>(
loc, rewriter.getIndexType(),
rewriter.create<shape::ShapeOfOp>(loc, *non_scalar_operand))
.result());
Value flat_non_scalar_operand = rewriter.create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(non_scalar_operand->getType(), /*rank=*/1),
*non_scalar_operand, flat_shape);
// Materialize ranked variants for the element-wise operations.
BlockAndValueMapping bvm;
for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
Value operand = std::get<1>(it);
bvm.map(std::get<0>(it), operand == *non_scalar_operand
? flat_non_scalar_operand
: operand);
}
SmallVector<Value, 8> unshaped_results =
MaterializeRankedOperations(rewriter, loc, bvm, op, /*target_rank=*/1);
// Restore the results' expected shape.
SmallVector<Value, 8> results =
MaterializeFinalReshape(rewriter, loc, op, unshaped_results);
rewriter.replaceOp(op, results); rewriter.replaceOp(op, results);
return success(); return success();
} }
@ -273,7 +325,7 @@ void PopulateRankSpecializationClusterPatterns(
void PopulateRankSpecializationToSCFPatterns( void PopulateRankSpecializationToSCFPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) { MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<LowerUnaryRankSpecializationClusterPattern>(context); patterns->insert<LowerSingleNonScalarOperandPattern>(context);
} }
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() { std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {

View File

@ -46,13 +46,14 @@ func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> // CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor<?xf32>) // CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor<?xf32>)
// CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor<?xf32>) // CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor<?xf32>)
// CHECK-SCF: %[[TMP2:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>) // CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>)
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[TMP2]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> // CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: return %[[RES]] // CHECK-SCF: return %[[RES]]
// ----- // -----
// Don't cluster single ranked operation. // Don't cluster ranked operations.
// CHECK-LABEL: @sqrt_ranked // CHECK-LABEL: @sqrt_ranked
// CHECK-SAME: (%[[ARG:.*]]: tensor<3x?xf32>) // CHECK-SAME: (%[[ARG:.*]]: tensor<3x?xf32>)
func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> { func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
@ -84,7 +85,7 @@ func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
// Unary CHLO operation. // Unary CHLO operation.
// CHECK-LABEL: @tan // CHECK-LABEL: @tan
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> { func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) ( { // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) ( {
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>) // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>)
@ -99,6 +100,19 @@ func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
return %2 : tensor<*xf32> return %2 : tensor<*xf32>
} }
// CHECK-SCF-LABEL: @tan
// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>)
// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]]
// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-SCF: %[[TMP0:.*]] = chlo.tan %[[FLAT_ARG]] : tensor<?xf32>
// CHECK-SCF: %[[TMP1:.*]] = chlo.tan %[[TMP0]] : tensor<?xf32>
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.tan %[[TMP1]] : tensor<?xf32>
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: return %[[RES]]
// ----- // -----
// Composition of unary/binary CHLO and unary MHLO ops. // Composition of unary/binary CHLO and unary MHLO ops.
@ -145,6 +159,18 @@ func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> {
return %1 : tensor<*xf32> return %1 : tensor<*xf32>
} }
// CHECK-SCF-LABEL: @relu
// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>)
// CHECK-SCF: %[[C0:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]]
// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.broadcast_maximum %[[FLAT_ARG]], %[[C0]] : (tensor<?xf32>, tensor<f32>)
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: return %[[RES]]
// ----- // -----
// Cluster with binary non-broadcasting operation. // Cluster with binary non-broadcasting operation.
@ -164,6 +190,19 @@ func @angle(%arg : tensor<*xcomplex<f32>>) -> tensor<*xf32> {
return %2 : tensor<*xf32> return %2 : tensor<*xf32>
} }
// CHECK-SCF-LABEL: @angle
// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xcomplex<f32>>)
// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]]
// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xcomplex<f32>>, tensor<1xindex>) -> tensor<?xcomplex<f32>>
// CHECK-SCF: %[[IMAG:.*]] = "mhlo.imag"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>)
// CHECK-SCF: %[[REAL:.*]] = "mhlo.real"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>)
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] : tensor<?xf32>
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: return %[[RES]]
// ----- // -----
// CHECK-LABEL: @xlogy // CHECK-LABEL: @xlogy