[MLIR][HLO] Support all smaller ranks in rank specialization cases

Rank specialization cases can be applied to all argument tensors of smaller
ranks than the expected maximum rank. This is crucial if all operands are
effectively scalars and the maximum reduced rank is 0.

PiperOrigin-RevId: 375712020
This commit is contained in:
A. Unique TensorFlower 2021-05-25 08:37:51 -07:00 committed by TensorFlow MLIR Team
parent c5af02fd8d
commit cb46298a07
2 changed files with 37 additions and 38 deletions

View File

@ -423,14 +423,14 @@ Value RecusivelyMaterializeTargetRankSpecializationCases(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes, Value max_rank,
int64_t min_target_rank, int64_t max_target_rank) {
Value min_target_rank_predicate =
b.create<CmpIOp>(loc, CmpIPredicate::eq, max_rank,
Value condition =
b.create<CmpIOp>(loc, CmpIPredicate::ule, max_rank,
b.create<ConstantIndexOp>(loc, min_target_rank));
// If only a unique target rank is left, we can lower to an assert instead
// of the usual if operation.
if (min_target_rank == max_target_rank) {
b.create<AssertOp>(loc, min_target_rank_predicate,
b.create<AssertOp>(loc, condition,
"Input for dynamic binary or n-ary op lowering was of "
"a rank greater than " +
std::to_string(max_target_rank));
@ -439,9 +439,8 @@ Value RecusivelyMaterializeTargetRankSpecializationCases(
}
// Materialize IR for the smallest considered target rank.
auto if_op =
b.create<scf::IfOp>(loc, op->getResultTypes(), min_target_rank_predicate,
/*withElseRegion=*/true);
auto if_op = b.create<scf::IfOp>(loc, op->getResultTypes(), condition,
/*withElseRegion=*/true);
auto then_builder = if_op.getThenBodyBuilder();
then_builder.create<scf::YieldOp>(
loc, MaterializeTargetRankSpecializationCase(then_builder, loc, op,

View File

@ -67,8 +67,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
// CHECK-SCF-DAG: %[[R20_GT_R1:.*]] = cmpi sgt, %[[R20]], %[[REDUCED_RANK1]]
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R20_GT_R1]], %[[R20]], %[[REDUCED_RANK1]]
// Generic case 1:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]]
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C1]]
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_LE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
@ -84,8 +84,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 2:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]]
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_2:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C2]]
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_LE_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]]
@ -101,8 +101,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 3:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]]
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C3]]
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_LE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
@ -118,8 +118,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 4:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]]
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_4:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C4]]
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_LE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
@ -135,8 +135,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 5:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]]
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_5:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C5]]
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_LE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
@ -152,8 +152,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 6:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]]
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_6:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C6]]
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_LE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
@ -169,8 +169,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 7:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]]
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_7:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C7]]
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_LE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
@ -186,8 +186,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 8:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]]
// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
// CHECK-SCF: %[[MAX_RED_RANK_LE_8:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C8]]
// CHECK-SCF: assert %[[MAX_RED_RANK_LE_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
@ -524,8 +524,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
// Generic case 1:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]]
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C1]]
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_LE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
@ -537,8 +537,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 2:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]]
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_2:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C2]]
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_LE_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
@ -550,8 +550,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 3:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]]
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C3]]
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_LE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
@ -563,8 +563,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 4:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]]
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_4:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C4]]
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_LE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
@ -576,8 +576,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 5:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]]
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_5:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C5]]
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_LE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
@ -589,8 +589,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 6:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]]
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_6:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C6]]
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_LE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
@ -602,8 +602,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 7:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]]
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]]
// CHECK-SCF: %[[MAX_RED_RANK_LE_7:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C7]]
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_LE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
@ -615,8 +615,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 8:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]]
// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
// CHECK-SCF: %[[MAX_RED_RANK_LE_8:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C8]]
// CHECK-SCF: assert %[[MAX_RED_RANK_LE_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]