[MLIR][HLO] Support all smaller ranks in rank specialization cases
Rank specialization cases can be applied to all argument tensors of smaller ranks than the expected maximum rank. This is crucial if all operands are effectively scalars and the maximum reduced rank is 0. PiperOrigin-RevId: 375712020
This commit is contained in:
parent
c5af02fd8d
commit
cb46298a07
|
@ -423,14 +423,14 @@ Value RecusivelyMaterializeTargetRankSpecializationCases(
|
||||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||||
const SmallVector<Value, 8> &shapes, Value max_rank,
|
const SmallVector<Value, 8> &shapes, Value max_rank,
|
||||||
int64_t min_target_rank, int64_t max_target_rank) {
|
int64_t min_target_rank, int64_t max_target_rank) {
|
||||||
Value min_target_rank_predicate =
|
Value condition =
|
||||||
b.create<CmpIOp>(loc, CmpIPredicate::eq, max_rank,
|
b.create<CmpIOp>(loc, CmpIPredicate::ule, max_rank,
|
||||||
b.create<ConstantIndexOp>(loc, min_target_rank));
|
b.create<ConstantIndexOp>(loc, min_target_rank));
|
||||||
|
|
||||||
// If only a unique target rank is left, we can lower to an assert instead
|
// If only a unique target rank is left, we can lower to an assert instead
|
||||||
// of the usual if operation.
|
// of the usual if operation.
|
||||||
if (min_target_rank == max_target_rank) {
|
if (min_target_rank == max_target_rank) {
|
||||||
b.create<AssertOp>(loc, min_target_rank_predicate,
|
b.create<AssertOp>(loc, condition,
|
||||||
"Input for dynamic binary or n-ary op lowering was of "
|
"Input for dynamic binary or n-ary op lowering was of "
|
||||||
"a rank greater than " +
|
"a rank greater than " +
|
||||||
std::to_string(max_target_rank));
|
std::to_string(max_target_rank));
|
||||||
|
@ -439,9 +439,8 @@ Value RecusivelyMaterializeTargetRankSpecializationCases(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Materialize IR for the smallest considered target rank.
|
// Materialize IR for the smallest considered target rank.
|
||||||
auto if_op =
|
auto if_op = b.create<scf::IfOp>(loc, op->getResultTypes(), condition,
|
||||||
b.create<scf::IfOp>(loc, op->getResultTypes(), min_target_rank_predicate,
|
/*withElseRegion=*/true);
|
||||||
/*withElseRegion=*/true);
|
|
||||||
auto then_builder = if_op.getThenBodyBuilder();
|
auto then_builder = if_op.getThenBodyBuilder();
|
||||||
then_builder.create<scf::YieldOp>(
|
then_builder.create<scf::YieldOp>(
|
||||||
loc, MaterializeTargetRankSpecializationCase(then_builder, loc, op,
|
loc, MaterializeTargetRankSpecializationCase(then_builder, loc, op,
|
||||||
|
|
|
@ -67,8 +67,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
// CHECK-SCF-DAG: %[[R20_GT_R1:.*]] = cmpi sgt, %[[R20]], %[[REDUCED_RANK1]]
|
// CHECK-SCF-DAG: %[[R20_GT_R1:.*]] = cmpi sgt, %[[R20]], %[[REDUCED_RANK1]]
|
||||||
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R20_GT_R1]], %[[R20]], %[[REDUCED_RANK1]]
|
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R20_GT_R1]], %[[R20]], %[[REDUCED_RANK1]]
|
||||||
// Generic case 1:
|
// Generic case 1:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C1]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]]
|
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_LE_1]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_1]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_1]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
|
||||||
|
@ -84,8 +84,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 2:
|
// Generic case 2:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_2:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C2]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]]
|
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_LE_2]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_2]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_2]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]]
|
||||||
|
@ -101,8 +101,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 3:
|
// Generic case 3:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C3]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]]
|
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_LE_3]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
|
||||||
|
@ -118,8 +118,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 4:
|
// Generic case 4:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_4:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C4]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]]
|
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_LE_4]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_4]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_4]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
|
||||||
|
@ -135,8 +135,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 5:
|
// Generic case 5:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_5:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C5]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]]
|
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_LE_5]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_5]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_5]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
|
||||||
|
@ -152,8 +152,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 6:
|
// Generic case 6:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_6:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C6]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]]
|
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_LE_6]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_6]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_6]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
|
||||||
|
@ -169,8 +169,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 7:
|
// Generic case 7:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_7:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C7]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]]
|
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_LE_7]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_7]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_7]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
|
||||||
|
@ -186,8 +186,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 8:
|
// Generic case 8:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_8:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C8]]
|
||||||
// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
|
// CHECK-SCF: assert %[[MAX_RED_RANK_LE_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_8]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_8]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
|
||||||
|
@ -524,8 +524,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
|
// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
|
||||||
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
|
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
|
||||||
// Generic case 1:
|
// Generic case 1:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C1]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]]
|
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_LE_1]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
@ -537,8 +537,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 2:
|
// Generic case 2:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_2:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C2]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]]
|
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_LE_2]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
@ -550,8 +550,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 3:
|
// Generic case 3:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C3]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]]
|
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_LE_3]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
@ -563,8 +563,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 4:
|
// Generic case 4:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_4:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C4]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]]
|
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_LE_4]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
@ -576,8 +576,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 5:
|
// Generic case 5:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_5:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C5]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]]
|
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_LE_5]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
@ -589,8 +589,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 6:
|
// Generic case 6:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_6:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C6]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]]
|
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_LE_6]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
@ -602,8 +602,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 7:
|
// Generic case 7:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_7:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C7]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]]
|
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_LE_7]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
@ -615,8 +615,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
// CHECK-SCF: else
|
// CHECK-SCF: else
|
||||||
// Generic case 8:
|
// Generic case 8:
|
||||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]]
|
// CHECK-SCF: %[[MAX_RED_RANK_LE_8:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C8]]
|
||||||
// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
|
// CHECK-SCF: assert %[[MAX_RED_RANK_LE_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
|
||||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
|
Loading…
Reference in New Issue