Added case for gather (#599)

Signed-off-by: Chen <jack.chen@verisilicon.com>
Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
chxin66 2023-06-26 09:15:08 +08:00 committed by GitHub
parent 233eb439e1
commit 34812fe40e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 9 deletions

View File

@ -33,6 +33,7 @@ namespace ops {
* ## Gather
*
* Gather slices from input, **axis** according to **indices**.
* batch_dims means in which dimension to repeat the value according to indices.
*/
class Gather : public BuiltinOp {

View File

@ -50,7 +50,7 @@ TEST(GatherElements, shape_3_2_1_int32_axis_0) {
std::vector<int32_t> in_data = {
1, 2, 3, 4, 5, 6,
};
//The index value greater than rank-1 is regarded as rank-1
std::vector<int32_t> indices = {
1,
2,
@ -97,7 +97,7 @@ TEST(GatherElements, shape_3_2_1_int32_axis_1) {
std::vector<int32_t> in_data = {
1, 2, 3, 4, 5, 6,
};
//The index value greater than rank-1 is regarded as rank-1
std::vector<int32_t> indices = {
1,
2,
@ -144,7 +144,7 @@ TEST(GatherElements, shape_3_2_1_float32_axis_2) {
std::vector<float> in_data = {
1, 2, 3, 4, 5, 6,
};
//The index value greater than rank-1 is regarded as rank-1
std::vector<int32_t> indices = {
1,
2,

View File

@ -54,7 +54,7 @@ TEST(Gather, shape_5_3_2_2_int32_axis_1_batchdims_1) {
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
};
//The index value greater than rank-1 is regarded as rank-1
std::vector<int32_t> indices = {1, 0, 0, 1, 1, 0, 0, 1};
std::vector<int8_t> golden = {
5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5,
@ -63,11 +63,9 @@ TEST(Gather, shape_5_3_2_2_int32_axis_1_batchdims_1) {
33, 34, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 50, 51, 52, 53,
54, 45, 46, 47, 48, 49, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54};
EXPECT_TRUE(
input_tensor->CopyDataToTensor(in_data.data(), in_data.size()));
EXPECT_TRUE(
indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4));
auto op = graph->CreateOperation<tim::vx::ops::Gather>(1,1);
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size()));
EXPECT_TRUE(indices_tensor->CopyDataToTensor(indices.data(), indices.size()));
auto op = graph->CreateOperation<tim::vx::ops::Gather>(1, 1);
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
EXPECT_TRUE(graph->Compile());
@ -77,3 +75,41 @@ TEST(Gather, shape_5_3_2_2_int32_axis_1_batchdims_1) {
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
EXPECT_EQ(golden, output);
}
TEST(Gather, shape_2_2_indices_2) {
auto ctx = tim::vx::Context::Create();
if (ctx->isClOnly()) GTEST_SKIP();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType in_shape({2, 2});
tim::vx::ShapeType indices_shape({2});
tim::vx::ShapeType out_shape({2, 2});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, in_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec indices_spec(tim::vx::DataType::INT32, indices_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, out_shape,
tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto indices_tensor = graph->CreateTensor(indices_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {-2.0f, 0.2f, 0.7f, 0.8f};
std::vector<int32_t> indices = {1, 0};
std::vector<float> golden = {0.7f, 0.8f, -2.0f, 0.2f};
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size()));
EXPECT_TRUE(indices_tensor->CopyDataToTensor(indices.data(), indices.size()));
auto op = graph->CreateOperation<tim::vx::ops::Gather>(1, 0);
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
EXPECT_TRUE(graph->Compile());
EXPECT_TRUE(graph->Run());
std::vector<float> output(golden.size());
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
EXPECT_EQ(golden, output);
}