diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index d72ee6c..53cad4c 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -922,6 +922,8 @@ class IotaConverter : public OpConversionPattern { ConversionPatternRewriter& rewriter) const final { ShapedType result_shaped_type = GetHloOpResultType(iota_op); if (!result_shaped_type) return failure(); + result_shaped_type = this->typeConverter->convertType(result_shaped_type) + .template dyn_cast(); auto result_element_type = result_shaped_type.getElementType(); if (!result_element_type.isSignlessIntOrFloat()) return failure(); diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index a756a3f..37f778a 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -906,8 +906,8 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { // ----- // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func @iota -func @iota() -> tensor<7x10xf32> { +// CHECK-LABEL: func @iota_f32 +func @iota_f32() -> tensor<7x10xf32> { %result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>) return %result : tensor<7x10xf32> } @@ -922,10 +922,43 @@ func @iota() -> tensor<7x10xf32> { // ----- +// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @iota_i32 +func @iota_i32() -> tensor<7x10xi32> { + %result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xi32>) + return %result : tensor<7x10xi32> +} +// CHECK: linalg.init_tensor +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%{{.*}}: i32): +// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 +// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32 +// CHECK-NEXT: linalg.yield %[[INT_CAST]] : i32 + +// ----- + +// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @iota_ui32 +func @iota_ui32() -> tensor<7x10xui32> { + %result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xui32>) + return %result : tensor<7x10xui32> +} +// CHECK: linalg.init_tensor +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%{{.*}}: i32): +// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 +// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32 +// CHECK-NEXT: linalg.yield %[[INT_CAST]] : i32 +// CHECK: unrealized_conversion_cast %{{.*}} : tensor<7x10xi32> to tensor<7x10xui32> + +// ----- + // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-LABEL: func @iota +// CHECK-LABEL: func @dynamic_iota_f32 // CHECK-SAME: %[[SHAPE:.*]]: tensor -func @iota(%shape: tensor) -> tensor { +func @dynamic_iota_f32(%shape: tensor) -> tensor { %result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor) -> (tensor) return %result : tensor } @@ -944,6 +977,28 @@ func @iota(%shape: tensor) -> tensor { // ----- +// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @dyanmic_iota_ui32 +// CHECK-SAME: %[[SHAPE:.*]]: tensor +func @dyanmic_iota_ui32(%shape: tensor) -> tensor { + %result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor) -> (tensor) + return %result : tensor +} +// CHECK: %[[E1:.*]] = tensor.extract %[[SHAPE]][%c0] : tensor +// CHECK: %[[I1:.*]] = index_cast %[[E1]] : i32 to index +// CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor +// CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index +// CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%{{.*}}: i32): +// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 +// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32 +// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : i32 +// CHECK: unrealized_conversion_cast %{{.*}} : tensor to tensor + +// ----- + func @shift_left(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { %result = "mhlo.shift_left"(%lhs, %rhs)