[MLIR][HLO] Generalize rank specialization with single operand
The pattern can be generalized to also rank specialize operations with a single non-scalar operand. Also extract helper functions that can be reused in following specializations. PiperOrigin-RevId: 374198381
This commit is contained in:
parent
b86b18489c
commit
474e419729
|
@ -175,6 +175,11 @@ struct RankSpecializationClusterPass
|
||||||
|
|
||||||
/// Lower rank specialization cluster to SCF.
|
/// Lower rank specialization cluster to SCF.
|
||||||
|
|
||||||
|
bool IsScalarTensorType(Type ty) {
|
||||||
|
auto ranked_ty = ty.dyn_cast<RankedTensorType>();
|
||||||
|
return ranked_ty && ranked_ty.getRank() == 0;
|
||||||
|
}
|
||||||
|
|
||||||
Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
|
Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
|
||||||
auto unranked_ty = ty.dyn_cast<UnrankedTensorType>();
|
auto unranked_ty = ty.dyn_cast<UnrankedTensorType>();
|
||||||
if (!unranked_ty) return ty;
|
if (!unranked_ty) return ty;
|
||||||
|
@ -182,65 +187,112 @@ Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
|
||||||
return RankedTensorType::get(shape, unranked_ty.getElementType());
|
return RankedTensorType::get(shape, unranked_ty.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unary element-wise operations on unranked tensors can be applied to the
|
Type DeriveUnrankedTensorTypes(Type ty) {
|
||||||
/// flattened tensor and reshaped to the expected shape afterwards.
|
if (auto ranked_ty = ty.dyn_cast<RankedTensorType>())
|
||||||
struct LowerUnaryRankSpecializationClusterPattern
|
return UnrankedTensorType::get(ranked_ty.getElementType());
|
||||||
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
return ty;
|
||||||
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
|
}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
Optional<Value> FindUniqueNonScalar(ValueRange values) {
|
||||||
PatternRewriter &rewriter) const override {
|
Value unique_non_scalar;
|
||||||
// Only apply this to unary operations.
|
for (Value v : values) {
|
||||||
if (op.operands().size() != 1) return failure();
|
if (!IsScalarTensorType(v.getType())) {
|
||||||
|
if (unique_non_scalar) return llvm::None;
|
||||||
|
unique_non_scalar = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!unique_non_scalar) return llvm::None;
|
||||||
|
return unique_non_scalar;
|
||||||
|
}
|
||||||
|
|
||||||
// Compute flattened operand shape.
|
SmallVector<Value, 8> MaterializeRankedOperations(
|
||||||
Location loc = op.getLoc();
|
OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
|
||||||
Value arg = op.operands().front();
|
chlo::RankSpecializationClusterOp &op, int64_t target_rank) {
|
||||||
Value shape = rewriter.create<shape::ShapeOfOp>(loc, arg);
|
// Create ranked operations.
|
||||||
Value flat_shape = rewriter.create<tensor::FromElementsOp>(
|
|
||||||
loc,
|
|
||||||
rewriter
|
|
||||||
.create<shape::NumElementsOp>(loc, rewriter.getIndexType(), shape)
|
|
||||||
.result());
|
|
||||||
|
|
||||||
// Flatten operand.
|
|
||||||
Value flat_arg = rewriter.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, DeriveRankedTensorTypes(arg.getType(), /*rank=*/1), arg,
|
|
||||||
flat_shape);
|
|
||||||
|
|
||||||
// Materialize ranked versions of the element-wise operations.
|
|
||||||
BlockAndValueMapping bvm;
|
|
||||||
bvm.map(op.getBody()->getArguments().front(), flat_arg);
|
|
||||||
for (Operation &nested_op : op.getBody()->without_terminator()) {
|
for (Operation &nested_op : op.getBody()->without_terminator()) {
|
||||||
auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
|
auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
|
||||||
nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); }));
|
nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); }));
|
||||||
auto ranked_result_types = llvm::to_vector<2>(llvm::map_range(
|
auto ranked_result_types = llvm::to_vector<2>(llvm::map_range(
|
||||||
nested_op.getResultTypes(),
|
nested_op.getResultTypes(),
|
||||||
[](Type ty) { return DeriveRankedTensorTypes(ty, /*rank=*/1); }));
|
[&](Type ty) { return DeriveRankedTensorTypes(ty, target_rank); }));
|
||||||
OperationState ranked_op_state(loc, nested_op.getName().getStringRef(),
|
OperationState ranked_op_state(loc, nested_op.getName().getStringRef(),
|
||||||
mapped_operands, ranked_result_types,
|
mapped_operands, ranked_result_types,
|
||||||
nested_op.getAttrs());
|
nested_op.getAttrs());
|
||||||
Operation *ranked_op = rewriter.createOperation(ranked_op_state);
|
Operation *ranked_op = b.createOperation(ranked_op_state);
|
||||||
for (auto it :
|
for (auto it : llvm::zip(nested_op.getResults(), ranked_op->getResults()))
|
||||||
llvm::zip(nested_op.getResults(), ranked_op->getResults())) {
|
|
||||||
bvm.map(std::get<0>(it), std::get<1>(it));
|
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Collect results and restore their shape. We don't have to reify a shape
|
// Collect ranked results.
|
||||||
// computation in the unary case as the operand shapes to all the
|
auto yield_op = llvm::cast<chlo::RankSpecializationClusterYieldOp>(
|
||||||
// element-wise ops can only be the unique input shape.
|
op.getBody()->getTerminator());
|
||||||
SmallVector<Value> results;
|
return llvm::to_vector<8>(llvm::map_range(
|
||||||
for (Value v : llvm::cast<chlo::RankSpecializationClusterYieldOp>(
|
yield_op.results(), [&](Value v) { return bvm.lookup(v); }));
|
||||||
op.getBody()->getTerminator())
|
}
|
||||||
.results()) {
|
|
||||||
Value flat_result = bvm.lookup(v);
|
|
||||||
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, v.getType(), flat_result, shape);
|
|
||||||
results.push_back(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace the rank specialization cluster.
|
SmallVector<Value, 8> MaterializeFinalReshape(
|
||||||
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||||
|
ValueRange unshaped_results) {
|
||||||
|
// Compute result shape.
|
||||||
|
auto non_scalar_operands = llvm::make_filter_range(
|
||||||
|
op.operands(), [](Value v) { return !IsScalarTensorType(v.getType()); });
|
||||||
|
SmallVector<Value, 8> results;
|
||||||
|
auto operand_shapes =
|
||||||
|
llvm::to_vector<8>(llvm::map_range(non_scalar_operands, [&](Value v) {
|
||||||
|
return b.create<shape::ShapeOfOp>(loc, v).result();
|
||||||
|
}));
|
||||||
|
auto shape = b.create<shape::BroadcastOp>(
|
||||||
|
loc, shape::getExtentTensorType(b.getContext()), operand_shapes);
|
||||||
|
|
||||||
|
// Reshape results.
|
||||||
|
return llvm::to_vector<8>(
|
||||||
|
llvm::map_range(unshaped_results, [&](Value unshaped) {
|
||||||
|
return b
|
||||||
|
.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, DeriveUnrankedTensorTypes(unshaped.getType()), unshaped,
|
||||||
|
shape)
|
||||||
|
.result();
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LowerSingleNonScalarOperandPattern
|
||||||
|
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
||||||
|
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Only apply this pattern if we can statically know that all operands have
|
||||||
|
// the same shape or are scalars, i.e. all but one operands are scalars.
|
||||||
|
Optional<Value> non_scalar_operand = FindUniqueNonScalar(op.operands());
|
||||||
|
if (!non_scalar_operand) return failure();
|
||||||
|
|
||||||
|
// Flatten the non-scalar operand.
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value flat_shape = rewriter.create<tensor::FromElementsOp>(
|
||||||
|
loc,
|
||||||
|
rewriter
|
||||||
|
.create<shape::NumElementsOp>(
|
||||||
|
loc, rewriter.getIndexType(),
|
||||||
|
rewriter.create<shape::ShapeOfOp>(loc, *non_scalar_operand))
|
||||||
|
.result());
|
||||||
|
Value flat_non_scalar_operand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, DeriveRankedTensorTypes(non_scalar_operand->getType(), /*rank=*/1),
|
||||||
|
*non_scalar_operand, flat_shape);
|
||||||
|
|
||||||
|
// Materialize ranked variants for the element-wise operations.
|
||||||
|
BlockAndValueMapping bvm;
|
||||||
|
for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
|
||||||
|
Value operand = std::get<1>(it);
|
||||||
|
bvm.map(std::get<0>(it), operand == *non_scalar_operand
|
||||||
|
? flat_non_scalar_operand
|
||||||
|
: operand);
|
||||||
|
}
|
||||||
|
SmallVector<Value, 8> unshaped_results =
|
||||||
|
MaterializeRankedOperations(rewriter, loc, bvm, op, /*target_rank=*/1);
|
||||||
|
|
||||||
|
// Restore the results' expected shape.
|
||||||
|
SmallVector<Value, 8> results =
|
||||||
|
MaterializeFinalReshape(rewriter, loc, op, unshaped_results);
|
||||||
rewriter.replaceOp(op, results);
|
rewriter.replaceOp(op, results);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -273,7 +325,7 @@ void PopulateRankSpecializationClusterPatterns(
|
||||||
|
|
||||||
void PopulateRankSpecializationToSCFPatterns(
|
void PopulateRankSpecializationToSCFPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
patterns->insert<LowerUnaryRankSpecializationClusterPattern>(context);
|
patterns->insert<LowerSingleNonScalarOperandPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
||||||
|
|
|
@ -46,13 +46,14 @@ func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor<?xf32>)
|
// CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor<?xf32>)
|
||||||
// CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor<?xf32>)
|
// CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor<?xf32>)
|
||||||
// CHECK-SCF: %[[TMP2:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>)
|
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>)
|
||||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[TMP2]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||||
|
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
// CHECK-SCF: return %[[RES]]
|
// CHECK-SCF: return %[[RES]]
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Don't cluster single ranked operation.
|
// Don't cluster ranked operations.
|
||||||
// CHECK-LABEL: @sqrt_ranked
|
// CHECK-LABEL: @sqrt_ranked
|
||||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<3x?xf32>)
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<3x?xf32>)
|
||||||
func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
||||||
|
@ -84,7 +85,7 @@ func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
|
||||||
|
|
||||||
// Unary CHLO operation.
|
// Unary CHLO operation.
|
||||||
// CHECK-LABEL: @tan
|
// CHECK-LABEL: @tan
|
||||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32>
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||||
func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) ( {
|
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) ( {
|
||||||
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>)
|
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>)
|
||||||
|
@ -99,6 +100,19 @@ func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
return %2 : tensor<*xf32>
|
return %2 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-SCF-LABEL: @tan
|
||||||
|
// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||||
|
// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||||
|
// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]]
|
||||||
|
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||||
|
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-SCF: %[[TMP0:.*]] = chlo.tan %[[FLAT_ARG]] : tensor<?xf32>
|
||||||
|
// CHECK-SCF: %[[TMP1:.*]] = chlo.tan %[[TMP0]] : tensor<?xf32>
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.tan %[[TMP1]] : tensor<?xf32>
|
||||||
|
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||||
|
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-SCF: return %[[RES]]
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Composition of unary/binary CHLO and unary MHLO ops.
|
// Composition of unary/binary CHLO and unary MHLO ops.
|
||||||
|
@ -145,6 +159,18 @@ func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
return %1 : tensor<*xf32>
|
return %1 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-SCF-LABEL: @relu
|
||||||
|
// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||||
|
// CHECK-SCF: %[[C0:.*]] = mhlo.constant dense<0.000000e+00>
|
||||||
|
// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||||
|
// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]]
|
||||||
|
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||||
|
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.broadcast_maximum %[[FLAT_ARG]], %[[C0]] : (tensor<?xf32>, tensor<f32>)
|
||||||
|
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||||
|
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-SCF: return %[[RES]]
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Cluster with binary non-broadcasting operation.
|
// Cluster with binary non-broadcasting operation.
|
||||||
|
@ -164,6 +190,19 @@ func @angle(%arg : tensor<*xcomplex<f32>>) -> tensor<*xf32> {
|
||||||
return %2 : tensor<*xf32>
|
return %2 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-SCF-LABEL: @angle
|
||||||
|
// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xcomplex<f32>>)
|
||||||
|
// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||||
|
// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]]
|
||||||
|
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||||
|
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xcomplex<f32>>, tensor<1xindex>) -> tensor<?xcomplex<f32>>
|
||||||
|
// CHECK-SCF: %[[IMAG:.*]] = "mhlo.imag"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>)
|
||||||
|
// CHECK-SCF: %[[REAL:.*]] = "mhlo.real"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>)
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] : tensor<?xf32>
|
||||||
|
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||||
|
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-SCF: return %[[RES]]
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @xlogy
|
// CHECK-LABEL: @xlogy
|
||||||
|
|
Loading…
Reference in New Issue