Add support for lowering mhlo.iota/dynamic_iota to Linalg on unsigned types.

PiperOrigin-RevId: 377956338
This commit is contained in:
Hanhan Wang 2021-06-07 10:58:21 -07:00 committed by TensorFlow MLIR Team
parent 5315997402
commit 25b93c8d66
2 changed files with 61 additions and 4 deletions

View File

@ -922,6 +922,8 @@ class IotaConverter : public OpConversionPattern<OpTy> {
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
ShapedType result_shaped_type = GetHloOpResultType<isLHLO>(iota_op); ShapedType result_shaped_type = GetHloOpResultType<isLHLO>(iota_op);
if (!result_shaped_type) return failure(); if (!result_shaped_type) return failure();
result_shaped_type = this->typeConverter->convertType(result_shaped_type)
.template dyn_cast<ShapedType>();
auto result_element_type = result_shaped_type.getElementType(); auto result_element_type = result_shaped_type.getElementType();
if (!result_element_type.isSignlessIntOrFloat()) return failure(); if (!result_element_type.isSignlessIntOrFloat()) return failure();

View File

@ -906,8 +906,8 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
// ----- // -----
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @iota // CHECK-LABEL: func @iota_f32
func @iota() -> tensor<7x10xf32> { func @iota_f32() -> tensor<7x10xf32> {
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>) %result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>)
return %result : tensor<7x10xf32> return %result : tensor<7x10xf32>
} }
@ -922,10 +922,43 @@ func @iota() -> tensor<7x10xf32> {
// ----- // -----
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @iota_i32
func @iota_i32() -> tensor<7x10xi32> {
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xi32>)
return %result : tensor<7x10xi32>
}
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%{{.*}}: i32):
// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32
// CHECK-NEXT: linalg.yield %[[INT_CAST]] : i32
// -----
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @iota_ui32
func @iota_ui32() -> tensor<7x10xui32> {
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xui32>)
return %result : tensor<7x10xui32>
}
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%{{.*}}: i32):
// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32
// CHECK-NEXT: linalg.yield %[[INT_CAST]] : i32
// CHECK: unrealized_conversion_cast %{{.*}} : tensor<7x10xi32> to tensor<7x10xui32>
// -----
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @iota // CHECK-LABEL: func @dynamic_iota_f32
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32> // CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
func @iota(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> { func @dynamic_iota_f32(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
%result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xf32>) %result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xf32>)
return %result : tensor<?x?x8xf32> return %result : tensor<?x?x8xf32>
} }
@ -944,6 +977,28 @@ func @iota(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
// ----- // -----
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @dyanmic_iota_ui32
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
func @dyanmic_iota_ui32(%shape: tensor<?xi32>) -> tensor<?x?x8xui32> {
%result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xui32>)
return %result : tensor<?x?x8xui32>
}
// CHECK: %[[E1:.*]] = tensor.extract %[[SHAPE]][%c0] : tensor<?xi32>
// CHECK: %[[I1:.*]] = index_cast %[[E1]] : i32 to index
// CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor<?xi32>
// CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index
// CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor<?x?x8xi32>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%{{.*}}: i32):
// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : i32
// CHECK: unrealized_conversion_cast %{{.*}} : tensor<?x?x8xi32> to tensor<?x?x8xui32>
// -----
func @shift_left(%lhs: tensor<2x2xi32>, func @shift_left(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
%result = "mhlo.shift_left"(%lhs, %rhs) %result = "mhlo.shift_left"(%lhs, %rhs)