From 8a15abf12b0622e7e8fce660be2eea32cfb2f04a Mon Sep 17 00:00:00 2001 From: "zhao.xia" Date: Thu, 3 Jun 2021 10:55:05 +0800 Subject: [PATCH] Add ScatterND Signed-off-by: zhao.xia --- include/tim/vx/ops/scatternd.h | 52 ++++++++ .../include/ops/vsi_nn_op_scatter_nd.h | 4 +- src/tim/vx/ops/README.md | 2 +- src/tim/vx/ops/scatternd.cc | 40 ++++++ src/tim/vx/ops/scatternd_test.cc | 124 ++++++++++++++++++ 5 files changed, 219 insertions(+), 3 deletions(-) create mode 100644 include/tim/vx/ops/scatternd.h create mode 100644 src/tim/vx/ops/scatternd.cc create mode 100644 src/tim/vx/ops/scatternd_test.cc diff --git a/include/tim/vx/ops/scatternd.h b/include/tim/vx/ops/scatternd.h new file mode 100644 index 0000000..ed28ba8 --- /dev/null +++ b/include/tim/vx/ops/scatternd.h @@ -0,0 +1,52 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_SCATTERND_H_ +#define TIM_VX_OPS_SCATTERND_H_ +#include "tim/vx/operation.h" + +namespace tim { +namespace vx { +namespace ops { + +/** + * ## ScatterND + * + * Scatter updates into a new tensor according to indices. + * + * - shape : The shape of the resulting tensor. + */ + +class ScatterND : public Operation { + public: + ScatterND(Graph* graph, const std::vector& shape); + + protected: + const std::vector shape_; +}; + +} // namespace ops +} // namespace vx +} // namespace tim + +#endif /* TIM_VX_OPS_SCATTERND_H_ */ diff --git a/src/tim/vx/internal/include/ops/vsi_nn_op_scatter_nd.h b/src/tim/vx/internal/include/ops/vsi_nn_op_scatter_nd.h index 9464f76..b0b3b24 100644 --- a/src/tim/vx/internal/include/ops/vsi_nn_op_scatter_nd.h +++ b/src/tim/vx/internal/include/ops/vsi_nn_op_scatter_nd.h @@ -32,8 +32,8 @@ extern "C" { typedef struct _vsi_nn_scatter_nd_param { - uint32_t dim_num; - uint32_t* shape; + uint32_t dim_num; + const uint32_t* shape; } vsi_nn_scatter_nd_param; #ifdef __cplusplus diff --git a/src/tim/vx/ops/README.md b/src/tim/vx/ops/README.md index 93a3485..448a208 100644 --- a/src/tim/vx/ops/README.md +++ b/src/tim/vx/ops/README.md @@ -90,7 +90,7 @@ Mish|MISH|Mapped|[tfa.activations.mish](https://tensorflow.google.cn/addons/api_ Resize1d|RESIZE_1D|Mapped|[Onnx.resize 1D image](https://github.com/onnx/onnx/blob/master/docs/Operators.md#resize) |Linear|LINEAR|Unmapped|[tf.keras.activations.linear](https://www.tensorflow.org/api_docs/python/tf/keras/activations/linear) ||MOMENTS|Unmapped|[tf.moments](https://tensorflow.google.cn/api_docs/python/tf/nn/moments) -||SCATTER_ND|Unmapped|[tf.scatter_nd](https://tensorflow.google.cn/api_docs/python/tf/scatter_nd) +ScatterND|SCATTER_ND|Mapped|[tf.scatter_nd](https://tensorflow.google.cn/api_docs/python/tf/scatter_nd) ||PROPOSAL|Unmapped|[Faster-RCNN Proposal Layer](https://github.com/intel/caffe/blob/master/examples/faster-rcnn/lib/rpn/proposal_layer.py) ||MATRIXMUL|Unmapped|[tf.experimental.numpy.matmul](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy/matmul) ||SIGNAL_FRAME|Unmapped|[tf.signal.frame](https://tensorflow.google.cn/api_docs/python/tf/signal/frame) diff --git a/src/tim/vx/ops/scatternd.cc b/src/tim/vx/ops/scatternd.cc new file mode 100644 index 0000000..d979d81 --- /dev/null +++ b/src/tim/vx/ops/scatternd.cc @@ -0,0 +1,40 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/ops/scatternd.h" + +#include "operation_private.h" +#include "vsi_nn_pub.h" + +namespace tim { +namespace vx { +namespace ops { + +ScatterND::ScatterND(Graph* graph, const std::vector& shape) + : Operation(graph, VSI_NN_OP_SCATTER_ND), shape_(shape) { + this->impl()->node()->nn_param.scatter_nd.dim_num = shape_.size(); + this->impl()->node()->nn_param.scatter_nd.shape = shape_.data(); +} +} // namespace ops +} // namespace vx +} // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/scatternd_test.cc b/src/tim/vx/ops/scatternd_test.cc new file mode 100644 index 0000000..1abd3f2 --- /dev/null +++ b/src/tim/vx/ops/scatternd_test.cc @@ -0,0 +1,124 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops/scatternd.h" + +#include "gtest/gtest.h" + +TEST(ScatterND, shape_4_4_4) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType indices_shape({1,2}); + tim::vx::ShapeType updates_shape({4,4,2}); + tim::vx::ShapeType out_shape({4, 4, 4}); + tim::vx::TensorSpec indices_spec(tim::vx::DataType::INT32, + indices_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec updates_spec(tim::vx::DataType::FLOAT32, + updates_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + out_shape, tim::vx::TensorAttribute::OUTPUT); + + auto indices_tensor = graph->CreateTensor(indices_spec); + auto updates_tensor = graph->CreateTensor(updates_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector indices_data = { 0, 2 }; + std::vector updates_data = { + 5,5,5,5, 6,6,6,6, + 7,7,7,7, 8,8,8,8, + 1,1,1,1, 2,2,2,2, + 3,3,3,3, 4,4,4,4, + }; + std::vector golden = { + 5,5,5,5, 6,6,6,6, + 7,7,7,7, 8,8,8,8, + 0,0,0,0, 0,0,0,0, + 0,0,0,0, 0,0,0,0, + 1,1,1,1, 2,2,2,2, + 3,3,3,3, 4,4,4,4, + 0,0,0,0, 0,0,0,0, + 0,0,0,0, 0,0,0,0, + }; + + EXPECT_TRUE(indices_tensor->CopyDataToTensor( + indices_data.data(), indices_data.size()*sizeof(int32_t))); + EXPECT_TRUE(updates_tensor->CopyDataToTensor( + updates_data.data(), updates_data.size()*sizeof(int32_t))); + std::vector shape = {4, 4, 4}; + auto op = graph->CreateOperation(shape); + (*op).BindInputs({indices_tensor, updates_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output(golden.size()); + + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_EQ(golden, output); +} + +TEST(ScatterND, shape_9) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType indices_shape({4}); + tim::vx::ShapeType updates_shape({4}); + tim::vx::ShapeType out_shape({9}); + tim::vx::Quantization updates_quant(tim::vx::QuantType::ASYMMETRIC, 0.5, 0); + tim::vx::Quantization output_quant(tim::vx::QuantType::ASYMMETRIC, 0.5, 0); + tim::vx::TensorSpec indices_spec(tim::vx::DataType::INT32, + indices_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec updates_spec(tim::vx::DataType::UINT8, + updates_shape, tim::vx::TensorAttribute::INPUT, updates_quant); + tim::vx::TensorSpec output_spec(tim::vx::DataType::UINT8, + out_shape, tim::vx::TensorAttribute::OUTPUT, output_quant); + + auto indices_tensor = graph->CreateTensor(indices_spec); + auto updates_tensor = graph->CreateTensor(updates_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector indices_data = { 4, 3, 1, 7 }; + std::vector updates_data = { + 18, 20, 22, 24 + }; + std::vector golden = { + 0, 22, 0, 20, 18, 0, 0, 24, 0 + }; + + EXPECT_TRUE(indices_tensor->CopyDataToTensor( + indices_data.data(), indices_data.size())); + EXPECT_TRUE(updates_tensor->CopyDataToTensor( + updates_data.data(), updates_data.size())); + std::vector shape = {9}; + auto op = graph->CreateOperation(shape); + (*op).BindInputs({indices_tensor, updates_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output(golden.size()); + + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_EQ(golden, output); +}