diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 1749bf0..0fb37fb 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -197,8 +197,8 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp Value shape = rewriter.create(loc, lhs_is_scalar ? rhs : lhs); Value num_elements = rewriter.create(loc, shape); - Value size = rewriter.create(loc, num_elements); - Value size_tensor = rewriter.create(loc, size); + Value size_tensor = + rewriter.create(loc, num_elements); Value reshaped = rewriter.create( loc, RankedTensorType::get({-1}, result_type.getElementType()), lhs_is_scalar ? rhs : lhs, size_tensor); @@ -211,10 +211,8 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp loc, SmallVector{reshaped.getType()}, operands, op.getAttrs()); // Reshape the result back into an unranked tensor. - Value shape_tensor = rewriter.create( - loc, RankedTensorType::get({-1}, rewriter.getIndexType()), shape); rewriter.replaceOpWithNewOp(op, result_type, - computed, shape_tensor); + computed, shape); return success(); } @@ -278,18 +276,10 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp // // See if shapes are equal. OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder(); - auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); Value shape_of_lhs = - else_no_scalars_builder.create( - loc, extent_tensor_type, - else_no_scalars_builder.create(loc, lhs) - .getResult()); + else_no_scalars_builder.create(loc, lhs); Value shape_of_rhs = - else_no_scalars_builder.create( - loc, extent_tensor_type, - else_no_scalars_builder.create(loc, rhs) - .getResult()); + else_no_scalars_builder.create(loc, rhs); Value equal_shapes = else_no_scalars_builder.create( loc, shape_of_lhs, shape_of_rhs); @@ -319,12 +309,8 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp // tensor. Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const { auto loc = op.getLoc(); - auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - Value shape_of_tensor = rewriter.create( - loc, extent_tensor_type, - rewriter.create(loc, tensor).getResult()); + Value shape_of_tensor = rewriter.create(loc, tensor); Value rank_tensor = rewriter.create( loc, rewriter.getIndexType(), shape_of_tensor); return rewriter.create(loc, rewriter.getI1Type(), CmpIPredicate::eq, diff --git a/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tests/chlo_legalize_to_hlo_broadcasts.mlir index 418c5c7..3e24ffd 100644 --- a/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -252,9 +252,8 @@ func @addScalarUnranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf3 // First handle the dynamic reshaping of the unranked operand // to a 1D tensor. // CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32> -// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] -// CHECK: %[[NUM_ELEMENTS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] -// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_INDEX]]) : tensor<1xindex> +// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor -> index +// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> // CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // The assuming region is part of the second stage of lowering // with ranked broadcasting logic. @@ -272,8 +271,7 @@ func @addScalarUnranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf3 // CHECK: } // As part of the unranked logic, the result is reshaped back // to an unranked tensor. -// CHECK: %[[SHAPE_2:.*]] = shape.to_extent_tensor %[[SHAPE_1]] : tensor -> tensor -// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_2]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_1]]) : (tensor, tensor) -> tensor<*xf32> // CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32> // CHECK: } @@ -289,9 +287,8 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf3 // First handle the dynamic reshaping of the unranked operand // to a 1D tensor. // CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] -// CHECK: %[[NUM_ELEMENTS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] -// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_INDEX]]) : tensor<1xindex> +// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor -> index +// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> // CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // The assuming region is part of the second stage of lowering // with ranked broadcasting logic. @@ -307,8 +304,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf3 // CHECK: } // As part of the unranked logic, the result is reshaped back // to an unranked tensor. -// CHECK: %[[SHAPE_2:.*]] = shape.to_extent_tensor %[[SHAPE_0]] -// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_2]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_0]]) : (tensor, tensor) -> tensor<*xf32> // CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32> // CHECK: } @@ -323,9 +319,8 @@ func @addUnrankedUnranked( // CHECK-LABEL: func @addUnrankedUnranked( // CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>, // CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -// CHECK: %[[LHS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[LHS_SHAPE]] : tensor -// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_EXTENT_TENSOR]] +// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor +// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor -> index // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index // Handle scalar LHS case @@ -334,9 +329,8 @@ func @addUnrankedUnranked( // CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor, tensor<*xf32>) -> tensor<*xf32> // CHECK: scf.yield %[[VAL_10]] : tensor<*xf32> // CHECK: } else { -// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -// CHECK: %[[RHS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[RHS_SHAPE]] : tensor -// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_EXTENT_TENSOR]] +// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor +// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor -> index // CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index // Handle scalar RHS case // CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { @@ -344,7 +338,7 @@ func @addUnrankedUnranked( // CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor) -> tensor<*xf32> // CHECK: scf.yield %[[VAL_16]] : tensor<*xf32> // CHECK: } else { -// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_EXTENT_TENSOR]], %[[RHS_EXTENT_TENSOR]] : tensor, tensor +// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor // Handle scalar RHS case // CHECK: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) { // CHECK: %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32>