Added case for gather (#599)
Signed-off-by: Chen <jack.chen@verisilicon.com> Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
parent
233eb439e1
commit
34812fe40e
|
|
@ -33,6 +33,7 @@ namespace ops {
|
|||
* ## Gather
|
||||
*
|
||||
* Gather slices from input, **axis** according to **indices**.
|
||||
* batch_dims means in which dimension to repeat the value according to indices.
|
||||
*/
|
||||
|
||||
class Gather : public BuiltinOp {
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ TEST(GatherElements, shape_3_2_1_int32_axis_0) {
|
|||
std::vector<int32_t> in_data = {
|
||||
1, 2, 3, 4, 5, 6,
|
||||
};
|
||||
//The index value greater than rank-1 is regarded as rank-1
|
||||
|
||||
std::vector<int32_t> indices = {
|
||||
1,
|
||||
2,
|
||||
|
|
@ -97,7 +97,7 @@ TEST(GatherElements, shape_3_2_1_int32_axis_1) {
|
|||
std::vector<int32_t> in_data = {
|
||||
1, 2, 3, 4, 5, 6,
|
||||
};
|
||||
//The index value greater than rank-1 is regarded as rank-1
|
||||
|
||||
std::vector<int32_t> indices = {
|
||||
1,
|
||||
2,
|
||||
|
|
@ -144,7 +144,7 @@ TEST(GatherElements, shape_3_2_1_float32_axis_2) {
|
|||
std::vector<float> in_data = {
|
||||
1, 2, 3, 4, 5, 6,
|
||||
};
|
||||
//The index value greater than rank-1 is regarded as rank-1
|
||||
|
||||
std::vector<int32_t> indices = {
|
||||
1,
|
||||
2,
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ TEST(Gather, shape_5_3_2_2_int32_axis_1_batchdims_1) {
|
|||
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
|
||||
45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
|
||||
};
|
||||
//The index value greater than rank-1 is regarded as rank-1
|
||||
|
||||
std::vector<int32_t> indices = {1, 0, 0, 1, 1, 0, 0, 1};
|
||||
std::vector<int8_t> golden = {
|
||||
5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5,
|
||||
|
|
@ -63,11 +63,9 @@ TEST(Gather, shape_5_3_2_2_int32_axis_1_batchdims_1) {
|
|||
33, 34, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 50, 51, 52, 53,
|
||||
54, 45, 46, 47, 48, 49, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54};
|
||||
|
||||
EXPECT_TRUE(
|
||||
input_tensor->CopyDataToTensor(in_data.data(), in_data.size()));
|
||||
EXPECT_TRUE(
|
||||
indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4));
|
||||
auto op = graph->CreateOperation<tim::vx::ops::Gather>(1,1);
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size()));
|
||||
EXPECT_TRUE(indices_tensor->CopyDataToTensor(indices.data(), indices.size()));
|
||||
auto op = graph->CreateOperation<tim::vx::ops::Gather>(1, 1);
|
||||
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
|
|
@ -77,3 +75,41 @@ TEST(Gather, shape_5_3_2_2_int32_axis_1_batchdims_1) {
|
|||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(Gather, shape_2_2_indices_2) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
|
||||
if (ctx->isClOnly()) GTEST_SKIP();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
tim::vx::ShapeType in_shape({2, 2});
|
||||
tim::vx::ShapeType indices_shape({2});
|
||||
tim::vx::ShapeType out_shape({2, 2});
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, in_shape,
|
||||
tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec indices_spec(tim::vx::DataType::INT32, indices_shape,
|
||||
tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, out_shape,
|
||||
tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto indices_tensor = graph->CreateTensor(indices_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<float> in_data = {-2.0f, 0.2f, 0.7f, 0.8f};
|
||||
|
||||
std::vector<int32_t> indices = {1, 0};
|
||||
std::vector<float> golden = {0.7f, 0.8f, -2.0f, 0.2f};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size()));
|
||||
EXPECT_TRUE(indices_tensor->CopyDataToTensor(indices.data(), indices.size()));
|
||||
auto op = graph->CreateOperation<tim::vx::ops::Gather>(1, 0);
|
||||
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
|
||||
std::vector<float> output(golden.size());
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue