From 9f331ed5ecc449921485295b93fd8c18559f1d89 Mon Sep 17 00:00:00 2001 From: chxin66 <57057788+chxin66@users.noreply.github.com> Date: Tue, 19 Jul 2022 12:33:09 +0800 Subject: [PATCH] Added batch dims in gather (#435) Signed-off-by: Chen Xin Co-authored-by: Chen Xin --- include/tim/vx/ops/gather.h | 3 +- .../transform/ops/gather_layout_inference.h | 3 +- src/tim/vx/ops/gather.cc | 7 +- src/tim/vx/ops/gather_test.cc | 77 +++++++++++++++++++ 4 files changed, 85 insertions(+), 5 deletions(-) create mode 100644 src/tim/vx/ops/gather_test.cc diff --git a/include/tim/vx/ops/gather.h b/include/tim/vx/ops/gather.h index 1728ac2..704a9e0 100644 --- a/include/tim/vx/ops/gather.h +++ b/include/tim/vx/ops/gather.h @@ -37,12 +37,13 @@ namespace ops { class Gather : public DirectMapOp { public: - Gather(Graph* Graph, int axis); + Gather(Graph* Graph, int axis, int batch_dims = 0); std::shared_ptr Clone(std::shared_ptr& graph) const override; protected: int axis_; + int batch_dims_; }; } // namespace ops diff --git a/src/tim/transform/ops/gather_layout_inference.h b/src/tim/transform/ops/gather_layout_inference.h index db4db9d..99ec6aa 100644 --- a/src/tim/transform/ops/gather_layout_inference.h +++ b/src/tim/transform/ops/gather_layout_inference.h @@ -41,7 +41,8 @@ class GatherLayoutInfer : public OpLayoutInfer { ReverseInputsPermuteVector(); auto gather = context_->infer_graph_->CreateOperation( - op_->impl()->node()->nn_param.gather.axis); + op_->impl()->node()->nn_param.gather.axis, + op_->impl()->node()->nn_param.gather.batch_dims); int32_t output_rank = -1; for (const auto& i_src : op_->impl()->InputsTensor()) { (*gather).BindInput(context_->GetMapedTensor(i_src)); diff --git a/src/tim/vx/ops/gather.cc b/src/tim/vx/ops/gather.cc index b247249..c62c987 100644 --- a/src/tim/vx/ops/gather.cc +++ b/src/tim/vx/ops/gather.cc @@ -30,13 +30,14 @@ namespace tim { namespace vx { namespace ops { -Gather::Gather(Graph* graph, int axis) - : DirectMapOp(graph, VSI_NN_OP_GATHER), axis_(axis) { +Gather::Gather(Graph* graph, int axis, int batch_dims) + : DirectMapOp(graph, VSI_NN_OP_GATHER), axis_(axis), batch_dims_(batch_dims) { this->impl()->node()->nn_param.gather.axis = axis_; + this->impl()->node()->nn_param.gather.batch_dims = batch_dims_; } std::shared_ptr Gather::Clone(std::shared_ptr& graph) const { - return graph->CreateOperation(this->axis_); + return graph->CreateOperation(this->axis_, this->batch_dims_); } } // namespace ops diff --git a/src/tim/vx/ops/gather_test.cc b/src/tim/vx/ops/gather_test.cc new file mode 100644 index 0000000..3f35768 --- /dev/null +++ b/src/tim/vx/ops/gather_test.cc @@ -0,0 +1,77 @@ +/**************************************************************************** +* +* Copyright (c) 2022 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops/gather.h" +#include +#include "gtest/gtest.h" +#include "test_utils.h" + +TEST(Gather, shape_5_3_2_2_int32_axis_1_batchdims_1) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType in_shape({5, 3, 2, 2}); + tim::vx::ShapeType indices_shape({2, 2, 2}); + tim::vx::ShapeType out_shape({5, 2, 2, 2, 2}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::INT8, in_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec indices_spec(tim::vx::DataType::INT32, indices_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::INT8, out_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto indices_tensor = graph->CreateTensor(indices_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + }; + //The index value greater than rank-1 is regarded as rank-1 + std::vector indices = {1, 0, 0, 1, 1, 0, 0, 1}; + std::vector golden = { + 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 20, 21, 22, 23, 24, 15, 16, 17, 18, 19, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 35, 36, 37, 38, 39, 30, 31, 32, + 33, 34, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 50, 51, 52, 53, + 54, 45, 46, 47, 48, 49, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54}; + + EXPECT_TRUE( + input_tensor->CopyDataToTensor(in_data.data(), in_data.size())); + EXPECT_TRUE( + indices_tensor->CopyDataToTensor(indices.data(), indices.size() * 4)); + auto op = graph->CreateOperation(1,1); + (*op).BindInputs({input_tensor, indices_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_EQ(golden, output); +}