From c62fd89663c0491ebd39a12bc8ad8611c0a5d77a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 May 2021 05:37:49 -0700 Subject: [PATCH] [MLIR][HLO] Add equal shapes case to rank specialization Also restructure lowering implementation to facilitate the addition or removal of special cases. PiperOrigin-RevId: 374626365 --- .../mhlo/transforms/rank_specialization.cc | 303 +++++++++++------- tests/rank-specialization.mlir | 206 ++++++------ 2 files changed, 296 insertions(+), 213 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index b076ae2..0f59779 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -182,10 +182,10 @@ bool IsScalarTensorType(Type ty) { } Type DeriveRankedTensorTypes(Type ty, int64_t rank) { - auto unranked_ty = ty.dyn_cast(); - if (!unranked_ty) return ty; + auto tensor_ty = ty.dyn_cast(); + if (!tensor_ty) return ty; SmallVector shape(rank, ShapedType::kDynamicSize); - return RankedTensorType::get(shape, unranked_ty.getElementType()); + return RankedTensorType::get(shape, tensor_ty.getElementType()); } Type DeriveUnrankedTensorTypes(Type ty) { @@ -208,7 +208,7 @@ Optional FindUniqueNonScalar(ValueRange values) { SmallVector MaterializeRankedOperations( OpBuilder &b, Location loc, BlockAndValueMapping &bvm, - chlo::RankSpecializationClusterOp &op, int64_t target_rank) { + chlo::RankSpecializationClusterOp op, int64_t target_rank) { // Create ranked operations. for (Operation &nested_op : op.getBody()->without_terminator()) { auto mapped_operands = llvm::to_vector<4>(llvm::map_range( @@ -256,59 +256,101 @@ SmallVector MaterializeFinalReshape( })); } -struct LowerSingleNonScalarOperandPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +SmallVector MaterializeRankSpecializationForSingleNonScalarOperand( + OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, + Value non_scalar_operand) { + // Flatten the non-scalar operand. + Value flat_shape = b.create( + loc, b.create( + loc, b.getIndexType(), + b.create(loc, non_scalar_operand)) + .result()); + Value flat_non_scalar_operand = b.create( + loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1), + non_scalar_operand, flat_shape); - LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, - PatternRewriter &rewriter) const override { - // Only apply this pattern if we can statically know that all operands have - // the same shape or are scalars, i.e. all but one operands are scalars. - Optional non_scalar_operand = FindUniqueNonScalar(op.operands()); - if (!non_scalar_operand) return failure(); - - // Flatten the non-scalar operand. - Location loc = op.getLoc(); - Value flat_shape = rewriter.create( - loc, - rewriter - .create( - loc, rewriter.getIndexType(), - rewriter.create(loc, *non_scalar_operand)) - .result()); - Value flat_non_scalar_operand = rewriter.create( - loc, DeriveRankedTensorTypes(non_scalar_operand->getType(), /*rank=*/1), - *non_scalar_operand, flat_shape); - - // Materialize ranked variants for the element-wise operations. - BlockAndValueMapping bvm; - for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) { - Value operand = std::get<1>(it); - bvm.map(std::get<0>(it), operand == *non_scalar_operand - ? flat_non_scalar_operand - : operand); - } - SmallVector unshaped_results = - MaterializeRankedOperations(rewriter, loc, bvm, op, /*target_rank=*/1); - - // Restore the results' expected shape. - SmallVector results = - MaterializeFinalReshape(rewriter, loc, op, unshaped_results); - rewriter.replaceOp(op, results); - return success(); + // Materialize ranked variants for the element-wise operations. + BlockAndValueMapping bvm; + for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) { + Value operand = std::get<1>(it); + bvm.map(std::get<0>(it), + operand == non_scalar_operand ? flat_non_scalar_operand : operand); } -}; + SmallVector unshaped_results = + MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1); -Value MaterializeRankSpecialization(OpBuilder &b, Location loc, - chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, - int64_t target_rank) { + // Restore the results' expected shape. + return MaterializeFinalReshape(b, loc, op, unshaped_results); +} + +Value MaterializeEqualShapesRankSpecializationCase( + OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, + const SmallVector &shapes, + function_ref else_builder_fn) { + assert(shapes.size() >= 2 && + "This strategy should only be materialized if there are at least two " + "shapes involved."); + + // Materialize all shapes equal predicate. + Value all_shapes_eq; + for (Value s : llvm::drop_begin(shapes)) { + auto literal = b.create(loc, shapes.front(), s); + all_shapes_eq = + all_shapes_eq + ? b.create(loc, all_shapes_eq, literal).result() + : literal; + } + + auto if_op = b.create( + loc, op->getResultTypes(), all_shapes_eq, + [&](OpBuilder &b, Location loc) { + // Flatten operands. + Value shape = shapes.front(); + for (Value s : llvm::drop_begin(shapes)) { + shape = b.create(loc, shape.getType(), + ValueRange{shape, s}); + } + Value flat_shape = b.create( + loc, b.create(loc, b.getIndexType(), shape) + .result()); + SmallVector flat_operands = + llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { + return b + .create( + loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v, + flat_shape) + .result(); + })); + + // Materialize ranked variants for the element-wise operations. + // TODO(frgossen): Materializae non-broadcasting equivalents instead. + BlockAndValueMapping bvm; + for (auto it : llvm::zip(op.getBody()->getArguments(), flat_operands)) + bvm.map(std::get<0>(it), std::get<1>(it)); + Value unshaped_result = + MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1) + .front(); + + // Return as unranked tensor for compatibility with the other cases. + b.create( + loc, b.create( + loc, DeriveUnrankedTensorTypes(unshaped_result.getType()), + unshaped_result) + .dest()); + }, + else_builder_fn); + + return if_op.results().front(); +} + +Value MaterializeTargetRankSpecializationCase( + OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, + const SmallVector &shapes, int64_t target_rank) { // Reshape operands to match the target rank. - MLIRContext *ctx = op->getContext(); llvm::SmallVector ranked_ty_dynamic_dims( target_rank, RankedTensorType::kDynamicSize); RankedTensorType extent_tensor_ty = - shape::getExtentTensorType(ctx, target_rank); + shape::getExtentTensorType(b.getContext(), target_rank); Value all_ones_shape = b.create( loc, extent_tensor_ty, mlir::DenseIntElementsAttr::get(extent_tensor_ty, @@ -319,7 +361,8 @@ Value MaterializeRankSpecialization(OpBuilder &b, Location loc, std::tie(operand, shape) = it; Value ranked_shape = b.create( loc, extent_tensor_ty, - b.create(loc, shape::getExtentTensorType(ctx), + b.create(loc, + shape::getExtentTensorType(b.getContext()), shape, all_ones_shape, /*error=*/nullptr)); Type element_ty = operand.getType().dyn_cast().getElementType(); @@ -341,11 +384,10 @@ Value MaterializeRankSpecialization(OpBuilder &b, Location loc, unshaped_result); } -Value MaterializeAllRankSpecializations(OpBuilder &b, Location loc, - chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, - Value max_rank, int64_t min_target_rank, - int64_t max_target_rank) { +Value RecusivelyMaterializeTargetRankSpecializationCases( + OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, + const SmallVector &shapes, Value max_rank, + int64_t min_target_rank, int64_t max_target_rank) { Value min_target_rank_predicate = b.create(loc, CmpIPredicate::eq, max_rank, b.create(loc, min_target_rank)); @@ -357,7 +399,8 @@ Value MaterializeAllRankSpecializations(OpBuilder &b, Location loc, "Input for dynamic binary or n-ary op lowering was of " "a rank greater than " + std::to_string(max_target_rank)); - return MaterializeRankSpecialization(b, loc, op, shapes, min_target_rank); + return MaterializeTargetRankSpecializationCase(b, loc, op, shapes, + min_target_rank); } // Materialize IR for the smallest considered target rank. @@ -366,84 +409,107 @@ Value MaterializeAllRankSpecializations(OpBuilder &b, Location loc, /*withElseRegion=*/true); auto then_builder = if_op.getThenBodyBuilder(); then_builder.create( - loc, MaterializeRankSpecialization(then_builder, loc, op, shapes, - min_target_rank)); + loc, MaterializeTargetRankSpecializationCase(then_builder, loc, op, + shapes, min_target_rank)); - // Recur for all remaining target ranks. + // Recurse for all remaining target ranks. auto else_builder = if_op.getElseBodyBuilder(); else_builder.create( - loc, - MaterializeAllRankSpecializations(else_builder, loc, op, shapes, max_rank, - min_target_rank + 1, max_target_rank)); + loc, RecusivelyMaterializeTargetRankSpecializationCases( + else_builder, loc, op, shapes, max_rank, min_target_rank + 1, + max_target_rank)); return if_op.results().front(); } -struct LowerMultipleNonScalarOperandPattern +Value MaterializeGenericRankSpecializationCases( + OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, + const SmallVector &shapes) { + // Get the minimum broadcast shapes of the operands. + ValueRange reduced_shapes = + b.create( + loc, + SmallVector(shapes.size(), + shape::getExtentTensorType(b.getContext())), + shapes) + .results(); + // TODO(frgossen): Avoid this reshape if it is redundant in all cases. + SmallVector reshaped_args; + for (auto it : llvm::zip(op.operands(), reduced_shapes)) { + Value arg = std::get<0>(it); + Value reduced_shape = std::get<1>(it); + reshaped_args.push_back(b.create( + loc, arg.getType(), arg, reduced_shape)); + } + + // Find the maximum rank among the reduced operand shapes. + Value max_rank; + for (Value shape : reduced_shapes) { + Value rank = b.create(loc, b.getIndexType(), shape); + if (!max_rank) { + max_rank = rank; + } else { + max_rank = b.create( + loc, b.create(loc, CmpIPredicate::sgt, max_rank, rank), + max_rank, rank); + } + } + + // Materialize rank specialization for ranks 1, ..., 8. + // TODO(frgossen): For clusters w/o a select operation, consider only ranks + // 1, ..., 5. + const int64_t kMinTargetRank = 1; + const int64_t kMaxTargetRank = 8; + return RecusivelyMaterializeTargetRankSpecializationCases( + b, loc, op, reduced_shapes, max_rank, kMinTargetRank, kMaxTargetRank); +} + +Value MaterializeDefaultRankSpecialization( + OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) { + auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { + return b.create(loc, v).result(); + })); + + // Materialize all the different cases. + Value unshaped_result = MaterializeEqualShapesRankSpecializationCase( + b, loc, op, shapes, [&](OpBuilder &b, Location loc) { + b.create( + loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes)); + }); + + // Materialize final reshape once and for all rank specialization cases. + return MaterializeFinalReshape(b, loc, op, unshaped_result).front(); +} + +struct LowerRankSpecializationClusterPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, PatternRewriter &rewriter) const override { - // We have a specialized pattern for the case in which all but one operands - // are scalars. - if (FindUniqueNonScalar(op.operands())) return failure(); + // TODO(frgossen): If there is a single operand, we can flatten it + // completely and apply a non-broadcasting operation. + + // If there is only one unranked operand and all others are known scalars, + // we can flatten the operands to rank 1. + Location loc = op.getLoc(); + if (Optional non_scalar_operand = + FindUniqueNonScalar(op.operands())) { + rewriter.replaceOp(op, + MaterializeRankSpecializationForSingleNonScalarOperand( + rewriter, loc, op, *non_scalar_operand)); + return success(); + } // Restoring the result shape currently relies on all operands being used // for a single result. The result shape is then the broadcasted shape of // all operands. if (op.getNumResults() != 1) return failure(); - // Get the minimum broadcast shapes of the operands. - Location loc = op.getLoc(); - SmallVector shapes = - llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { - return rewriter.create(loc, v).result(); - })); - ValueRange reduced_shapes = - rewriter - .create( - loc, - SmallVector(shapes.size(), - shape::getExtentTensorType(getContext())), - shapes) - .results(); - // TODO(frgossen): Avoid this reshape if it is redundant in all cases. - SmallVector reshaped_args; - for (auto it : llvm::zip(op.operands(), reduced_shapes)) { - Value arg = std::get<0>(it); - Value reduced_shape = std::get<1>(it); - reshaped_args.push_back(rewriter.create( - loc, arg.getType(), arg, reduced_shape)); - } - - // Find the maximum rank among the reduced operand shapes. - Value max_rank; - for (Value shape : reduced_shapes) { - Value rank = - rewriter.create(loc, rewriter.getIndexType(), shape); - if (!max_rank) { - max_rank = rank; - } else { - max_rank = rewriter.create( - loc, - rewriter.create(loc, CmpIPredicate::sgt, max_rank, rank), - max_rank, rank); - } - } - - // Materialize rank specialization for ranks 1, ..., 8. - // TODO(frgossen): For clusters w/o a select operation, consider only ranks - // 1, ..., 5. - const int64_t kMinTargetRank = 1; - const int64_t kMaxTargetRank = 8; - Value unshaped_result = MaterializeAllRankSpecializations( - rewriter, loc, op, reduced_shapes, max_rank, kMinTargetRank, - kMaxTargetRank); - - // Materialize final reshape once and for all rank specialization cases. - rewriter.replaceOp( - op, MaterializeFinalReshape(rewriter, loc, op, unshaped_result)); + // For all other cases, reshape the operands to match in rank, apply the + // operation, and restore the expected shape. + rewriter.replaceOp(op, + MaterializeDefaultRankSpecialization(rewriter, loc, op)); return success(); } }; @@ -475,8 +541,7 @@ void PopulateRankSpecializationClusterPatterns( void PopulateRankSpecializationToSCFPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert(context); + patterns->insert(context); } std::unique_ptr createRankSpecializationClusterPass() { diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index dc33d8f..1365e88 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -40,162 +40,180 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, // CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] // CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] // CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]] -// Find maximum reduced rank. -// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#1 -// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#2 -// CHECK-SCF-DAG: %[[REDUCED_RANK2:.*]] = shape.rank %[[REDUCED_SHAPES]]#0 -// CHECK-SCF-DAG: %[[R2_GT_R0:.*]] = cmpi sgt, %[[REDUCED_RANK2]], %[[REDUCED_RANK0]] -// CHECK-SCF-DAG: %[[R20:.*]] = select %[[R2_GT_R0]], %[[REDUCED_RANK2]], %[[REDUCED_RANK0]] -// CHECK-SCF-DAG: %[[R20_GT_R1:.*]] = cmpi sgt, %[[R20]], %[[REDUCED_RANK1]] -// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R20_GT_R1]], %15, %[[REDUCED_RANK1]] -// Case 1: -// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]] -// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) -// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// Equal shapes case: +// CHECK-SCF-DAG: %[[EQ20:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EQ21:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = and %[[EQ20]], %[[EQ21]] +// CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]] +// CHECK-SCF-DAG: %[[S20:.*]] = shape.any %[[SHAPE_ARG2]], %[[SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[S201:.*]] = shape.any %[[S20]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S201]] +// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[FLAT_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[FLAT_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[FLAT_ARG2]] : (tensor, tensor) // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: else -// Case 2: -// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]] -// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]] +// Find maximum reduced rank. +// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#1 +// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#2 +// CHECK-SCF-DAG: %[[REDUCED_RANK2:.*]] = shape.rank %[[REDUCED_SHAPES]]#0 +// CHECK-SCF-DAG: %[[R2_GT_R0:.*]] = cmpi sgt, %[[REDUCED_RANK2]], %[[REDUCED_RANK0]] +// CHECK-SCF-DAG: %[[R20:.*]] = select %[[R2_GT_R0]], %[[REDUCED_RANK2]], %[[REDUCED_RANK0]] +// CHECK-SCF-DAG: %[[R20_GT_R1:.*]] = cmpi sgt, %[[R20]], %[[REDUCED_RANK1]] +// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R20_GT_R1]], %[[R20]], %[[REDUCED_RANK1]] +// Generic case 1: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]] +// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: else -// Case 3: -// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]] -// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] +// Generic case 2: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]] +// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_2]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: else -// Case 4: -// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]] -// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_4]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]] +// Generic case 3: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]] +// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: else -// Case 5: -// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]] -// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_5]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]] +// Generic case 4: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]] +// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_4]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: else -// Case 6: -// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]] -// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_6]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]] +// Generic case 5: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]] +// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_5]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: else -// Case 7: -// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]] -// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_7]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]] +// Generic case 6: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]] +// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_6]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: else -// Case 8: -// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]] -// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8" -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_8]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) -// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]] +// Generic case 7: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]] +// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_7]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Generic case 8: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]] +// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8" +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_8]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) +// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]] // Reshape the result. // CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] // CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] // CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]] // CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_1]], %[[RES_SHAPE]]) +// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_EQ_SHAPES]], %[[RES_SHAPE]]) // CHECK-SCF: return %[[RES]] // -----