[MLIR][HLO] Remove redundant casts from unranked to ranked transformation
The transformation of unranked to ranked operations no longer generates cast operations for shapes and sizes. Instead, we use the newly introduced support for extent tensor and index types directly. PiperOrigin-RevId: 325057440
This commit is contained in:
parent
37c36a4389
commit
5d3cc2105e
|
@ -46,7 +46,6 @@ namespace {
|
||||||
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
||||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
||||||
|
|
||||||
// TODO(frgossen): Make it variadic.
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||||
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
||||||
|
@ -75,28 +74,24 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||||
|
|
||||||
// Generate IR to flatten the operand.
|
// Generate IR to flatten the operand.
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
Value shape = rewriter.create<shape::ShapeOfOp>(loc, operand);
|
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
|
||||||
Value numElements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
Value shape =
|
||||||
Value numElementsAsIndex =
|
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
|
||||||
rewriter.create<shape::SizeToIndexOp>(loc, numElements);
|
Type indexTy = rewriter.getIndexType();
|
||||||
Value flatShapeAsDimTensor =
|
Value numElements =
|
||||||
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
|
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
|
||||||
|
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
|
||||||
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||||
operandTy.getElementType());
|
operandTy.getElementType());
|
||||||
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
|
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
loc, flatTensorTy, operand, flatShapeAsDimTensor);
|
loc, flatTensorTy, operand, flatShape);
|
||||||
|
|
||||||
// Generate IR for the actual operation.
|
// Generate IR for the actual operation.
|
||||||
Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand);
|
Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand);
|
||||||
|
|
||||||
// Generate IR to restore the original shape.
|
// Generate IR to restore the original shape.
|
||||||
auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, operandTy,
|
||||||
rewriter.getIndexType());
|
flatResult, shape);
|
||||||
Value shapeAsExtentTensor =
|
|
||||||
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
|
|
||||||
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, operandTy, flatResult, shapeAsExtentTensor);
|
|
||||||
rewriter.replaceOp(op, result);
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -122,17 +117,18 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flatten operands.
|
// Flatten operands.
|
||||||
Type shapeTy = shape::ShapeType::get(rewriter.getContext());
|
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
Value shapeLhs = rewriter.create<shape::ShapeOfOp>(loc, op.lhs());
|
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
|
||||||
Value shapeRhs = rewriter.create<shape::ShapeOfOp>(loc, op.rhs());
|
Value shapeLhs =
|
||||||
Value shape = rewriter.create<shape::AnyOp>(loc, shapeTy,
|
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.lhs());
|
||||||
|
Value shapeRhs =
|
||||||
|
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.rhs());
|
||||||
|
Value shape = rewriter.create<shape::AnyOp>(loc, extentTensorTy,
|
||||||
ValueRange{shapeLhs, shapeRhs});
|
ValueRange{shapeLhs, shapeRhs});
|
||||||
Value numElements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
Type indexTy = rewriter.getIndexType();
|
||||||
Value numElementsAsIndex =
|
Value numElements =
|
||||||
rewriter.create<shape::SizeToIndexOp>(loc, numElements);
|
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
|
||||||
Value flatShape =
|
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
|
||||||
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
|
|
||||||
TensorType lhsTy = op.lhs().getType().template cast<TensorType>();
|
TensorType lhsTy = op.lhs().getType().template cast<TensorType>();
|
||||||
Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||||
lhsTy.getElementType());
|
lhsTy.getElementType());
|
||||||
|
@ -148,13 +144,8 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||||
Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs);
|
Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs);
|
||||||
|
|
||||||
// Restore original shape.
|
// Restore original shape.
|
||||||
auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
|
||||||
rewriter.getIndexType());
|
shape);
|
||||||
Value shapeAsExtentTensor =
|
|
||||||
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
|
|
||||||
Value result = rewriter.create<DynamicReshapeOp>(
|
|
||||||
loc, op.getType(), flatResult, shapeAsExtentTensor);
|
|
||||||
rewriter.replaceOp(op, result);
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,8 +7,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// Flatten operand shape.
|
// Flatten operand shape.
|
||||||
%shape = shape.shape_of %a : tensor<*xf32> -> tensor<?xindex>
|
%shape = shape.shape_of %a : tensor<*xf32> -> tensor<?xindex>
|
||||||
%num_elements = shape.num_elements %shape : tensor<?xindex> -> index
|
%num_elements = shape.num_elements %shape : tensor<?xindex> -> index
|
||||||
%num_elements_as_index = shape.size_to_index %num_elements : index
|
%flat_shape = tensor_from_elements(%num_elements) : tensor<1xindex>
|
||||||
%flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex>
|
|
||||||
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
|
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
|
||||||
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
|
||||||
|
@ -16,8 +15,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
|
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
|
||||||
// Restore original shape.
|
// Restore original shape.
|
||||||
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex> -> tensor<?xindex>
|
%b = "mhlo.dynamic_reshape"(%flat_b, %shape)
|
||||||
%b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
|
|
||||||
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
|
||||||
return %b : tensor<*xf32>
|
return %b : tensor<*xf32>
|
||||||
|
@ -29,14 +27,12 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-LABEL: @sqrt
|
// CHECK-LABEL: @sqrt
|
||||||
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>)
|
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>)
|
||||||
func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
|
func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32>
|
// CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||||
// CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex>
|
||||||
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
|
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
|
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
|
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
|
||||||
%b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
|
%b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %b : tensor<*xf32>
|
return %b : tensor<*xf32>
|
||||||
|
@ -75,13 +71,11 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]]
|
// CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]]
|
||||||
// CHECK: %[[SHAPE:.*]] = "shape.any"(%[[SHAPE_A]], %[[SHAPE_B]])
|
// CHECK: %[[SHAPE:.*]] = "shape.any"(%[[SHAPE_A]], %[[SHAPE_B]])
|
||||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||||
// CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex>
|
||||||
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
|
|
||||||
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
|
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
|
||||||
// CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]]
|
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK: return %[[RESULT]] : tensor<*xf32>
|
// CHECK: return %[[RESULT]] : tensor<*xf32>
|
||||||
%result = mhlo.add %a, %b : tensor<*xf32>
|
%result = mhlo.add %a, %b : tensor<*xf32>
|
||||||
return %result : tensor<*xf32>
|
return %result : tensor<*xf32>
|
||||||
|
|
Loading…
Reference in New Issue