Add support for lowering mhlo.iota/dynamic_iota to Linalg on unsigned types.
PiperOrigin-RevId: 377956338
This commit is contained in:
parent
5315997402
commit
25b93c8d66
|
@ -922,6 +922,8 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
ShapedType result_shaped_type = GetHloOpResultType<isLHLO>(iota_op);
|
ShapedType result_shaped_type = GetHloOpResultType<isLHLO>(iota_op);
|
||||||
if (!result_shaped_type) return failure();
|
if (!result_shaped_type) return failure();
|
||||||
|
result_shaped_type = this->typeConverter->convertType(result_shaped_type)
|
||||||
|
.template dyn_cast<ShapedType>();
|
||||||
|
|
||||||
auto result_element_type = result_shaped_type.getElementType();
|
auto result_element_type = result_shaped_type.getElementType();
|
||||||
if (!result_element_type.isSignlessIntOrFloat()) return failure();
|
if (!result_element_type.isSignlessIntOrFloat()) return failure();
|
||||||
|
|
|
@ -906,8 +906,8 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-LABEL: func @iota
|
// CHECK-LABEL: func @iota_f32
|
||||||
func @iota() -> tensor<7x10xf32> {
|
func @iota_f32() -> tensor<7x10xf32> {
|
||||||
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>)
|
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>)
|
||||||
return %result : tensor<7x10xf32>
|
return %result : tensor<7x10xf32>
|
||||||
}
|
}
|
||||||
|
@ -922,10 +922,43 @@ func @iota() -> tensor<7x10xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
// CHECK-LABEL: func @iota_i32
|
||||||
|
func @iota_i32() -> tensor<7x10xi32> {
|
||||||
|
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xi32>)
|
||||||
|
return %result : tensor<7x10xi32>
|
||||||
|
}
|
||||||
|
// CHECK: linalg.init_tensor
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
||||||
|
// CHECK-NEXT: ^bb0(%{{.*}}: i32):
|
||||||
|
// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1
|
||||||
|
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[INT_CAST]] : i32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
// CHECK-LABEL: func @iota_ui32
|
||||||
|
func @iota_ui32() -> tensor<7x10xui32> {
|
||||||
|
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xui32>)
|
||||||
|
return %result : tensor<7x10xui32>
|
||||||
|
}
|
||||||
|
// CHECK: linalg.init_tensor
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
||||||
|
// CHECK-NEXT: ^bb0(%{{.*}}: i32):
|
||||||
|
// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1
|
||||||
|
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[INT_CAST]] : i32
|
||||||
|
// CHECK: unrealized_conversion_cast %{{.*}} : tensor<7x10xi32> to tensor<7x10xui32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
// CHECK-LABEL: func @iota
|
// CHECK-LABEL: func @dynamic_iota_f32
|
||||||
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
|
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
|
||||||
func @iota(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
|
func @dynamic_iota_f32(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
|
||||||
%result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xf32>)
|
%result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xf32>)
|
||||||
return %result : tensor<?x?x8xf32>
|
return %result : tensor<?x?x8xf32>
|
||||||
}
|
}
|
||||||
|
@ -944,6 +977,28 @@ func @iota(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK-LABEL: func @dyanmic_iota_ui32
|
||||||
|
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
|
||||||
|
func @dyanmic_iota_ui32(%shape: tensor<?xi32>) -> tensor<?x?x8xui32> {
|
||||||
|
%result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xui32>)
|
||||||
|
return %result : tensor<?x?x8xui32>
|
||||||
|
}
|
||||||
|
// CHECK: %[[E1:.*]] = tensor.extract %[[SHAPE]][%c0] : tensor<?xi32>
|
||||||
|
// CHECK: %[[I1:.*]] = index_cast %[[E1]] : i32 to index
|
||||||
|
// CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor<?xi32>
|
||||||
|
// CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index
|
||||||
|
// CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor<?x?x8xi32>
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
||||||
|
// CHECK-NEXT: ^bb0(%{{.*}}: i32):
|
||||||
|
// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1
|
||||||
|
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : i32
|
||||||
|
// CHECK: unrealized_conversion_cast %{{.*}} : tensor<?x?x8xi32> to tensor<?x?x8xui32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @shift_left(%lhs: tensor<2x2xi32>,
|
func @shift_left(%lhs: tensor<2x2xi32>,
|
||||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||||
%result = "mhlo.shift_left"(%lhs, %rhs)
|
%result = "mhlo.shift_left"(%lhs, %rhs)
|
||||||
|
|
Loading…
Reference in New Issue