[MLIR][HLO] Remove redundant casts from unranked to ranked transformation
The transformation of unranked to ranked operations no longer generates cast operations for shapes and sizes. Instead, we use the newly introduced support for extent tensor and index types directly. PiperOrigin-RevId: 325057440
This commit is contained in:
parent
37c36a4389
commit
5d3cc2105e
|
@ -46,7 +46,6 @@ namespace {
|
|||
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
||||
|
||||
// TODO(frgossen): Make it variadic.
|
||||
template <typename OpTy>
|
||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
||||
|
@ -75,28 +74,24 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|||
|
||||
// Generate IR to flatten the operand.
|
||||
auto loc = op.getLoc();
|
||||
Value shape = rewriter.create<shape::ShapeOfOp>(loc, operand);
|
||||
Value numElements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
||||
Value numElementsAsIndex =
|
||||
rewriter.create<shape::SizeToIndexOp>(loc, numElements);
|
||||
Value flatShapeAsDimTensor =
|
||||
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
|
||||
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
|
||||
Type indexTy = rewriter.getIndexType();
|
||||
Value numElements =
|
||||
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
|
||||
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
|
||||
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
operandTy.getElementType());
|
||||
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, flatTensorTy, operand, flatShapeAsDimTensor);
|
||||
loc, flatTensorTy, operand, flatShape);
|
||||
|
||||
// Generate IR for the actual operation.
|
||||
Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand);
|
||||
|
||||
// Generate IR to restore the original shape.
|
||||
auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
Value shapeAsExtentTensor =
|
||||
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
|
||||
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, operandTy, flatResult, shapeAsExtentTensor);
|
||||
rewriter.replaceOp(op, result);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, operandTy,
|
||||
flatResult, shape);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -122,17 +117,18 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|||
}
|
||||
|
||||
// Flatten operands.
|
||||
Type shapeTy = shape::ShapeType::get(rewriter.getContext());
|
||||
auto loc = op.getLoc();
|
||||
Value shapeLhs = rewriter.create<shape::ShapeOfOp>(loc, op.lhs());
|
||||
Value shapeRhs = rewriter.create<shape::ShapeOfOp>(loc, op.rhs());
|
||||
Value shape = rewriter.create<shape::AnyOp>(loc, shapeTy,
|
||||
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
|
||||
Value shapeLhs =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.lhs());
|
||||
Value shapeRhs =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.rhs());
|
||||
Value shape = rewriter.create<shape::AnyOp>(loc, extentTensorTy,
|
||||
ValueRange{shapeLhs, shapeRhs});
|
||||
Value numElements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
||||
Value numElementsAsIndex =
|
||||
rewriter.create<shape::SizeToIndexOp>(loc, numElements);
|
||||
Value flatShape =
|
||||
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
|
||||
Type indexTy = rewriter.getIndexType();
|
||||
Value numElements =
|
||||
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
|
||||
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
|
||||
TensorType lhsTy = op.lhs().getType().template cast<TensorType>();
|
||||
Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
lhsTy.getElementType());
|
||||
|
@ -148,13 +144,8 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|||
Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs);
|
||||
|
||||
// Restore original shape.
|
||||
auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
Value shapeAsExtentTensor =
|
||||
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
|
||||
Value result = rewriter.create<DynamicReshapeOp>(
|
||||
loc, op.getType(), flatResult, shapeAsExtentTensor);
|
||||
rewriter.replaceOp(op, result);
|
||||
rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
|
||||
shape);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -7,8 +7,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
|
|||
// Flatten operand shape.
|
||||
%shape = shape.shape_of %a : tensor<*xf32> -> tensor<?xindex>
|
||||
%num_elements = shape.num_elements %shape : tensor<?xindex> -> index
|
||||
%num_elements_as_index = shape.size_to_index %num_elements : index
|
||||
%flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex>
|
||||
%flat_shape = tensor_from_elements(%num_elements) : tensor<1xindex>
|
||||
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
|
||||
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
|
||||
|
@ -16,8 +15,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
|
|||
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
|
||||
// Restore original shape.
|
||||
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex> -> tensor<?xindex>
|
||||
%b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
|
||||
%b = "mhlo.dynamic_reshape"(%flat_b, %shape)
|
||||
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
|
||||
return %b : tensor<*xf32>
|
||||
|
@ -29,14 +27,12 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-LABEL: @sqrt
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>)
|
||||
func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
||||
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
|
||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
|
||||
%b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %b : tensor<*xf32>
|
||||
|
@ -75,13 +71,11 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]]
|
||||
// CHECK: %[[SHAPE:.*]] = "shape.any"(%[[SHAPE_A]], %[[SHAPE_B]])
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
||||
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
|
||||
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex>
|
||||
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
|
||||
// CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]]
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<*xf32>
|
||||
%result = mhlo.add %a, %b : tensor<*xf32>
|
||||
return %result : tensor<*xf32>
|
||||
|
|
Loading…
Reference in New Issue