modify GatherElements (#406)

This commit is contained in:
MESeraph 2022-05-29 22:25:14 +08:00 committed by GitHub
parent 1b4c30e572
commit 6d0c6b01b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 18 deletions

View File

@ -42,6 +42,7 @@
#include "tim/vx/ops/erf.h" #include "tim/vx/ops/erf.h"
#include "tim/vx/ops/fullyconnected.h" #include "tim/vx/ops/fullyconnected.h"
#include "tim/vx/ops/gather.h" #include "tim/vx/ops/gather.h"
#include "tim/vx/ops/gather_elements.h"
#include "tim/vx/ops/gathernd.h" #include "tim/vx/ops/gathernd.h"
#include "tim/vx/ops/groupedconv2d.h" #include "tim/vx/ops/groupedconv2d.h"
#include "tim/vx/ops/instancenormalization.h" #include "tim/vx/ops/instancenormalization.h"

View File

@ -21,8 +21,8 @@
* DEALINGS IN THE SOFTWARE. * DEALINGS IN THE SOFTWARE.
* *
*****************************************************************************/ *****************************************************************************/
#ifndef TIM_VX_OPS_GATHER_H_ #ifndef TIM_VX_OPS_GATHER_ELEMENTS_H_
#define TIM_VX_OPS_GATHER_H_ #define TIM_VX_OPS_GATHER_ELEMENTS_H_
#include "tim/vx/direct_map_op.h" #include "tim/vx/direct_map_op.h"
namespace tim { namespace tim {
@ -30,18 +30,18 @@ namespace vx {
namespace ops { namespace ops {
/** /**
* ## Gather_elements * ## GatherElements
* *
* Gather_elements slices from input, **axis** according to **indices**. * GatherElements slices from input, **axis** according to **indices**.
* out[i][j][k] = input[index[i][j][k]][j][k] if axis = 0, * out[i][j][k] = input[index[i][j][k]][j][k] if axis = 0,
* out[i][j][k] = input[i][index[i][j][k]][k] if axis = 1, * out[i][j][k] = input[i][index[i][j][k]][k] if axis = 1,
* out[i][j][k] = input[i][j][index[i][j][k]] if axis = 2, * out[i][j][k] = input[i][j][index[i][j][k]] if axis = 2,
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements * https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements
*/ */
class Gather_elements : public DirectMapOp { class GatherElements : public DirectMapOp {
public: public:
Gather_elements(Graph* Graph, int axis); GatherElements(Graph* Graph, int axis);
std::shared_ptr<Operation> Clone( std::shared_ptr<Operation> Clone(
std::shared_ptr<Graph>& graph) const override; std::shared_ptr<Graph>& graph) const override;
@ -54,4 +54,4 @@ class Gather_elements : public DirectMapOp {
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim
#endif /* TIM_VX_OPS_GATHER_H_ */ #endif /* TIM_VX_OPS_GATHER_ELEMENTS_H_ */

View File

@ -30,14 +30,14 @@ namespace tim {
namespace vx { namespace vx {
namespace ops { namespace ops {
#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H #ifdef _VSI_NN_OP_GATHER_ELEMENTS_H
Gather_elements::Gather_elements(Graph* graph, int axis) GatherElements::GatherElements(Graph* graph, int axis)
: DirectMapOp(graph, VSI_NN_OP_GATHER_ELEMENTS), axis_(axis) { : DirectMapOp(graph, VSI_NN_OP_GATHER_ELEMENTS), axis_(axis) {
this->impl()->node()->nn_param.gather_elements.axis = axis_; this->impl()->node()->nn_param.gather_elements.axis = axis_;
} }
std::shared_ptr<Operation> Gather_elements::Clone( std::shared_ptr<Operation> GatherElements::Clone(
std::shared_ptr<Graph>& graph) const { std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Gather_elements>(this->axis_); return graph->CreateOperation<GatherElements>(this->axis_);
} }
#endif #endif

View File

@ -28,9 +28,8 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "test_utils.h" #include "test_utils.h"
#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H
TEST(Gather_elements, shape_3_2_1_int32_axis_0) { TEST(GatherElements, shape_3_2_1_int32_axis_0) {
auto ctx = tim::vx::Context::Create(); auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph(); auto graph = ctx->CreateGraph();
@ -66,7 +65,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_0) {
input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4)); input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4));
EXPECT_TRUE( EXPECT_TRUE(
indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4)); indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4));
auto op = graph->CreateOperation<tim::vx::ops::Gather_elements>(0); auto op = graph->CreateOperation<tim::vx::ops::GatherElements>(0);
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor}); (*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
EXPECT_TRUE(graph->Compile()); EXPECT_TRUE(graph->Compile());
@ -77,7 +76,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_0) {
EXPECT_EQ(golden, output); EXPECT_EQ(golden, output);
} }
TEST(Gather_elements, shape_3_2_1_int32_axis_1) { TEST(GatherElements, shape_3_2_1_int32_axis_1) {
auto ctx = tim::vx::Context::Create(); auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph(); auto graph = ctx->CreateGraph();
@ -113,7 +112,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_1) {
input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4)); input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4));
EXPECT_TRUE( EXPECT_TRUE(
indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4)); indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4));
auto op = graph->CreateOperation<tim::vx::ops::Gather_elements>(1); auto op = graph->CreateOperation<tim::vx::ops::GatherElements>(1);
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor}); (*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
EXPECT_TRUE(graph->Compile()); EXPECT_TRUE(graph->Compile());
@ -124,7 +123,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_1) {
EXPECT_EQ(golden, output); EXPECT_EQ(golden, output);
} }
TEST(Gather_elements, shape_3_2_1_float32_axis_2) { TEST(GatherElements, shape_3_2_1_float32_axis_2) {
auto ctx = tim::vx::Context::Create(); auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph(); auto graph = ctx->CreateGraph();
@ -160,7 +159,7 @@ TEST(Gather_elements, shape_3_2_1_float32_axis_2) {
input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4)); input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4));
EXPECT_TRUE( EXPECT_TRUE(
indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4)); indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4));
auto op = graph->CreateOperation<tim::vx::ops::Gather_elements>(2); auto op = graph->CreateOperation<tim::vx::ops::GatherElements>(2);
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor}); (*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
EXPECT_TRUE(graph->Compile()); EXPECT_TRUE(graph->Compile());
@ -170,4 +169,4 @@ TEST(Gather_elements, shape_3_2_1_float32_axis_2) {
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
EXPECT_EQ(golden, output); EXPECT_EQ(golden, output);
} }
#endif