Remove unnecessary conversions between Shape and ExtentTensor.
PiperOrigin-RevId: 323981215
This commit is contained in:
parent
ffef8d6593
commit
cce4bddf4b
|
@ -197,8 +197,8 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||||
Value shape =
|
Value shape =
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
|
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
|
||||||
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
||||||
Value size = rewriter.create<shape::SizeToIndexOp>(loc, num_elements);
|
Value size_tensor =
|
||||||
Value size_tensor = rewriter.create<TensorFromElementsOp>(loc, size);
|
rewriter.create<TensorFromElementsOp>(loc, num_elements);
|
||||||
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
loc, RankedTensorType::get({-1}, result_type.getElementType()),
|
loc, RankedTensorType::get({-1}, result_type.getElementType()),
|
||||||
lhs_is_scalar ? rhs : lhs, size_tensor);
|
lhs_is_scalar ? rhs : lhs, size_tensor);
|
||||||
|
@ -211,10 +211,8 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||||
loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs());
|
loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs());
|
||||||
|
|
||||||
// Reshape the result back into an unranked tensor.
|
// Reshape the result back into an unranked tensor.
|
||||||
Value shape_tensor = rewriter.create<shape::ToExtentTensorOp>(
|
|
||||||
loc, RankedTensorType::get({-1}, rewriter.getIndexType()), shape);
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
||||||
computed, shape_tensor);
|
computed, shape);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -278,18 +276,10 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
//
|
//
|
||||||
// See if shapes are equal.
|
// See if shapes are equal.
|
||||||
OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder();
|
OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder();
|
||||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
|
||||||
rewriter.getIndexType());
|
|
||||||
Value shape_of_lhs =
|
Value shape_of_lhs =
|
||||||
else_no_scalars_builder.create<shape::ToExtentTensorOp>(
|
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||||
loc, extent_tensor_type,
|
|
||||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs)
|
|
||||||
.getResult());
|
|
||||||
Value shape_of_rhs =
|
Value shape_of_rhs =
|
||||||
else_no_scalars_builder.create<shape::ToExtentTensorOp>(
|
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||||
loc, extent_tensor_type,
|
|
||||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs)
|
|
||||||
.getResult());
|
|
||||||
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
|
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
|
||||||
loc, shape_of_lhs, shape_of_rhs);
|
loc, shape_of_lhs, shape_of_rhs);
|
||||||
|
|
||||||
|
@ -319,12 +309,8 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
// tensor.
|
// tensor.
|
||||||
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
|
||||||
rewriter.getIndexType());
|
|
||||||
|
|
||||||
Value shape_of_tensor = rewriter.create<shape::ToExtentTensorOp>(
|
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
|
||||||
loc, extent_tensor_type,
|
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, tensor).getResult());
|
|
||||||
Value rank_tensor = rewriter.create<shape::RankOp>(
|
Value rank_tensor = rewriter.create<shape::RankOp>(
|
||||||
loc, rewriter.getIndexType(), shape_of_tensor);
|
loc, rewriter.getIndexType(), shape_of_tensor);
|
||||||
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
|
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
|
||||||
|
|
|
@ -252,9 +252,8 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3
|
||||||
// First handle the dynamic reshaping of the unranked operand
|
// First handle the dynamic reshaping of the unranked operand
|
||||||
// to a 1D tensor.
|
// to a 1D tensor.
|
||||||
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
|
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
|
||||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]]
|
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
|
||||||
// CHECK: %[[NUM_ELEMENTS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex>
|
||||||
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_INDEX]]) : tensor<1xindex>
|
|
||||||
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// The assuming region is part of the second stage of lowering
|
// The assuming region is part of the second stage of lowering
|
||||||
// with ranked broadcasting logic.
|
// with ranked broadcasting logic.
|
||||||
|
@ -272,8 +271,7 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// As part of the unranked logic, the result is reshaped back
|
// As part of the unranked logic, the result is reshaped back
|
||||||
// to an unranked tensor.
|
// to an unranked tensor.
|
||||||
// CHECK: %[[SHAPE_2:.*]] = shape.to_extent_tensor %[[SHAPE_1]] : tensor<?xindex> -> tensor<?xindex>
|
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_2]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
|
@ -289,9 +287,8 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3
|
||||||
// First handle the dynamic reshaping of the unranked operand
|
// First handle the dynamic reshaping of the unranked operand
|
||||||
// to a 1D tensor.
|
// to a 1D tensor.
|
||||||
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
|
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
|
||||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]]
|
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
|
||||||
// CHECK: %[[NUM_ELEMENTS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex>
|
||||||
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_INDEX]]) : tensor<1xindex>
|
|
||||||
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// The assuming region is part of the second stage of lowering
|
// The assuming region is part of the second stage of lowering
|
||||||
// with ranked broadcasting logic.
|
// with ranked broadcasting logic.
|
||||||
|
@ -307,8 +304,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// As part of the unranked logic, the result is reshaped back
|
// As part of the unranked logic, the result is reshaped back
|
||||||
// to an unranked tensor.
|
// to an unranked tensor.
|
||||||
// CHECK: %[[SHAPE_2:.*]] = shape.to_extent_tensor %[[SHAPE_0]]
|
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_2]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
|
@ -323,9 +319,8 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-LABEL: func @addUnrankedUnranked(
|
// CHECK-LABEL: func @addUnrankedUnranked(
|
||||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32>
|
// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
// CHECK: %[[LHS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[LHS_SHAPE]] : tensor<?xindex>
|
// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||||
// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_EXTENT_TENSOR]]
|
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
|
// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
|
||||||
// Handle scalar LHS case
|
// Handle scalar LHS case
|
||||||
|
@ -334,9 +329,8 @@ func @addUnrankedUnranked(
|
||||||
// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32>
|
// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32>
|
||||||
// CHECK: } else {
|
// CHECK: } else {
|
||||||
// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32>
|
// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
// CHECK: %[[RHS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[RHS_SHAPE]] : tensor<?xindex>
|
// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||||
// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_EXTENT_TENSOR]]
|
|
||||||
// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
|
// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
|
||||||
// Handle scalar RHS case
|
// Handle scalar RHS case
|
||||||
// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||||
|
@ -344,7 +338,7 @@ func @addUnrankedUnranked(
|
||||||
// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||||
// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32>
|
// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32>
|
||||||
// CHECK: } else {
|
// CHECK: } else {
|
||||||
// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_EXTENT_TENSOR]], %[[RHS_EXTENT_TENSOR]] : tensor<?xindex>, tensor<?xindex>
|
// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
||||||
// Handle scalar RHS case
|
// Handle scalar RHS case
|
||||||
// CHECK: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
|
// CHECK: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
|
||||||
// CHECK: %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32>
|
// CHECK: %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32>
|
||||||
|
|
Loading…
Reference in New Issue