Handle rank 1 broadcasts in unranked kernel lowering.
Previously this started at rank 2 after checking for scalars and equal shapes. This resulted in cases such as <1xf32> + <2xf32> being treated as impossible. PiperOrigin-RevId: 341043965
This commit is contained in:
parent
c30ea47682
commit
af4c9774dc
|
@ -386,14 +386,14 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
|
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
|
||||||
|
|
||||||
// Generate a list of nested if/else statements to handle rank
|
// Generate a list of nested if/else statements to handle rank
|
||||||
// specializations from 2-6.
|
// specializations from 1-6.
|
||||||
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
|
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
|
||||||
rhs, greater_rank, 2);
|
rhs, greater_rank, 1);
|
||||||
|
|
||||||
// Put each subsequent rank specialization inside the else statement of the
|
// Put each subsequent rank specialization inside the else statement of the
|
||||||
// previous one.
|
// previous one.
|
||||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
for (int i = 3; i < max_rank_specialization; i++) {
|
for (int i = 2; i < max_rank_specialization; i++) {
|
||||||
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
|
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
|
||||||
rhs, greater_rank, i);
|
rhs, greater_rank, i);
|
||||||
|
|
||||||
|
|
|
@ -199,6 +199,21 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: } else {
|
// CHECK-NEXT: } else {
|
||||||
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||||
|
// Handle rank 1 specialization
|
||||||
|
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C1]] : index
|
||||||
|
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor_cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor_cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor_cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
|
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
|
||||||
// Handle rank 2 specialization
|
// Handle rank 2 specialization
|
||||||
|
@ -292,5 +307,7 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: return %[[VAL_71:.*]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[VAL_71:.*]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return %[[VAL_72:.*]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
|
Loading…
Reference in New Issue