From 944fdfad8f660c69c3245761558512ecda566c55 Mon Sep 17 00:00:00 2001 From: Chen Xin Date: Mon, 8 Aug 2022 19:07:06 +0800 Subject: [PATCH] Mapped GRUCell & unit test Signed-off-by: Chen Xin --- include/tim/vx/ops/grucell.h | 75 ++++++++++++++++++++ src/tim/vx/ops/README.md | 2 +- src/tim/vx/ops/grucell.cc | 53 ++++++++++++++ src/tim/vx/ops/grucell_test.cc | 122 +++++++++++++++++++++++++++++++++ 4 files changed, 251 insertions(+), 1 deletion(-) create mode 100644 include/tim/vx/ops/grucell.h create mode 100644 src/tim/vx/ops/grucell.cc create mode 100644 src/tim/vx/ops/grucell_test.cc diff --git a/include/tim/vx/ops/grucell.h b/include/tim/vx/ops/grucell.h new file mode 100644 index 0000000..f6f2c56 --- /dev/null +++ b/include/tim/vx/ops/grucell.h @@ -0,0 +1,75 @@ +/**************************************************************************** +* +* Copyright (c) 2022 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_GRUCELL_H_ +#define TIM_VX_OPS_GRUCELL_H_ + +#include +#include "tim/vx/direct_map_op.h" +#include "vsi_nn_pub.h" + +namespace tim { +namespace vx { +namespace ops { + +/** + * ## GRUCell + * + * - num_units : dimensionality of the output space. + * - activation : Activation function to use. + * - recurrent_activation : Activation function to use for the recurrent step. + * - reset_after : whether to apply reset gate after or before matrix multiplication. + * False = "before", True = "after". + */ + +class GRUCell : public DirectMapOp { + public: + enum ActivationType { + kNONE = 0, + kRELU = 1, + kRELU6 = 3, + kTANH = 4, + kSIGMOID = 6, + kHARDSIGMOID = 31, /* temporary use 31 */ + }; + + GRUCell(Graph* graph, uint32_t num_units, + ActivationType activation = ActivationType::kTANH, + ActivationType recurrent_activation = ActivationType::kSIGMOID, + vsi_bool reset_after = TRUE); + + std::shared_ptr Clone( + std::shared_ptr& graph) const override; + + protected: + const uint32_t num_units_; + const ActivationType activation_; + const ActivationType recurrent_activation_; + const int32_t reset_after_; +}; + +} // namespace ops +} // namespace vx +} // namespace tim + +#endif /* TIM_VX_OPS_GRUCELL_H_ */ \ No newline at end of file diff --git a/src/tim/vx/ops/README.md b/src/tim/vx/ops/README.md index 73aa1b4..056a149 100644 --- a/src/tim/vx/ops/README.md +++ b/src/tim/vx/ops/README.md @@ -111,7 +111,7 @@ GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow. ||ROI_POOL|Planned 22Q4|[ANEURALNETWORKS_ROI_POOLING](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a6736198af337b2efbdb0b6b64dee7fe4) ROI_Align||ROI_ALIGN|Mapped|[ANEURALNETWORKS_ROI_ALIGN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a2848b39dd4bfba78f2438fda0d9397a4) TopK||TOPK|Mapped (limited support)|[tf.math.top_k](https://tensorflow.google.cn/api_docs/python/tf/math/top_k) -|GRUCell|GRUCELL_OVXLIB|Planned 22Q3|[tf.keras.layers.GRUCell](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/GRUCell?hl=en) +|GRUCell|GRUCELL_OVXLIB|Mapped|[tf.keras.layers.GRUCell](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/GRUCell?hl=en) |UnidirectionalSequenceGRU|GRU_OVXLIB|Planned 22Q3|[tf.keras.layers.GRU](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/GRUCell?hl=en) |UnidirectionalSequenceRNN|UNIDIRECTIONAL_SEQUENCE_RNN|Planned 22Q3|[ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0ae11aa1d461d2abaa117f6ee2cb503dd8) |BidirectionalSequenceRNN|BIDIRECTIONAL_SEQUENCE_RNN|Planned 22Q3|[ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a487fc5ae247de828f13e62b99f259f3c) diff --git a/src/tim/vx/ops/grucell.cc b/src/tim/vx/ops/grucell.cc new file mode 100644 index 0000000..3314731 --- /dev/null +++ b/src/tim/vx/ops/grucell.cc @@ -0,0 +1,53 @@ +/**************************************************************************** +* +* Copyright (c) 2022 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/ops/grucell.h" +#include "direct_map_op_impl.h" +#include "type_utils.h" +#include "vsi_nn_pub.h" + +namespace tim { +namespace vx { +namespace ops { +GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation, + ActivationType recurrent_activation, vsi_bool reset_after) + : DirectMapOp(graph, VSI_NN_OP_GRUCELL), + num_units_(num_units), + activation_(activation), + recurrent_activation_(recurrent_activation), + reset_after_(reset_after) { + this->impl()->node()->nn_param.grucell.num_units = num_units; + this->impl()->node()->nn_param.grucell.activation = activation; + this->impl()->node()->nn_param.grucell.recurrent_activation = recurrent_activation; + this->impl()->node()->nn_param.grucell.reset_after = reset_after; +} + +std::shared_ptr GRUCell::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->num_units_, this->activation_, + this->recurrent_activation_, + this->reset_after_); +} + +} // namespace ops +} // namespace vx +} // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/grucell_test.cc b/src/tim/vx/ops/grucell_test.cc new file mode 100644 index 0000000..31d76d9 --- /dev/null +++ b/src/tim/vx/ops/grucell_test.cc @@ -0,0 +1,122 @@ +/**************************************************************************** +* +* Copyright (c) 2022 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops/grucell.h" + +#include "gtest/gtest.h" +#include "test_utils.h" + +std::shared_ptr make_empty_tensor( + std::shared_ptr graph, const tim::vx::ShapeType& shape, + const tim::vx::TensorAttribute& role); + +TEST(GRUCell, unit_4) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + uint32_t num_units = 2; + uint32_t feature = 4; + uint32_t batchs = 1; + tim::vx::ShapeType in_shape({feature, batchs}); + tim::vx::ShapeType hstate_and_out_shape({num_units, batchs}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, in_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + hstate_and_out_shape, + tim::vx::TensorAttribute::OUTPUT); + + tim::vx::TensorSpec kernel_i_spec(tim::vx::DataType::FLOAT32, + tim::vx::ShapeType({feature, num_units}), + tim::vx::TensorAttribute::CONSTANT); + + tim::vx::TensorSpec kernel_r_spec(tim::vx::DataType::FLOAT32, + tim::vx::ShapeType({num_units, num_units}), + tim::vx::TensorAttribute::CONSTANT); + + std::vector kernel_i2z = {-0.1201707124710083, 0.051147401332855225, + -0.02161085605621338, 0.2582472562789917, + -0.7641150951385498, 0.27272117137908936, + 0.4013441801071167, -0.43467071652412415}; + std::vector kernel_i2r = {-0.34522661566734314, 0.11888366937637329, + 0.6542353630065918, 0.6331415176391602, + -0.2489457130432129, -0.47332942485809326, + -0.7532100081443787, 0.46069061756134033}; + std::vector kernel_i2h = { + -0.0012096166610717773, -0.05206263065338135, -0.418102502822876, + -0.20800292491912842, -0.5549647808074951, -0.1337134838104248, + 0.14222955703735352, -0.21347862482070923}; + std::vector kernel_r2z = {-0.49559473991394043, -0.10428445041179657, + 0.39165210723876953, 0.38152191042900085}; + std::vector kernel_r2r = {0.03387263044714928, -0.39444485306739807, + 0.4542817771434784, -0.4098765254020691}; + std::vector kernel_r2h = {-0.5441233515739441, -0.35663682222366333, + -0.3120974004268646, 0.6267299056053162}; + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto h_state_tensor = graph->CreateTensor(output_spec); + + auto kernel_i2z_tensor = + graph->CreateTensor(kernel_i_spec, kernel_i2z.data()); + auto kernel_i2r_tensor = + graph->CreateTensor(kernel_i_spec, kernel_i2r.data()); + auto kernel_i2h_tensor = + graph->CreateTensor(kernel_i_spec, kernel_i2h.data()); + + auto kernel_r2z_tensor = + graph->CreateTensor(kernel_r_spec, kernel_r2z.data()); + auto kernel_r2r_tensor = + graph->CreateTensor(kernel_r_spec, kernel_r2r.data()); + auto kernel_r2h_tensor = + graph->CreateTensor(kernel_r_spec, kernel_r2h.data()); + + std::vector in_data = {1, 2, 3, 4}; + std::vector golden = {-0.2719525, -0.5766771}; + + EXPECT_TRUE( + input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4)); + auto op = graph->CreateOperation(num_units); + (*op) + .BindInputs({ + input_tensor, + make_empty_tensor(graph, tim::vx::ShapeType(hstate_and_out_shape), + tim::vx::TensorAttribute::INPUT), //h_state + kernel_i2z_tensor, //KERNEL_I2 + kernel_i2r_tensor, //KERNEL_I2 + kernel_i2h_tensor, //KERNEL_I2 + kernel_r2z_tensor, //KERNEL_R2 + kernel_r2r_tensor, //KERNEL_R2 + kernel_r2h_tensor, //KERNEL_R2 + }) + .BindOutputs({output_tensor, h_state_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); +}