Fix bug using std.rank instead of shape.rank
PiperOrigin-RevId: 339890070
This commit is contained in:
parent
76b30fd426
commit
81e8d778c4
|
@ -377,9 +377,9 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
Value rhs_shape =
|
Value rhs_shape =
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
|
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
|
||||||
Value lhs_rank =
|
Value lhs_rank =
|
||||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
|
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), lhs_shape);
|
||||||
Value rhs_rank =
|
Value rhs_rank =
|
||||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
|
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), rhs_shape);
|
||||||
Value greater_rank_lhs =
|
Value greater_rank_lhs =
|
||||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
|
||||||
Value greater_rank =
|
Value greater_rank =
|
||||||
|
|
|
@ -158,9 +158,9 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
|
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[LHS_RANK]], %[[C0]] : index
|
||||||
// Handle scalar LHS case
|
// Handle scalar LHS case
|
||||||
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||||
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
|
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
|
||||||
|
@ -173,8 +173,8 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: } else {
|
// CHECK-NEXT: } else {
|
||||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||||
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
|
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RHS_RANK]], %[[C0]] : index
|
||||||
// Handle scalar RHS case
|
// Handle scalar RHS case
|
||||||
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||||
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
|
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
|
||||||
|
@ -197,8 +197,6 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: } else {
|
// CHECK-NEXT: } else {
|
||||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||||
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
|
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
|
||||||
|
|
Loading…
Reference in New Issue