modify GatherElements (#406)

This commit is contained in:
MESeraph 2022-05-29 22:25:14 +08:00 committed by GitHub
parent 1b4c30e572
commit 6d0c6b01b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 18 deletions

View File

@ -42,6 +42,7 @@
#include "tim/vx/ops/erf.h"
#include "tim/vx/ops/fullyconnected.h"
#include "tim/vx/ops/gather.h"
#include "tim/vx/ops/gather_elements.h"
#include "tim/vx/ops/gathernd.h"
#include "tim/vx/ops/groupedconv2d.h"
#include "tim/vx/ops/instancenormalization.h"

View File

@ -21,8 +21,8 @@
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_VX_OPS_GATHER_H_
#define TIM_VX_OPS_GATHER_H_
#ifndef TIM_VX_OPS_GATHER_ELEMENTS_H_
#define TIM_VX_OPS_GATHER_ELEMENTS_H_
#include "tim/vx/direct_map_op.h"
namespace tim {
@ -30,18 +30,18 @@ namespace vx {
namespace ops {
/**
* ## Gather_elements
* ## GatherElements
*
* Gather_elements slices from input, **axis** according to **indices**.
* GatherElements slices from input, **axis** according to **indices**.
* out[i][j][k] = input[index[i][j][k]][j][k] if axis = 0,
* out[i][j][k] = input[i][index[i][j][k]][k] if axis = 1,
* out[i][j][k] = input[i][j][index[i][j][k]] if axis = 2,
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements
*/
class Gather_elements : public DirectMapOp {
class GatherElements : public DirectMapOp {
public:
Gather_elements(Graph* Graph, int axis);
GatherElements(Graph* Graph, int axis);
std::shared_ptr<Operation> Clone(
std::shared_ptr<Graph>& graph) const override;
@ -54,4 +54,4 @@ class Gather_elements : public DirectMapOp {
} // namespace vx
} // namespace tim
#endif /* TIM_VX_OPS_GATHER_H_ */
#endif /* TIM_VX_OPS_GATHER_ELEMENTS_H_ */

View File

@ -30,14 +30,14 @@ namespace tim {
namespace vx {
namespace ops {
#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H
Gather_elements::Gather_elements(Graph* graph, int axis)
GatherElements::GatherElements(Graph* graph, int axis)
: DirectMapOp(graph, VSI_NN_OP_GATHER_ELEMENTS), axis_(axis) {
this->impl()->node()->nn_param.gather_elements.axis = axis_;
}
std::shared_ptr<Operation> Gather_elements::Clone(
std::shared_ptr<Operation> GatherElements::Clone(
std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Gather_elements>(this->axis_);
return graph->CreateOperation<GatherElements>(this->axis_);
}
#endif

View File

@ -28,9 +28,8 @@
#include "gtest/gtest.h"
#include "test_utils.h"
#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H
TEST(Gather_elements, shape_3_2_1_int32_axis_0) {
TEST(GatherElements, shape_3_2_1_int32_axis_0) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
@ -66,7 +65,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_0) {
input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4));
EXPECT_TRUE(
indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4));
auto op = graph->CreateOperation<tim::vx::ops::Gather_elements>(0);
auto op = graph->CreateOperation<tim::vx::ops::GatherElements>(0);
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
EXPECT_TRUE(graph->Compile());
@ -77,7 +76,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_0) {
EXPECT_EQ(golden, output);
}
TEST(Gather_elements, shape_3_2_1_int32_axis_1) {
TEST(GatherElements, shape_3_2_1_int32_axis_1) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
@ -113,7 +112,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_1) {
input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4));
EXPECT_TRUE(
indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4));
auto op = graph->CreateOperation<tim::vx::ops::Gather_elements>(1);
auto op = graph->CreateOperation<tim::vx::ops::GatherElements>(1);
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
EXPECT_TRUE(graph->Compile());
@ -124,7 +123,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_1) {
EXPECT_EQ(golden, output);
}
TEST(Gather_elements, shape_3_2_1_float32_axis_2) {
TEST(GatherElements, shape_3_2_1_float32_axis_2) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
@ -160,7 +159,7 @@ TEST(Gather_elements, shape_3_2_1_float32_axis_2) {
input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4));
EXPECT_TRUE(
indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4));
auto op = graph->CreateOperation<tim::vx::ops::Gather_elements>(2);
auto op = graph->CreateOperation<tim::vx::ops::GatherElements>(2);
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
EXPECT_TRUE(graph->Compile());
@ -170,4 +169,4 @@ TEST(Gather_elements, shape_3_2_1_float32_axis_2) {
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
EXPECT_EQ(golden, output);
}
#endif