[MHLO:Linalg] Add support for lowering torch_index_select of unsigned tensors
Also fixes typos in tests. PiperOrigin-RevId: 374979460
This commit is contained in:
parent
a1445aa0fa
commit
cd8f585cf7
|
@ -2014,7 +2014,9 @@ struct TorchIndexSelectOpOnTensorsConversion
|
||||||
if (batch < 0) batch += num_indices;
|
if (batch < 0) batch += num_indices;
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto result_type = op.getResult().getType().cast<ShapedType>();
|
auto result_type =
|
||||||
|
this->typeConverter->convertType(op.getResult().getType())
|
||||||
|
.cast<ShapedType>();
|
||||||
int rank = static_cast<int>(result_type.getRank());
|
int rank = static_cast<int>(result_type.getRank());
|
||||||
|
|
||||||
SmallVector<AffineMap, 2> indexing_maps;
|
SmallVector<AffineMap, 2> indexing_maps;
|
||||||
|
|
|
@ -2076,7 +2076,7 @@ func @reduce_window_sum_max_nhwc(%arg0: tensor<1x18x18x64xf32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @torch_select_index(%arg0: tensor<5x1x5xi32>,
|
func @torch_index_select(%arg0: tensor<5x1x5xi32>,
|
||||||
%arg1: tensor<2xi32>) -> tensor<2x1x5xi32> {
|
%arg1: tensor<2xi32>) -> tensor<2x1x5xi32> {
|
||||||
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||||
dim = 0 : i64,
|
dim = 0 : i64,
|
||||||
|
@ -2086,7 +2086,7 @@ func @torch_select_index(%arg0: tensor<5x1x5xi32>,
|
||||||
}
|
}
|
||||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
|
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
|
||||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
// CHECK: func @torch_select_index
|
// CHECK: func @torch_index_select
|
||||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||||
// CHECK: linalg.indexed_generic {
|
// CHECK: linalg.indexed_generic {
|
||||||
|
@ -2103,7 +2103,37 @@ func @torch_select_index(%arg0: tensor<5x1x5xi32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @torch_select_index_scalar(%arg0: tensor<4x8xf32>,
|
func @torch_index_select_unsigned(%arg0: tensor<5x1x5xui32>,
|
||||||
|
%arg1: tensor<2xi32>) -> tensor<2x1x5xui32> {
|
||||||
|
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||||
|
dim = 0 : i64,
|
||||||
|
batch_dims = 0 : i64
|
||||||
|
} : (tensor<5x1x5xui32>, tensor<2xi32>) -> tensor<2x1x5xui32>
|
||||||
|
return %0 : tensor<2x1x5xui32>
|
||||||
|
}
|
||||||
|
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
|
||||||
|
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK: func @torch_index_select_unsigned
|
||||||
|
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[INPUT_SIGNLESS:.*]] = unrealized_conversion_cast %[[INPUT]] : tensor<5x1x5xui32> to tensor<5x1x5xi32>
|
||||||
|
// CHECK: %[[RES:.+]] = linalg.indexed_generic {
|
||||||
|
// CHECK-SAME: indexing_maps
|
||||||
|
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
||||||
|
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
|
||||||
|
// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>)
|
||||||
|
// CHECK: ^{{.+}}(
|
||||||
|
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[K:.+]]: index
|
||||||
|
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: i32):
|
||||||
|
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
||||||
|
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT_SIGNLESS]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32>
|
||||||
|
// CHECK: linalg.yield %[[VAL2]] : i32
|
||||||
|
// CHECK: %[[RES_UNSIGNED:.+]] = unrealized_conversion_cast %[[RES]] : tensor<2x1x5xi32> to tensor<2x1x5xui32>
|
||||||
|
// CHECK: return %[[RES_UNSIGNED]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @torch_index_select_scalar(%arg0: tensor<4x8xf32>,
|
||||||
%arg1: tensor<i32>) -> tensor<8xf32> {
|
%arg1: tensor<i32>) -> tensor<8xf32> {
|
||||||
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||||
batch_dims = 0 : i64,
|
batch_dims = 0 : i64,
|
||||||
|
@ -2114,7 +2144,7 @@ func @torch_select_index_scalar(%arg0: tensor<4x8xf32>,
|
||||||
|
|
||||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
|
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
|
||||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
|
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
|
||||||
// CHECK: func @torch_select_index_scalar
|
// CHECK: func @torch_index_select_scalar
|
||||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||||
// CHECK: %[[T0:.+]] = linalg.init_tensor [8] : tensor<8xf32>
|
// CHECK: %[[T0:.+]] = linalg.init_tensor [8] : tensor<8xf32>
|
||||||
|
@ -2131,7 +2161,7 @@ func @torch_select_index_scalar(%arg0: tensor<4x8xf32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @torch_select_index_batch(%arg0: tensor<4x7x8x2xf32>,
|
func @torch_index_select_batch(%arg0: tensor<4x7x8x2xf32>,
|
||||||
%arg1: tensor<4x1xi32>) -> tensor<4x7x1x2xf32> {
|
%arg1: tensor<4x1xi32>) -> tensor<4x7x1x2xf32> {
|
||||||
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||||
dim = 2 : i64,
|
dim = 2 : i64,
|
||||||
|
@ -2141,7 +2171,7 @@ func @torch_select_index_batch(%arg0: tensor<4x7x8x2xf32>,
|
||||||
}
|
}
|
||||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
|
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
|
||||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||||
// CHECK: func @torch_select_index_batch
|
// CHECK: func @torch_index_select_batch
|
||||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||||
// CHECK: linalg.indexed_generic {
|
// CHECK: linalg.indexed_generic {
|
||||||
|
|
Loading…
Reference in New Issue