diff --git a/include/tim/vx/ops/gather.h b/include/tim/vx/ops/gather.h index a9cb27f..297e72e 100644 --- a/include/tim/vx/ops/gather.h +++ b/include/tim/vx/ops/gather.h @@ -33,6 +33,7 @@ namespace ops { * ## Gather * * Gather slices from input, **axis** according to **indices**. + * batch_dims means in which dimension to repeat the value according to indices. */ class Gather : public BuiltinOp { diff --git a/src/tim/vx/ops/gather_elements_test.cc b/src/tim/vx/ops/gather_elements_test.cc index 81fd8d7..4251927 100644 --- a/src/tim/vx/ops/gather_elements_test.cc +++ b/src/tim/vx/ops/gather_elements_test.cc @@ -50,7 +50,7 @@ TEST(GatherElements, shape_3_2_1_int32_axis_0) { std::vector in_data = { 1, 2, 3, 4, 5, 6, }; - //The index value greater than rank-1 is regarded as rank-1 + std::vector indices = { 1, 2, @@ -97,7 +97,7 @@ TEST(GatherElements, shape_3_2_1_int32_axis_1) { std::vector in_data = { 1, 2, 3, 4, 5, 6, }; - //The index value greater than rank-1 is regarded as rank-1 + std::vector indices = { 1, 2, @@ -144,7 +144,7 @@ TEST(GatherElements, shape_3_2_1_float32_axis_2) { std::vector in_data = { 1, 2, 3, 4, 5, 6, }; - //The index value greater than rank-1 is regarded as rank-1 + std::vector indices = { 1, 2, diff --git a/src/tim/vx/ops/gather_test.cc b/src/tim/vx/ops/gather_test.cc index 8b2dd7a..8989b1c 100644 --- a/src/tim/vx/ops/gather_test.cc +++ b/src/tim/vx/ops/gather_test.cc @@ -54,7 +54,7 @@ TEST(Gather, shape_5_3_2_2_int32_axis_1_batchdims_1) { 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, }; - //The index value greater than rank-1 is regarded as rank-1 + std::vector indices = {1, 0, 0, 1, 1, 0, 0, 1}; std::vector golden = { 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, @@ -63,11 +63,9 @@ TEST(Gather, shape_5_3_2_2_int32_axis_1_batchdims_1) { 33, 34, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 50, 51, 52, 53, 54, 45, 46, 47, 48, 49, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54}; - EXPECT_TRUE( - input_tensor->CopyDataToTensor(in_data.data(), in_data.size())); - EXPECT_TRUE( - indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4)); - auto op = graph->CreateOperation(1,1); + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size())); + EXPECT_TRUE(indices_tensor->CopyDataToTensor(indices.data(), indices.size())); + auto op = graph->CreateOperation(1, 1); (*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor}); EXPECT_TRUE(graph->Compile()); @@ -77,3 +75,41 @@ TEST(Gather, shape_5_3_2_2_int32_axis_1_batchdims_1) { EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_EQ(golden, output); } + +TEST(Gather, shape_2_2_indices_2) { + auto ctx = tim::vx::Context::Create(); + + if (ctx->isClOnly()) GTEST_SKIP(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType in_shape({2, 2}); + tim::vx::ShapeType indices_shape({2}); + tim::vx::ShapeType out_shape({2, 2}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, in_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec indices_spec(tim::vx::DataType::INT32, indices_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, out_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto indices_tensor = graph->CreateTensor(indices_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = {-2.0f, 0.2f, 0.7f, 0.8f}; + + std::vector indices = {1, 0}; + std::vector golden = {0.7f, 0.8f, -2.0f, 0.2f}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size())); + EXPECT_TRUE(indices_tensor->CopyDataToTensor(indices.data(), indices.size())); + auto op = graph->CreateOperation(1, 0); + (*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_EQ(golden, output); +}