diff --git a/include/tim/vx/ops.h b/include/tim/vx/ops.h index 253f698..d93b313 100644 --- a/include/tim/vx/ops.h +++ b/include/tim/vx/ops.h @@ -42,6 +42,7 @@ #include "tim/vx/ops/erf.h" #include "tim/vx/ops/fullyconnected.h" #include "tim/vx/ops/gather.h" +#include "tim/vx/ops/gather_elements.h" #include "tim/vx/ops/gathernd.h" #include "tim/vx/ops/groupedconv2d.h" #include "tim/vx/ops/instancenormalization.h" diff --git a/include/tim/vx/ops/gather_elements.h b/include/tim/vx/ops/gather_elements.h index ae1603b..509e8d9 100644 --- a/include/tim/vx/ops/gather_elements.h +++ b/include/tim/vx/ops/gather_elements.h @@ -21,8 +21,8 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ -#ifndef TIM_VX_OPS_GATHER_H_ -#define TIM_VX_OPS_GATHER_H_ +#ifndef TIM_VX_OPS_GATHER_ELEMENTS_H_ +#define TIM_VX_OPS_GATHER_ELEMENTS_H_ #include "tim/vx/direct_map_op.h" namespace tim { @@ -30,18 +30,18 @@ namespace vx { namespace ops { /** - * ## Gather_elements + * ## GatherElements * - * Gather_elements slices from input, **axis** according to **indices**. + * GatherElements slices from input, **axis** according to **indices**. * out[i][j][k] = input[index[i][j][k]][j][k] if axis = 0, * out[i][j][k] = input[i][index[i][j][k]][k] if axis = 1, * out[i][j][k] = input[i][j][index[i][j][k]] if axis = 2, * https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements */ -class Gather_elements : public DirectMapOp { +class GatherElements : public DirectMapOp { public: - Gather_elements(Graph* Graph, int axis); + GatherElements(Graph* Graph, int axis); std::shared_ptr Clone( std::shared_ptr& graph) const override; @@ -54,4 +54,4 @@ class Gather_elements : public DirectMapOp { } // namespace vx } // namespace tim -#endif /* TIM_VX_OPS_GATHER_H_ */ \ No newline at end of file +#endif /* TIM_VX_OPS_GATHER_ELEMENTS_H_ */ \ No newline at end of file diff --git a/src/tim/vx/ops/gather_elements.cc b/src/tim/vx/ops/gather_elements.cc index 89876a9..aebe962 100644 --- a/src/tim/vx/ops/gather_elements.cc +++ b/src/tim/vx/ops/gather_elements.cc @@ -30,14 +30,14 @@ namespace tim { namespace vx { namespace ops { #ifdef _VSI_NN_OP_GATHER_ELEMENTS_H -Gather_elements::Gather_elements(Graph* graph, int axis) +GatherElements::GatherElements(Graph* graph, int axis) : DirectMapOp(graph, VSI_NN_OP_GATHER_ELEMENTS), axis_(axis) { this->impl()->node()->nn_param.gather_elements.axis = axis_; } -std::shared_ptr Gather_elements::Clone( +std::shared_ptr GatherElements::Clone( std::shared_ptr& graph) const { - return graph->CreateOperation(this->axis_); + return graph->CreateOperation(this->axis_); } #endif diff --git a/src/tim/vx/ops/gather_elements_test.cc b/src/tim/vx/ops/gather_elements_test.cc index 6977518..d39ec48 100644 --- a/src/tim/vx/ops/gather_elements_test.cc +++ b/src/tim/vx/ops/gather_elements_test.cc @@ -28,9 +28,8 @@ #include "gtest/gtest.h" #include "test_utils.h" -#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H -TEST(Gather_elements, shape_3_2_1_int32_axis_0) { +TEST(GatherElements, shape_3_2_1_int32_axis_0) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -66,7 +65,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_0) { input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4)); EXPECT_TRUE( indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4)); - auto op = graph->CreateOperation(0); + auto op = graph->CreateOperation(0); (*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor}); EXPECT_TRUE(graph->Compile()); @@ -77,7 +76,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_0) { EXPECT_EQ(golden, output); } -TEST(Gather_elements, shape_3_2_1_int32_axis_1) { +TEST(GatherElements, shape_3_2_1_int32_axis_1) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -113,7 +112,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_1) { input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4)); EXPECT_TRUE( indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4)); - auto op = graph->CreateOperation(1); + auto op = graph->CreateOperation(1); (*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor}); EXPECT_TRUE(graph->Compile()); @@ -124,7 +123,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_1) { EXPECT_EQ(golden, output); } -TEST(Gather_elements, shape_3_2_1_float32_axis_2) { +TEST(GatherElements, shape_3_2_1_float32_axis_2) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -160,7 +159,7 @@ TEST(Gather_elements, shape_3_2_1_float32_axis_2) { input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4)); EXPECT_TRUE( indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4)); - auto op = graph->CreateOperation(2); + auto op = graph->CreateOperation(2); (*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor}); EXPECT_TRUE(graph->Compile()); @@ -170,4 +169,4 @@ TEST(Gather_elements, shape_3_2_1_float32_axis_2) { EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_EQ(golden, output); } -#endif +