[MLIR][HLO] Remove redundant casts from unranked to ranked transformation

The transformation of unranked to ranked operations no longer generates cast
operations for shapes and sizes. Instead, we use the newly introduced support
for extent tensor and index types directly.

PiperOrigin-RevId: 325057440
This commit is contained in:
A. Unique TensorFlower 2020-08-05 11:10:20 -07:00 committed by TensorFlow MLIR Team
parent 37c36a4389
commit 5d3cc2105e
2 changed files with 29 additions and 44 deletions

View File

@ -46,7 +46,6 @@ namespace {
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
// TODO(frgossen): Make it variadic.
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
@ -75,28 +74,24 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
// Generate IR to flatten the operand.
auto loc = op.getLoc();
Value shape = rewriter.create<shape::ShapeOfOp>(loc, operand);
Value numElements = rewriter.create<shape::NumElementsOp>(loc, shape);
Value numElementsAsIndex =
rewriter.create<shape::SizeToIndexOp>(loc, numElements);
Value flatShapeAsDimTensor =
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
Type indexTy = rewriter.getIndexType();
Value numElements =
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
operandTy.getElementType());
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
loc, flatTensorTy, operand, flatShapeAsDimTensor);
loc, flatTensorTy, operand, flatShape);
// Generate IR for the actual operation.
Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand);
// Generate IR to restore the original shape.
auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value shapeAsExtentTensor =
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
loc, operandTy, flatResult, shapeAsExtentTensor);
rewriter.replaceOp(op, result);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, operandTy,
flatResult, shape);
return success();
}
@ -122,17 +117,18 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
}
// Flatten operands.
Type shapeTy = shape::ShapeType::get(rewriter.getContext());
auto loc = op.getLoc();
Value shapeLhs = rewriter.create<shape::ShapeOfOp>(loc, op.lhs());
Value shapeRhs = rewriter.create<shape::ShapeOfOp>(loc, op.rhs());
Value shape = rewriter.create<shape::AnyOp>(loc, shapeTy,
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
Value shapeLhs =
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.lhs());
Value shapeRhs =
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.rhs());
Value shape = rewriter.create<shape::AnyOp>(loc, extentTensorTy,
ValueRange{shapeLhs, shapeRhs});
Value numElements = rewriter.create<shape::NumElementsOp>(loc, shape);
Value numElementsAsIndex =
rewriter.create<shape::SizeToIndexOp>(loc, numElements);
Value flatShape =
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
Type indexTy = rewriter.getIndexType();
Value numElements =
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
TensorType lhsTy = op.lhs().getType().template cast<TensorType>();
Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
lhsTy.getElementType());
@ -148,13 +144,8 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs);
// Restore original shape.
auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value shapeAsExtentTensor =
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
Value result = rewriter.create<DynamicReshapeOp>(
loc, op.getType(), flatResult, shapeAsExtentTensor);
rewriter.replaceOp(op, result);
rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
shape);
return success();
}

View File

@ -7,8 +7,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
// Flatten operand shape.
%shape = shape.shape_of %a : tensor<*xf32> -> tensor<?xindex>
%num_elements = shape.num_elements %shape : tensor<?xindex> -> index
%num_elements_as_index = shape.size_to_index %num_elements : index
%flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex>
%flat_shape = tensor_from_elements(%num_elements) : tensor<1xindex>
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
@ -16,8 +15,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
// Restore original shape.
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex> -> tensor<?xindex>
%b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
%b = "mhlo.dynamic_reshape"(%flat_b, %shape)
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
return %b : tensor<*xf32>
@ -29,14 +27,12 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-LABEL: @sqrt
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>)
func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32>
// CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex>
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
%b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
return %b : tensor<*xf32>
@ -75,13 +71,11 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]]
// CHECK: %[[SHAPE:.*]] = "shape.any"(%[[SHAPE_A]], %[[SHAPE_B]])
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
// CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]]
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESULT]] : tensor<*xf32>
%result = mhlo.add %a, %b : tensor<*xf32>
return %result : tensor<*xf32>