[MLIR][HLO] Generalize rank specialization with single operand
The pattern can be generalized to also rank specialize operations with a single non-scalar operand. Also extract helper functions that can be reused in following specializations. PiperOrigin-RevId: 374198381
This commit is contained in:
parent
b86b18489c
commit
474e419729
|
@ -175,6 +175,11 @@ struct RankSpecializationClusterPass
|
|||
|
||||
/// Lower rank specialization cluster to SCF.
|
||||
|
||||
bool IsScalarTensorType(Type ty) {
|
||||
auto ranked_ty = ty.dyn_cast<RankedTensorType>();
|
||||
return ranked_ty && ranked_ty.getRank() == 0;
|
||||
}
|
||||
|
||||
Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
|
||||
auto unranked_ty = ty.dyn_cast<UnrankedTensorType>();
|
||||
if (!unranked_ty) return ty;
|
||||
|
@ -182,65 +187,112 @@ Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
|
|||
return RankedTensorType::get(shape, unranked_ty.getElementType());
|
||||
}
|
||||
|
||||
/// Unary element-wise operations on unranked tensors can be applied to the
|
||||
/// flattened tensor and reshaped to the expected shape afterwards.
|
||||
struct LowerUnaryRankSpecializationClusterPattern
|
||||
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
||||
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
|
||||
Type DeriveUnrankedTensorTypes(Type ty) {
|
||||
if (auto ranked_ty = ty.dyn_cast<RankedTensorType>())
|
||||
return UnrankedTensorType::get(ranked_ty.getElementType());
|
||||
return ty;
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only apply this to unary operations.
|
||||
if (op.operands().size() != 1) return failure();
|
||||
Optional<Value> FindUniqueNonScalar(ValueRange values) {
|
||||
Value unique_non_scalar;
|
||||
for (Value v : values) {
|
||||
if (!IsScalarTensorType(v.getType())) {
|
||||
if (unique_non_scalar) return llvm::None;
|
||||
unique_non_scalar = v;
|
||||
}
|
||||
}
|
||||
if (!unique_non_scalar) return llvm::None;
|
||||
return unique_non_scalar;
|
||||
}
|
||||
|
||||
// Compute flattened operand shape.
|
||||
Location loc = op.getLoc();
|
||||
Value arg = op.operands().front();
|
||||
Value shape = rewriter.create<shape::ShapeOfOp>(loc, arg);
|
||||
Value flat_shape = rewriter.create<tensor::FromElementsOp>(
|
||||
loc,
|
||||
rewriter
|
||||
.create<shape::NumElementsOp>(loc, rewriter.getIndexType(), shape)
|
||||
.result());
|
||||
|
||||
// Flatten operand.
|
||||
Value flat_arg = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, DeriveRankedTensorTypes(arg.getType(), /*rank=*/1), arg,
|
||||
flat_shape);
|
||||
|
||||
// Materialize ranked versions of the element-wise operations.
|
||||
BlockAndValueMapping bvm;
|
||||
bvm.map(op.getBody()->getArguments().front(), flat_arg);
|
||||
SmallVector<Value, 8> MaterializeRankedOperations(
|
||||
OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
|
||||
chlo::RankSpecializationClusterOp &op, int64_t target_rank) {
|
||||
// Create ranked operations.
|
||||
for (Operation &nested_op : op.getBody()->without_terminator()) {
|
||||
auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
|
||||
nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); }));
|
||||
auto ranked_result_types = llvm::to_vector<2>(llvm::map_range(
|
||||
nested_op.getResultTypes(),
|
||||
[](Type ty) { return DeriveRankedTensorTypes(ty, /*rank=*/1); }));
|
||||
[&](Type ty) { return DeriveRankedTensorTypes(ty, target_rank); }));
|
||||
OperationState ranked_op_state(loc, nested_op.getName().getStringRef(),
|
||||
mapped_operands, ranked_result_types,
|
||||
nested_op.getAttrs());
|
||||
Operation *ranked_op = rewriter.createOperation(ranked_op_state);
|
||||
for (auto it :
|
||||
llvm::zip(nested_op.getResults(), ranked_op->getResults())) {
|
||||
Operation *ranked_op = b.createOperation(ranked_op_state);
|
||||
for (auto it : llvm::zip(nested_op.getResults(), ranked_op->getResults()))
|
||||
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||
}
|
||||
}
|
||||
|
||||
// Collect results and restore their shape. We don't have to reify a shape
|
||||
// computation in the unary case as the operand shapes to all the
|
||||
// element-wise ops can only be the unique input shape.
|
||||
SmallVector<Value> results;
|
||||
for (Value v : llvm::cast<chlo::RankSpecializationClusterYieldOp>(
|
||||
op.getBody()->getTerminator())
|
||||
.results()) {
|
||||
Value flat_result = bvm.lookup(v);
|
||||
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, v.getType(), flat_result, shape);
|
||||
results.push_back(result);
|
||||
}
|
||||
// Collect ranked results.
|
||||
auto yield_op = llvm::cast<chlo::RankSpecializationClusterYieldOp>(
|
||||
op.getBody()->getTerminator());
|
||||
return llvm::to_vector<8>(llvm::map_range(
|
||||
yield_op.results(), [&](Value v) { return bvm.lookup(v); }));
|
||||
}
|
||||
|
||||
// Replace the rank specialization cluster.
|
||||
SmallVector<Value, 8> MaterializeFinalReshape(
|
||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||
ValueRange unshaped_results) {
|
||||
// Compute result shape.
|
||||
auto non_scalar_operands = llvm::make_filter_range(
|
||||
op.operands(), [](Value v) { return !IsScalarTensorType(v.getType()); });
|
||||
SmallVector<Value, 8> results;
|
||||
auto operand_shapes =
|
||||
llvm::to_vector<8>(llvm::map_range(non_scalar_operands, [&](Value v) {
|
||||
return b.create<shape::ShapeOfOp>(loc, v).result();
|
||||
}));
|
||||
auto shape = b.create<shape::BroadcastOp>(
|
||||
loc, shape::getExtentTensorType(b.getContext()), operand_shapes);
|
||||
|
||||
// Reshape results.
|
||||
return llvm::to_vector<8>(
|
||||
llvm::map_range(unshaped_results, [&](Value unshaped) {
|
||||
return b
|
||||
.create<mhlo::DynamicReshapeOp>(
|
||||
loc, DeriveUnrankedTensorTypes(unshaped.getType()), unshaped,
|
||||
shape)
|
||||
.result();
|
||||
}));
|
||||
}
|
||||
|
||||
struct LowerSingleNonScalarOperandPattern
|
||||
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
||||
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only apply this pattern if we can statically know that all operands have
|
||||
// the same shape or are scalars, i.e. all but one operands are scalars.
|
||||
Optional<Value> non_scalar_operand = FindUniqueNonScalar(op.operands());
|
||||
if (!non_scalar_operand) return failure();
|
||||
|
||||
// Flatten the non-scalar operand.
|
||||
Location loc = op.getLoc();
|
||||
Value flat_shape = rewriter.create<tensor::FromElementsOp>(
|
||||
loc,
|
||||
rewriter
|
||||
.create<shape::NumElementsOp>(
|
||||
loc, rewriter.getIndexType(),
|
||||
rewriter.create<shape::ShapeOfOp>(loc, *non_scalar_operand))
|
||||
.result());
|
||||
Value flat_non_scalar_operand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, DeriveRankedTensorTypes(non_scalar_operand->getType(), /*rank=*/1),
|
||||
*non_scalar_operand, flat_shape);
|
||||
|
||||
// Materialize ranked variants for the element-wise operations.
|
||||
BlockAndValueMapping bvm;
|
||||
for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
|
||||
Value operand = std::get<1>(it);
|
||||
bvm.map(std::get<0>(it), operand == *non_scalar_operand
|
||||
? flat_non_scalar_operand
|
||||
: operand);
|
||||
}
|
||||
SmallVector<Value, 8> unshaped_results =
|
||||
MaterializeRankedOperations(rewriter, loc, bvm, op, /*target_rank=*/1);
|
||||
|
||||
// Restore the results' expected shape.
|
||||
SmallVector<Value, 8> results =
|
||||
MaterializeFinalReshape(rewriter, loc, op, unshaped_results);
|
||||
rewriter.replaceOp(op, results);
|
||||
return success();
|
||||
}
|
||||
|
@ -273,7 +325,7 @@ void PopulateRankSpecializationClusterPatterns(
|
|||
|
||||
void PopulateRankSpecializationToSCFPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
patterns->insert<LowerUnaryRankSpecializationClusterPattern>(context);
|
||||
patterns->insert<LowerSingleNonScalarOperandPattern>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
||||
|
|
|
@ -46,13 +46,14 @@ func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor<?xf32>)
|
||||
// CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor<?xf32>)
|
||||
// CHECK-SCF: %[[TMP2:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>)
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[TMP2]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>)
|
||||
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
||||
// -----
|
||||
|
||||
// Don't cluster single ranked operation.
|
||||
// Don't cluster ranked operations.
|
||||
// CHECK-LABEL: @sqrt_ranked
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<3x?xf32>)
|
||||
func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
||||
|
@ -84,7 +85,7 @@ func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
|
|||
|
||||
// Unary CHLO operation.
|
||||
// CHECK-LABEL: @tan
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||
func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) ( {
|
||||
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>)
|
||||
|
@ -99,6 +100,19 @@ func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
|||
return %2 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-SCF-LABEL: @tan
|
||||
// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||
// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||
// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-SCF: %[[TMP0:.*]] = chlo.tan %[[FLAT_ARG]] : tensor<?xf32>
|
||||
// CHECK-SCF: %[[TMP1:.*]] = chlo.tan %[[TMP0]] : tensor<?xf32>
|
||||
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.tan %[[TMP1]] : tensor<?xf32>
|
||||
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
||||
// -----
|
||||
|
||||
// Composition of unary/binary CHLO and unary MHLO ops.
|
||||
|
@ -145,6 +159,18 @@ func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
|||
return %1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-SCF-LABEL: @relu
|
||||
// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||
// CHECK-SCF: %[[C0:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||
// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.broadcast_maximum %[[FLAT_ARG]], %[[C0]] : (tensor<?xf32>, tensor<f32>)
|
||||
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
||||
// -----
|
||||
|
||||
// Cluster with binary non-broadcasting operation.
|
||||
|
@ -164,6 +190,19 @@ func @angle(%arg : tensor<*xcomplex<f32>>) -> tensor<*xf32> {
|
|||
return %2 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-SCF-LABEL: @angle
|
||||
// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xcomplex<f32>>)
|
||||
// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||
// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xcomplex<f32>>, tensor<1xindex>) -> tensor<?xcomplex<f32>>
|
||||
// CHECK-SCF: %[[IMAG:.*]] = "mhlo.imag"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>)
|
||||
// CHECK-SCF: %[[REAL:.*]] = "mhlo.real"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>)
|
||||
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] : tensor<?xf32>
|
||||
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @xlogy
|
||||
|
|
Loading…
Reference in New Issue