diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 7f286f9..b449e1b 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -175,6 +175,11 @@ struct RankSpecializationClusterPass /// Lower rank specialization cluster to SCF. +bool IsScalarTensorType(Type ty) { + auto ranked_ty = ty.dyn_cast(); + return ranked_ty && ranked_ty.getRank() == 0; +} + Type DeriveRankedTensorTypes(Type ty, int64_t rank) { auto unranked_ty = ty.dyn_cast(); if (!unranked_ty) return ty; @@ -182,65 +187,112 @@ Type DeriveRankedTensorTypes(Type ty, int64_t rank) { return RankedTensorType::get(shape, unranked_ty.getElementType()); } -/// Unary element-wise operations on unranked tensors can be applied to the -/// flattened tensor and reshaped to the expected shape afterwards. -struct LowerUnaryRankSpecializationClusterPattern +Type DeriveUnrankedTensorTypes(Type ty) { + if (auto ranked_ty = ty.dyn_cast()) + return UnrankedTensorType::get(ranked_ty.getElementType()); + return ty; +} + +Optional FindUniqueNonScalar(ValueRange values) { + Value unique_non_scalar; + for (Value v : values) { + if (!IsScalarTensorType(v.getType())) { + if (unique_non_scalar) return llvm::None; + unique_non_scalar = v; + } + } + if (!unique_non_scalar) return llvm::None; + return unique_non_scalar; +} + +SmallVector MaterializeRankedOperations( + OpBuilder &b, Location loc, BlockAndValueMapping &bvm, + chlo::RankSpecializationClusterOp &op, int64_t target_rank) { + // Create ranked operations. + for (Operation &nested_op : op.getBody()->without_terminator()) { + auto mapped_operands = llvm::to_vector<4>(llvm::map_range( + nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); })); + auto ranked_result_types = llvm::to_vector<2>(llvm::map_range( + nested_op.getResultTypes(), + [&](Type ty) { return DeriveRankedTensorTypes(ty, target_rank); })); + OperationState ranked_op_state(loc, nested_op.getName().getStringRef(), + mapped_operands, ranked_result_types, + nested_op.getAttrs()); + Operation *ranked_op = b.createOperation(ranked_op_state); + for (auto it : llvm::zip(nested_op.getResults(), ranked_op->getResults())) + bvm.map(std::get<0>(it), std::get<1>(it)); + } + + // Collect ranked results. + auto yield_op = llvm::cast( + op.getBody()->getTerminator()); + return llvm::to_vector<8>(llvm::map_range( + yield_op.results(), [&](Value v) { return bvm.lookup(v); })); +} + +SmallVector MaterializeFinalReshape( + OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, + ValueRange unshaped_results) { + // Compute result shape. + auto non_scalar_operands = llvm::make_filter_range( + op.operands(), [](Value v) { return !IsScalarTensorType(v.getType()); }); + SmallVector results; + auto operand_shapes = + llvm::to_vector<8>(llvm::map_range(non_scalar_operands, [&](Value v) { + return b.create(loc, v).result(); + })); + auto shape = b.create( + loc, shape::getExtentTensorType(b.getContext()), operand_shapes); + + // Reshape results. + return llvm::to_vector<8>( + llvm::map_range(unshaped_results, [&](Value unshaped) { + return b + .create( + loc, DeriveUnrankedTensorTypes(unshaped.getType()), unshaped, + shape) + .result(); + })); +} + +struct LowerSingleNonScalarOperandPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, PatternRewriter &rewriter) const override { - // Only apply this to unary operations. - if (op.operands().size() != 1) return failure(); + // Only apply this pattern if we can statically know that all operands have + // the same shape or are scalars, i.e. all but one operands are scalars. + Optional non_scalar_operand = FindUniqueNonScalar(op.operands()); + if (!non_scalar_operand) return failure(); - // Compute flattened operand shape. + // Flatten the non-scalar operand. Location loc = op.getLoc(); - Value arg = op.operands().front(); - Value shape = rewriter.create(loc, arg); Value flat_shape = rewriter.create( loc, rewriter - .create(loc, rewriter.getIndexType(), shape) + .create( + loc, rewriter.getIndexType(), + rewriter.create(loc, *non_scalar_operand)) .result()); + Value flat_non_scalar_operand = rewriter.create( + loc, DeriveRankedTensorTypes(non_scalar_operand->getType(), /*rank=*/1), + *non_scalar_operand, flat_shape); - // Flatten operand. - Value flat_arg = rewriter.create( - loc, DeriveRankedTensorTypes(arg.getType(), /*rank=*/1), arg, - flat_shape); - - // Materialize ranked versions of the element-wise operations. + // Materialize ranked variants for the element-wise operations. BlockAndValueMapping bvm; - bvm.map(op.getBody()->getArguments().front(), flat_arg); - for (Operation &nested_op : op.getBody()->without_terminator()) { - auto mapped_operands = llvm::to_vector<4>(llvm::map_range( - nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); })); - auto ranked_result_types = llvm::to_vector<2>(llvm::map_range( - nested_op.getResultTypes(), - [](Type ty) { return DeriveRankedTensorTypes(ty, /*rank=*/1); })); - OperationState ranked_op_state(loc, nested_op.getName().getStringRef(), - mapped_operands, ranked_result_types, - nested_op.getAttrs()); - Operation *ranked_op = rewriter.createOperation(ranked_op_state); - for (auto it : - llvm::zip(nested_op.getResults(), ranked_op->getResults())) { - bvm.map(std::get<0>(it), std::get<1>(it)); - } + for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) { + Value operand = std::get<1>(it); + bvm.map(std::get<0>(it), operand == *non_scalar_operand + ? flat_non_scalar_operand + : operand); } + SmallVector unshaped_results = + MaterializeRankedOperations(rewriter, loc, bvm, op, /*target_rank=*/1); - // Collect results and restore their shape. We don't have to reify a shape - // computation in the unary case as the operand shapes to all the - // element-wise ops can only be the unique input shape. - SmallVector results; - for (Value v : llvm::cast( - op.getBody()->getTerminator()) - .results()) { - Value flat_result = bvm.lookup(v); - Value result = rewriter.create( - loc, v.getType(), flat_result, shape); - results.push_back(result); - } - - // Replace the rank specialization cluster. + // Restore the results' expected shape. + SmallVector results = + MaterializeFinalReshape(rewriter, loc, op, unshaped_results); rewriter.replaceOp(op, results); return success(); } @@ -273,7 +325,7 @@ void PopulateRankSpecializationClusterPatterns( void PopulateRankSpecializationToSCFPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert(context); + patterns->insert(context); } std::unique_ptr createRankSpecializationClusterPass() { diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index 5850062..521118c 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -46,13 +46,14 @@ func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor) // CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor) -// CHECK-SCF: %[[TMP2:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor) -// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[TMP2]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor) +// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]] +// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-SCF: return %[[RES]] // ----- -// Don't cluster single ranked operation. +// Don't cluster ranked operations. // CHECK-LABEL: @sqrt_ranked // CHECK-SAME: (%[[ARG:.*]]: tensor<3x?xf32>) func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> { @@ -84,7 +85,7 @@ func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>, // Unary CHLO operation. // CHECK-LABEL: @tan -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) ( { // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>) @@ -99,6 +100,19 @@ func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> { return %2 : tensor<*xf32> } +// CHECK-SCF-LABEL: @tan +// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>) +// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] +// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]] +// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-SCF: %[[TMP0:.*]] = chlo.tan %[[FLAT_ARG]] : tensor +// CHECK-SCF: %[[TMP1:.*]] = chlo.tan %[[TMP0]] : tensor +// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.tan %[[TMP1]] : tensor +// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]] +// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-SCF: return %[[RES]] + // ----- // Composition of unary/binary CHLO and unary MHLO ops. @@ -145,6 +159,18 @@ func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> { return %1 : tensor<*xf32> } +// CHECK-SCF-LABEL: @relu +// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>) +// CHECK-SCF: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> +// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] +// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]] +// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.broadcast_maximum %[[FLAT_ARG]], %[[C0]] : (tensor, tensor) +// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]] +// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-SCF: return %[[RES]] + // ----- // Cluster with binary non-broadcasting operation. @@ -164,6 +190,19 @@ func @angle(%arg : tensor<*xcomplex>) -> tensor<*xf32> { return %2 : tensor<*xf32> } +// CHECK-SCF-LABEL: @angle +// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xcomplex>) +// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] +// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]] +// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xcomplex>, tensor<1xindex>) -> tensor> +// CHECK-SCF: %[[IMAG:.*]] = "mhlo.imag"(%[[FLAT_ARG]]) : (tensor>) +// CHECK-SCF: %[[REAL:.*]] = "mhlo.real"(%[[FLAT_ARG]]) : (tensor>) +// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] : tensor +// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]] +// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-SCF: return %[[RES]] + // ----- // CHECK-LABEL: @xlogy