diff --git a/include/mlir-hlo/utils/broadcast_utils.h b/include/mlir-hlo/utils/broadcast_utils.h index 1c57073..1e24042 100644 --- a/include/mlir-hlo/utils/broadcast_utils.h +++ b/include/mlir-hlo/utils/broadcast_utils.h @@ -38,12 +38,10 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, // Emits shape dialect ops to compute the result shape for a broadcasting // binary elementwise op which broadcasts according to "numpy" semantics -// (see above), returning a `shape.shape` or an extent tensor of the resulting -// shape. The result should only be an extent tensor in contexts that ensure -// both operands to be broadcastable. -Value ComputeBinaryElementwiseBroadcastingResultExtents( - Location loc, Value lhs, Value rhs, OpBuilder& builder, - bool unsafe_as_extent_tensor); +// (see above), returning an extents tensor of the resulting shape. +Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, + Value rhs, + OpBuilder& builder); } // namespace hlo } // namespace mlir diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 81389c3..99ed8bc 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -151,7 +151,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( } Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents( - loc, lhs, rhs, builder, /*unsafe_as_extent_tensor=*/false); + loc, lhs, rhs, builder); if (!computed_shape) return failure(); reifiedReturnShapes.push_back(computed_shape); return success(); diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index c2db488..adbd2e5 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -124,8 +124,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); Value result_extents = - hlo::ComputeBinaryElementwiseBroadcastingResultExtents( - loc, lhs, rhs, rewriter, /*unsafe_as_extent_tensor=*/true); + hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, + rewriter); // Note that we unconditionally emit DynamicBroadcastInDim ops and let // downstream canonicalizations fold them away if possible. This is diff --git a/lib/utils/broadcast_utils.cc b/lib/utils/broadcast_utils.cc index 71b1a4e..a3ce4d4 100644 --- a/lib/utils/broadcast_utils.cc +++ b/lib/utils/broadcast_utils.cc @@ -20,7 +20,6 @@ limitations under the License. #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/StandardTypes.h" @@ -47,9 +46,9 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, broadcast_dims.getIntValues().begin()); } -Value ComputeBinaryElementwiseBroadcastingResultExtents( - Location loc, Value lhs, Value rhs, OpBuilder& builder, - bool unsafe_as_extent_tensor) { +Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, + Value rhs, + OpBuilder& builder) { auto lhs_type = lhs.getType().dyn_cast(); auto rhs_type = rhs.getType().dyn_cast(); if (!lhs_type || !rhs_type) { @@ -58,22 +57,15 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents( return nullptr; } + int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); Value lhs_shape_v = builder.createOrFold(loc, lhs); Value rhs_shape_v = builder.createOrFold(loc, rhs); - - if (unsafe_as_extent_tensor) { - int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); - Value result_shape_v = builder.createOrFold( - loc, shape::getExtentTensorType(builder.getContext()), lhs_shape_v, - rhs_shape_v, nullptr /* error */); - return builder.createOrFold( - loc, RankedTensorType::get({result_rank}, builder.getIndexType()), - result_shape_v); - } - - return builder.createOrFold( - loc, builder.getType(), lhs_shape_v, rhs_shape_v, - nullptr /* error */); + Value result_shape_v = builder.createOrFold( + loc, shape::ShapeType::get(builder.getContext()), lhs_shape_v, + rhs_shape_v, nullptr /* error */); + return builder.createOrFold( + loc, RankedTensorType::get({result_rank}, builder.getIndexType()), + result_shape_v); } } // namespace hlo diff --git a/tests/chlo_infer_shape_type_methods.mlir b/tests/chlo_infer_shape_type_methods.mlir index d226c92..99aab53 100644 --- a/tests/chlo_infer_shape_type_methods.mlir +++ b/tests/chlo_infer_shape_type_methods.mlir @@ -5,14 +5,15 @@ // only test reification on an examplar op. // CHECK-SAME: %[[ARG0:.+]]: tensor, // CHECK-SAME: %[[ARG1:.+]]: tensor -func @broadcast_add(%arg0: tensor, %arg1: tensor) -> !shape.shape { +func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xindex> { // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] : tensor, tensor -> !shape.shape - // CHECK: return %[[BCAST_S]] : !shape.shape + // CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] + // CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]] + // CHECK: return %[[EXTENTS]] %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor - %1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor) -> !shape.shape - return %1 : !shape.shape + %1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> + return %1 : tensor<1xindex> } // ----- diff --git a/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tests/chlo_legalize_to_hlo_broadcasts.mlir index 9670372..c08ead5 100644 --- a/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -19,7 +19,7 @@ func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor to tensor<2xindex> + // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]] @@ -40,7 +40,7 @@ func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> t // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor to tensor<2xindex> + // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> @@ -61,7 +61,7 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor to tensor<2xindex> + // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor @@ -263,7 +263,7 @@ func @addScalarUnranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf3 // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { // CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape [] // CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]] -// CHECK: %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor to tensor<1xindex> +// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : !shape.shape -> tensor<1xindex> // CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor @@ -296,7 +296,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf3 // CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]] // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { -// CHECK: %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]] +// CHECK: %[[ASTENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]] // CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor