diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 4a62a40..24afdb3 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -2014,7 +2014,9 @@ struct TorchIndexSelectOpOnTensorsConversion if (batch < 0) batch += num_indices; Location loc = op.getLoc(); - auto result_type = op.getResult().getType().cast(); + auto result_type = + this->typeConverter->convertType(op.getResult().getType()) + .cast(); int rank = static_cast(result_type.getRank()); SmallVector indexing_maps; diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 5085d49..955187e 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -2076,7 +2076,7 @@ func @reduce_window_sum_max_nhwc(%arg0: tensor<1x18x18x64xf32>, // ----- -func @torch_select_index(%arg0: tensor<5x1x5xi32>, +func @torch_index_select(%arg0: tensor<5x1x5xi32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xi32> { %0 = "mhlo.torch_index_select"(%arg0, %arg1) { dim = 0 : i64, @@ -2086,7 +2086,7 @@ func @torch_select_index(%arg0: tensor<5x1x5xi32>, } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @torch_select_index +// CHECK: func @torch_index_select // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] // CHECK: linalg.indexed_generic { @@ -2103,7 +2103,37 @@ func @torch_select_index(%arg0: tensor<5x1x5xi32>, // ----- -func @torch_select_index_scalar(%arg0: tensor<4x8xf32>, +func @torch_index_select_unsigned(%arg0: tensor<5x1x5xui32>, + %arg1: tensor<2xi32>) -> tensor<2x1x5xui32> { + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xui32>, tensor<2xi32>) -> tensor<2x1x5xui32> + return %0 : tensor<2x1x5xui32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: func @torch_index_select_unsigned +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] +// CHECK: %[[INPUT_SIGNLESS:.*]] = unrealized_conversion_cast %[[INPUT]] : tensor<5x1x5xui32> to tensor<5x1x5xi32> +// CHECK: %[[RES:.+]] = linalg.indexed_generic { +// CHECK-SAME: indexing_maps +// CHECK-SAME: #[[MAP0]], #[[MAP1]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>) +// CHECK: ^{{.+}}( +// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[K:.+]]: index +// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: i32): +// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index +// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT_SIGNLESS]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32> +// CHECK: linalg.yield %[[VAL2]] : i32 +// CHECK: %[[RES_UNSIGNED:.+]] = unrealized_conversion_cast %[[RES]] : tensor<2x1x5xi32> to tensor<2x1x5xui32> +// CHECK: return %[[RES_UNSIGNED]] + +// ----- + +func @torch_index_select_scalar(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<8xf32> { %0 = "mhlo.torch_index_select"(%arg0, %arg1) { batch_dims = 0 : i64, @@ -2114,7 +2144,7 @@ func @torch_select_index_scalar(%arg0: tensor<4x8xf32>, // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> -// CHECK: func @torch_select_index_scalar +// CHECK: func @torch_index_select_scalar // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] // CHECK: %[[T0:.+]] = linalg.init_tensor [8] : tensor<8xf32> @@ -2131,7 +2161,7 @@ func @torch_select_index_scalar(%arg0: tensor<4x8xf32>, // ----- -func @torch_select_index_batch(%arg0: tensor<4x7x8x2xf32>, +func @torch_index_select_batch(%arg0: tensor<4x7x8x2xf32>, %arg1: tensor<4x1xi32>) -> tensor<4x7x1x2xf32> { %0 = "mhlo.torch_index_select"(%arg0, %arg1) { dim = 2 : i64, @@ -2141,7 +2171,7 @@ func @torch_select_index_batch(%arg0: tensor<4x7x8x2xf32>, } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @torch_select_index_batch +// CHECK: func @torch_index_select_batch // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] // CHECK: linalg.indexed_generic {