diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 358cb5c..e54c404 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -377,9 +377,9 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp Value rhs_shape = rewriter.create(loc, extent_tensor_type, rhs); Value lhs_rank = - rewriter.create(loc, rewriter.getIndexType(), lhs_shape); + rewriter.create(loc, rewriter.getIndexType(), lhs_shape); Value rhs_rank = - rewriter.create(loc, rewriter.getIndexType(), rhs_shape); + rewriter.create(loc, rewriter.getIndexType(), rhs_shape); Value greater_rank_lhs = rewriter.create(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank); Value greater_rank = diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir index c0baecd..8488a38 100644 --- a/tests/hlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -158,9 +158,9 @@ func @addUnrankedUnranked( // CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>, // CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { // CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor -> index // CHECK-NEXT: %[[C0:.*]] = constant 0 : index -// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index +// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[LHS_RANK]], %[[C0]] : index // Handle scalar LHS case // CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor @@ -173,8 +173,8 @@ func @addUnrankedUnranked( // CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32> // CHECK-NEXT: } else { // CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index +// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RHS_RANK]], %[[C0]] : index // Handle scalar RHS case // CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor @@ -197,8 +197,6 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor -// CHECK-NEXT: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor // CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[C2:.*]] = constant 2 : index