modify GatherElements (#406)
This commit is contained in:
parent
1b4c30e572
commit
6d0c6b01b5
|
|
@ -42,6 +42,7 @@
|
|||
#include "tim/vx/ops/erf.h"
|
||||
#include "tim/vx/ops/fullyconnected.h"
|
||||
#include "tim/vx/ops/gather.h"
|
||||
#include "tim/vx/ops/gather_elements.h"
|
||||
#include "tim/vx/ops/gathernd.h"
|
||||
#include "tim/vx/ops/groupedconv2d.h"
|
||||
#include "tim/vx/ops/instancenormalization.h"
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@
|
|||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_GATHER_H_
|
||||
#define TIM_VX_OPS_GATHER_H_
|
||||
#ifndef TIM_VX_OPS_GATHER_ELEMENTS_H_
|
||||
#define TIM_VX_OPS_GATHER_ELEMENTS_H_
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
@ -30,18 +30,18 @@ namespace vx {
|
|||
namespace ops {
|
||||
|
||||
/**
|
||||
* ## Gather_elements
|
||||
* ## GatherElements
|
||||
*
|
||||
* Gather_elements slices from input, **axis** according to **indices**.
|
||||
* GatherElements slices from input, **axis** according to **indices**.
|
||||
* out[i][j][k] = input[index[i][j][k]][j][k] if axis = 0,
|
||||
* out[i][j][k] = input[i][index[i][j][k]][k] if axis = 1,
|
||||
* out[i][j][k] = input[i][j][index[i][j][k]] if axis = 2,
|
||||
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements
|
||||
*/
|
||||
|
||||
class Gather_elements : public DirectMapOp {
|
||||
class GatherElements : public DirectMapOp {
|
||||
public:
|
||||
Gather_elements(Graph* Graph, int axis);
|
||||
GatherElements(Graph* Graph, int axis);
|
||||
|
||||
std::shared_ptr<Operation> Clone(
|
||||
std::shared_ptr<Graph>& graph) const override;
|
||||
|
|
@ -54,4 +54,4 @@ class Gather_elements : public DirectMapOp {
|
|||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
||||
#endif /* TIM_VX_OPS_GATHER_H_ */
|
||||
#endif /* TIM_VX_OPS_GATHER_ELEMENTS_H_ */
|
||||
|
|
@ -30,14 +30,14 @@ namespace tim {
|
|||
namespace vx {
|
||||
namespace ops {
|
||||
#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H
|
||||
Gather_elements::Gather_elements(Graph* graph, int axis)
|
||||
GatherElements::GatherElements(Graph* graph, int axis)
|
||||
: DirectMapOp(graph, VSI_NN_OP_GATHER_ELEMENTS), axis_(axis) {
|
||||
this->impl()->node()->nn_param.gather_elements.axis = axis_;
|
||||
}
|
||||
|
||||
std::shared_ptr<Operation> Gather_elements::Clone(
|
||||
std::shared_ptr<Operation> GatherElements::Clone(
|
||||
std::shared_ptr<Graph>& graph) const {
|
||||
return graph->CreateOperation<Gather_elements>(this->axis_);
|
||||
return graph->CreateOperation<GatherElements>(this->axis_);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -28,9 +28,8 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "test_utils.h"
|
||||
|
||||
#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H
|
||||
|
||||
TEST(Gather_elements, shape_3_2_1_int32_axis_0) {
|
||||
TEST(GatherElements, shape_3_2_1_int32_axis_0) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
|
|
@ -66,7 +65,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_0) {
|
|||
input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4));
|
||||
EXPECT_TRUE(
|
||||
indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4));
|
||||
auto op = graph->CreateOperation<tim::vx::ops::Gather_elements>(0);
|
||||
auto op = graph->CreateOperation<tim::vx::ops::GatherElements>(0);
|
||||
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
|
|
@ -77,7 +76,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_0) {
|
|||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(Gather_elements, shape_3_2_1_int32_axis_1) {
|
||||
TEST(GatherElements, shape_3_2_1_int32_axis_1) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
|
|
@ -113,7 +112,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_1) {
|
|||
input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4));
|
||||
EXPECT_TRUE(
|
||||
indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4));
|
||||
auto op = graph->CreateOperation<tim::vx::ops::Gather_elements>(1);
|
||||
auto op = graph->CreateOperation<tim::vx::ops::GatherElements>(1);
|
||||
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
|
|
@ -124,7 +123,7 @@ TEST(Gather_elements, shape_3_2_1_int32_axis_1) {
|
|||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(Gather_elements, shape_3_2_1_float32_axis_2) {
|
||||
TEST(GatherElements, shape_3_2_1_float32_axis_2) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
|
|
@ -160,7 +159,7 @@ TEST(Gather_elements, shape_3_2_1_float32_axis_2) {
|
|||
input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4));
|
||||
EXPECT_TRUE(
|
||||
indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4));
|
||||
auto op = graph->CreateOperation<tim::vx::ops::Gather_elements>(2);
|
||||
auto op = graph->CreateOperation<tim::vx::ops::GatherElements>(2);
|
||||
(*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
|
|
@ -170,4 +169,4 @@ TEST(Gather_elements, shape_3_2_1_float32_axis_2) {
|
|||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue