diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index b449e1b..b076ae2 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -298,11 +299,160 @@ struct LowerSingleNonScalarOperandPattern } }; +Value MaterializeRankSpecialization(OpBuilder &b, Location loc, + chlo::RankSpecializationClusterOp op, + const SmallVector &shapes, + int64_t target_rank) { + // Reshape operands to match the target rank. + MLIRContext *ctx = op->getContext(); + llvm::SmallVector ranked_ty_dynamic_dims( + target_rank, RankedTensorType::kDynamicSize); + RankedTensorType extent_tensor_ty = + shape::getExtentTensorType(ctx, target_rank); + Value all_ones_shape = b.create( + loc, extent_tensor_ty, + mlir::DenseIntElementsAttr::get(extent_tensor_ty, + SmallVector(target_rank, 1))); + SmallVector ranked_operands; + for (auto it : llvm::zip(op.operands(), shapes)) { + Value operand, shape; + std::tie(operand, shape) = it; + Value ranked_shape = b.create( + loc, extent_tensor_ty, + b.create(loc, shape::getExtentTensorType(ctx), + shape, all_ones_shape, + /*error=*/nullptr)); + Type element_ty = operand.getType().dyn_cast().getElementType(); + auto ranked_ty = RankedTensorType::get(ranked_ty_dynamic_dims, element_ty); + ranked_operands.push_back(b.create( + loc, ranked_ty, operand, ranked_shape)); + } + + // Materialize ranked versions of the element-wise operations. + BlockAndValueMapping bvm; + for (auto it : llvm::zip(op.body().front().getArguments(), ranked_operands)) + bvm.map(std::get<0>(it), std::get<1>(it)); + + // Return as unranked for compatibility with other target ranks. + auto unshaped_result = + MaterializeRankedOperations(b, loc, bvm, op, target_rank).front(); + return b.create( + loc, DeriveUnrankedTensorTypes(unshaped_result.getType()), + unshaped_result); +} + +Value MaterializeAllRankSpecializations(OpBuilder &b, Location loc, + chlo::RankSpecializationClusterOp op, + const SmallVector &shapes, + Value max_rank, int64_t min_target_rank, + int64_t max_target_rank) { + Value min_target_rank_predicate = + b.create(loc, CmpIPredicate::eq, max_rank, + b.create(loc, min_target_rank)); + + // If only a unique target rank is left, we can lower to an assert instead + // of the usual if operation. + if (min_target_rank == max_target_rank) { + b.create(loc, min_target_rank_predicate, + "Input for dynamic binary or n-ary op lowering was of " + "a rank greater than " + + std::to_string(max_target_rank)); + return MaterializeRankSpecialization(b, loc, op, shapes, min_target_rank); + } + + // Materialize IR for the smallest considered target rank. + auto if_op = + b.create(loc, op->getResultTypes(), min_target_rank_predicate, + /*withElseRegion=*/true); + auto then_builder = if_op.getThenBodyBuilder(); + then_builder.create( + loc, MaterializeRankSpecialization(then_builder, loc, op, shapes, + min_target_rank)); + + // Recur for all remaining target ranks. + auto else_builder = if_op.getElseBodyBuilder(); + else_builder.create( + loc, + MaterializeAllRankSpecializations(else_builder, loc, op, shapes, max_rank, + min_target_rank + 1, max_target_rank)); + + return if_op.results().front(); +} + +struct LowerMultipleNonScalarOperandPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, + PatternRewriter &rewriter) const override { + // We have a specialized pattern for the case in which all but one operands + // are scalars. + if (FindUniqueNonScalar(op.operands())) return failure(); + + // Restoring the result shape currently relies on all operands being used + // for a single result. The result shape is then the broadcasted shape of + // all operands. + if (op.getNumResults() != 1) return failure(); + + // Get the minimum broadcast shapes of the operands. + Location loc = op.getLoc(); + SmallVector shapes = + llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { + return rewriter.create(loc, v).result(); + })); + ValueRange reduced_shapes = + rewriter + .create( + loc, + SmallVector(shapes.size(), + shape::getExtentTensorType(getContext())), + shapes) + .results(); + // TODO(frgossen): Avoid this reshape if it is redundant in all cases. + SmallVector reshaped_args; + for (auto it : llvm::zip(op.operands(), reduced_shapes)) { + Value arg = std::get<0>(it); + Value reduced_shape = std::get<1>(it); + reshaped_args.push_back(rewriter.create( + loc, arg.getType(), arg, reduced_shape)); + } + + // Find the maximum rank among the reduced operand shapes. + Value max_rank; + for (Value shape : reduced_shapes) { + Value rank = + rewriter.create(loc, rewriter.getIndexType(), shape); + if (!max_rank) { + max_rank = rank; + } else { + max_rank = rewriter.create( + loc, + rewriter.create(loc, CmpIPredicate::sgt, max_rank, rank), + max_rank, rank); + } + } + + // Materialize rank specialization for ranks 1, ..., 8. + // TODO(frgossen): For clusters w/o a select operation, consider only ranks + // 1, ..., 5. + const int64_t kMinTargetRank = 1; + const int64_t kMaxTargetRank = 8; + Value unshaped_result = MaterializeAllRankSpecializations( + rewriter, loc, op, reduced_shapes, max_rank, kMinTargetRank, + kMaxTargetRank); + + // Materialize final reshape once and for all rank specialization cases. + rewriter.replaceOp( + op, MaterializeFinalReshape(rewriter, loc, op, unshaped_result)); + return success(); + } +}; + struct RankSpecializationToSCFPass : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + shape::ShapeDialect, scf::SCFDialect>(); } void runOnFunction() override { @@ -325,7 +475,8 @@ void PopulateRankSpecializationClusterPatterns( void PopulateRankSpecializationToSCFPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert(context); + patterns->insert(context); } std::unique_ptr createRankSpecializationClusterPass() { diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index 521118c..dc33d8f 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -19,6 +19,185 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, return %1 : tensor<*xf32> } +// CHECK-SCF-LABEL: @add_mul +// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) +// CHECK-SCF-DAG: %[[C1:.*]] = constant 1 +// CHECK-SCF-DAG: %[[C2:.*]] = constant 2 +// CHECK-SCF-DAG: %[[C3:.*]] = constant 3 +// CHECK-SCF-DAG: %[[C4:.*]] = constant 4 +// CHECK-SCF-DAG: %[[C5:.*]] = constant 5 +// CHECK-SCF-DAG: %[[C6:.*]] = constant 6 +// CHECK-SCF-DAG: %[[C7:.*]] = constant 7 +// CHECK-SCF-DAG: %[[C8:.*]] = constant 8 +// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_7:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_8:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1, 1] +// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] +// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] +// CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]] +// Find maximum reduced rank. +// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#1 +// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#2 +// CHECK-SCF-DAG: %[[REDUCED_RANK2:.*]] = shape.rank %[[REDUCED_SHAPES]]#0 +// CHECK-SCF-DAG: %[[R2_GT_R0:.*]] = cmpi sgt, %[[REDUCED_RANK2]], %[[REDUCED_RANK0]] +// CHECK-SCF-DAG: %[[R20:.*]] = select %[[R2_GT_R0]], %[[REDUCED_RANK2]], %[[REDUCED_RANK0]] +// CHECK-SCF-DAG: %[[R20_GT_R1:.*]] = cmpi sgt, %[[R20]], %[[REDUCED_RANK1]] +// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R20_GT_R1]], %15, %[[REDUCED_RANK1]] +// Case 1: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]] +// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Case 2: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]] +// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_2]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Case 3: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]] +// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Case 4: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]] +// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_4]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Case 5: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]] +// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_5]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Case 6: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]] +// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_6]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Case 7: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]] +// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_7]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Case 8: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]] +// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8" +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_8]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]] +// Reshape the result. +// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] +// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] +// CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]] +// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_1]], %[[RES_SHAPE]]) +// CHECK-SCF: return %[[RES]] + // ----- // Unary MHLO operation. @@ -64,6 +243,10 @@ func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> { return %2 : tensor<3x?xf32> } +// CHECK-SCF-LABEL: @sqrt_ranked +// CHECK-SCF-NOT: dynamic_reshape +// CHECK-SCF: return + // ----- // Ternary operation. @@ -81,6 +264,9 @@ func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>, return %0 : tensor<*xf32> } +// CHECK-SCF-LABEL: @select_mixed +// CHECK-SCF: chlo.broadcast_select %{{.*}}, %{{.*}}, %{{.*}} : (tensor, tensor, tensor) + // ----- // Unary CHLO operation. @@ -141,6 +327,14 @@ func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>) return %5 : tensor<*xf32> } +// CHECK-SCF-LABEL: @mixed +// CHECK-SCF-DAG: %[[TMP0:.*]] = chlo.tan %{{.*}} : tensor +// CHECK-SCF-DAG: %[[TMP1:.*]] = "mhlo.sqrt"(%{{.*}}) : (tensor) +// CHECK-SCF-DAG: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %{{.*}} : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP4:.*]] = "mhlo.sqrt"(%[[TMP3]]) : (tensor) +// CHECK-SCF: chlo.tan %[[TMP4]] : tensor + // ----- // Constant cluster operand. @@ -228,3 +422,9 @@ func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { : (tensor<*xi1>, tensor, tensor<*xf32>) -> tensor<*xf32> return %5 : tensor<*xf32> } + +// CHECK-SCF-LABEL: @xlogy +// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %{{.*}}, %{{.*}} {comparison_direction = "EQ"} : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%{{.*}}) : (tensor) +// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %{{.*}}, %[[TMP0]] : (tensor, tensor) +// CHECK-SCF: chlo.broadcast_select %[[PRED]], %{{.*}}, %[[TMP1]] : (tensor, tensor, tensor)