Remove unnecessary conversions between Shape and ExtentTensor.

PiperOrigin-RevId: 323981215
This commit is contained in:
Tres Popp 2020-07-30 04:59:33 -07:00 committed by TensorFlow MLIR Team
parent ffef8d6593
commit cce4bddf4b
2 changed files with 17 additions and 37 deletions

View File

@ -197,8 +197,8 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
Value size = rewriter.create<shape::SizeToIndexOp>(loc, num_elements);
Value size_tensor = rewriter.create<TensorFromElementsOp>(loc, size);
Value size_tensor =
rewriter.create<TensorFromElementsOp>(loc, num_elements);
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
loc, RankedTensorType::get({-1}, result_type.getElementType()),
lhs_is_scalar ? rhs : lhs, size_tensor);
@ -211,10 +211,8 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs());
// Reshape the result back into an unranked tensor.
Value shape_tensor = rewriter.create<shape::ToExtentTensorOp>(
loc, RankedTensorType::get({-1}, rewriter.getIndexType()), shape);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
computed, shape_tensor);
computed, shape);
return success();
}
@ -278,18 +276,10 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
//
// See if shapes are equal.
OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder();
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value shape_of_lhs =
else_no_scalars_builder.create<shape::ToExtentTensorOp>(
loc, extent_tensor_type,
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs)
.getResult());
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
Value shape_of_rhs =
else_no_scalars_builder.create<shape::ToExtentTensorOp>(
loc, extent_tensor_type,
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs)
.getResult());
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
loc, shape_of_lhs, shape_of_rhs);
@ -319,12 +309,8 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
// tensor.
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
auto loc = op.getLoc();
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value shape_of_tensor = rewriter.create<shape::ToExtentTensorOp>(
loc, extent_tensor_type,
rewriter.create<shape::ShapeOfOp>(loc, tensor).getResult());
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
Value rank_tensor = rewriter.create<shape::RankOp>(
loc, rewriter.getIndexType(), shape_of_tensor);
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,

View File

@ -252,9 +252,8 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3
// First handle the dynamic reshaping of the unranked operand
// to a 1D tensor.
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]]
// CHECK: %[[NUM_ELEMENTS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_INDEX]]) : tensor<1xindex>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex>
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// The assuming region is part of the second stage of lowering
// with ranked broadcasting logic.
@ -272,8 +271,7 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3
// CHECK: }
// As part of the unranked logic, the result is reshaped back
// to an unranked tensor.
// CHECK: %[[SHAPE_2:.*]] = shape.to_extent_tensor %[[SHAPE_1]] : tensor<?xindex> -> tensor<?xindex>
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_2]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK: }
@ -289,9 +287,8 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3
// First handle the dynamic reshaping of the unranked operand
// to a 1D tensor.
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]]
// CHECK: %[[NUM_ELEMENTS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_INDEX]]) : tensor<1xindex>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex>
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// The assuming region is part of the second stage of lowering
// with ranked broadcasting logic.
@ -307,8 +304,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3
// CHECK: }
// As part of the unranked logic, the result is reshaped back
// to an unranked tensor.
// CHECK: %[[SHAPE_2:.*]] = shape.to_extent_tensor %[[SHAPE_0]]
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_2]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK: }
@ -323,9 +319,8 @@ func @addUnrankedUnranked(
// CHECK-LABEL: func @addUnrankedUnranked(
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32>
// CHECK: %[[LHS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[LHS_SHAPE]] : tensor<?xindex>
// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_EXTENT_TENSOR]]
// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
// Handle scalar LHS case
@ -334,9 +329,8 @@ func @addUnrankedUnranked(
// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32>
// CHECK: %[[RHS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[RHS_SHAPE]] : tensor<?xindex>
// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_EXTENT_TENSOR]]
// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
// Handle scalar RHS case
// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
@ -344,7 +338,7 @@ func @addUnrankedUnranked(
// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_EXTENT_TENSOR]], %[[RHS_EXTENT_TENSOR]] : tensor<?xindex>, tensor<?xindex>
// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
// Handle scalar RHS case
// CHECK: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
// CHECK: %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32>