diff --git a/include/tim/vx/ops/spatial_transformer.h b/include/tim/vx/ops/spatial_transformer.h new file mode 100644 index 0000000..fdfab0c --- /dev/null +++ b/include/tim/vx/ops/spatial_transformer.h @@ -0,0 +1,71 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_SPATIAL_TRANSFORMER_H_ +#define TIM_VX_OPS_SPATIAL_TRANSFORMER_H_ +#include "tim/vx/operation.h" + +namespace tim { +namespace vx { +namespace ops { + +/** + * ## Spatial Transformer + * + * 'Spatial Transformer Networks', Jaderberg et. al, + * (https://arxiv.org/abs/1506.02025) + * + * - theta : Affine transform tensor of shape (B, 6). Permits cropping, + translation and isotropic scaling. Initialize to identity matrix. + It is the output of the localization network. + */ + +class SpatialTransformer : public Operation { + public: + SpatialTransformer(Graph* graph, uint32_t output_h, uint32_t output_w, + bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3, + bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3, + float theta_1_1, float theta_1_2, float theta_1_3, + float theta_2_1, float theta_2_2, float theta_2_3); + protected: + const uint32_t output_h_; + const uint32_t output_w_; + bool has_theta_1_1_; + bool has_theta_1_2_; + bool has_theta_1_3_; + bool has_theta_2_1_; + bool has_theta_2_2_; + bool has_theta_2_3_; + float theta_1_1_; + float theta_1_2_; + float theta_1_3_; + float theta_2_1_; + float theta_2_2_; + float theta_2_3_; +}; + +} // namespace ops +} // namespace vx +} // namespace tim + +#endif /* TIM_VX_OPS_SPATIAL_TRANSFORMER_H_ */ \ No newline at end of file diff --git a/src/tim/vx/ops/README.md b/src/tim/vx/ops/README.md index 6a3357a..2ae5658 100644 --- a/src/tim/vx/ops/README.md +++ b/src/tim/vx/ops/README.md @@ -97,6 +97,7 @@ ScatterND|SCATTER_ND|Mapped|[tf.scatter_nd](https://tensorflow.google.cn/api_doc Unstack|UNSTACK|Mapped|[tf.unstack](https://tensorflow.google.cn/api_docs/python/tf/unstack) Tile|TILE|Mapped|[tf.tile](https://tensorflow.google.cn/api_docs/python/tf/tile) GroupedConv2d|GROUPED_CONV2D|Mapped|[ANEURALNETWORKS_GROUPED_CONV_2D](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a847acf8d9f3d2343328c3dbe6d447c50) +SpatialTransformer|SPATIAL_TRANSFORMER|Mapped|[SpatialTransformer](https://github.com/daerduoCarey/SpatialTransformerLayer) ||PROPOSAL|Planned 21Q3|[Faster-RCNN Proposal Layer](https://github.com/intel/caffe/blob/master/examples/faster-rcnn/lib/rpn/proposal_layer.py) ||ROI_POOL|Planned 21Q3|[ANEURALNETWORKS_ROI_POOLING](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a6736198af337b2efbdb0b6b64dee7fe4) ||ROI_ALIGN|Planned 21Q3|[ANEURALNETWORKS_ROI_ALIGN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a2848b39dd4bfba78f2438fda0d9397a4) @@ -151,7 +152,6 @@ GroupedConv2d|GROUPED_CONV2D|Mapped|[ANEURALNETWORKS_GROUPED_CONV_2D](https://de ||PRE_PROCESS_TENSOR|InternalOnly ||IMAGEPROCESS|Deprecated ||POST_PROCESS|InternalOnly -||SPATIAL_TRANSFORMER|InternalOnly|[SpatialTransformer](https://github.com/daerduoCarey/SpatialTransformerLayer) ||EXTRA_ENDING|InternalOnly ||SYNC_HOST|InternalOnly ||BATCHNORM_SINGLE|InternalOnly| diff --git a/src/tim/vx/ops/spatial_transformer.cc b/src/tim/vx/ops/spatial_transformer.cc new file mode 100644 index 0000000..d1369df --- /dev/null +++ b/src/tim/vx/ops/spatial_transformer.cc @@ -0,0 +1,61 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/ops/spatial_transformer.h" + +#include "operation_private.h" +#include "vsi_nn_pub.h" + +namespace tim { +namespace vx { +namespace ops { + +SpatialTransformer::SpatialTransformer(Graph* graph, uint32_t output_h, uint32_t output_w, + bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3, + bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3, + float theta_1_1, float theta_1_2, float theta_1_3, + float theta_2_1, float theta_2_2, float theta_2_3) + : Operation(graph, VSI_NN_OP_SPATIAL_TRANSFORMER, 2, 1), output_h_(output_h), output_w_(output_w), + has_theta_1_1_(has_theta_1_1), has_theta_1_2_(has_theta_1_2), has_theta_1_3_(has_theta_1_3), + has_theta_2_1_(has_theta_2_1), has_theta_2_2_(has_theta_2_2), has_theta_2_3_(has_theta_2_3), + theta_1_1_(theta_1_1), theta_1_2_(theta_1_2), theta_1_3_(theta_1_3), + theta_2_1_(theta_2_1), theta_2_2_(theta_2_2), theta_2_3_(theta_2_3) { + this->impl()->node()->nn_param.spatial_transformer.output_H = output_h_; + this->impl()->node()->nn_param.spatial_transformer.output_W = output_w_; + this->impl()->node()->nn_param.spatial_transformer.has_theta_1_1 = has_theta_1_1_; + this->impl()->node()->nn_param.spatial_transformer.has_theta_1_2 = has_theta_1_2_; + this->impl()->node()->nn_param.spatial_transformer.has_theta_1_3 = has_theta_1_3_; + this->impl()->node()->nn_param.spatial_transformer.has_theta_2_1 = has_theta_2_1_; + this->impl()->node()->nn_param.spatial_transformer.has_theta_2_2 = has_theta_2_2_; + this->impl()->node()->nn_param.spatial_transformer.has_theta_2_3 = has_theta_2_3_; + this->impl()->node()->nn_param.spatial_transformer.theta_1_1 = theta_1_1_; + this->impl()->node()->nn_param.spatial_transformer.theta_1_2 = theta_1_2_; + this->impl()->node()->nn_param.spatial_transformer.theta_1_3 = theta_1_3_; + this->impl()->node()->nn_param.spatial_transformer.theta_2_1 = theta_2_1_; + this->impl()->node()->nn_param.spatial_transformer.theta_2_2 = theta_2_2_; + this->impl()->node()->nn_param.spatial_transformer.theta_2_3 = theta_2_3_; +} + +} // namespace ops +} // namespace vx +} // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/spatial_transformer_test.cc b/src/tim/vx/ops/spatial_transformer_test.cc new file mode 100644 index 0000000..9e32bfc --- /dev/null +++ b/src/tim/vx/ops/spatial_transformer_test.cc @@ -0,0 +1,75 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops/spatial_transformer.h" + +#include "gtest/gtest.h" + +TEST(SpatialTransformer, shape_1_3_3_1_u8) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType in_shape({1, 3, 3, 1}); + tim::vx::ShapeType theta_shape({6}); + tim::vx::ShapeType out_shape({1, 3, 3, 1}); + tim::vx::Quantization io_quant(tim::vx::QuantType::ASYMMETRIC, 0.5, 0); + tim::vx::TensorSpec input_spec(tim::vx::DataType::UINT8, + in_shape, tim::vx::TensorAttribute::INPUT, io_quant); + tim::vx::TensorSpec theta_spec(tim::vx::DataType::UINT8, + theta_shape, tim::vx::TensorAttribute::INPUT, io_quant); + tim::vx::TensorSpec output_spec(tim::vx::DataType::UINT8, + out_shape, tim::vx::TensorAttribute::OUTPUT, io_quant); + + auto input_tensor = graph->CreateTensor(input_spec); + auto theta_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 2, 4, 6, + 2, 4, 6, + 2, 4, 6 }; + std::vector theta_data = { + 2, 2, 2, + 2, 2, 2 }; + std::vector values_golden = { + 2,3,2, + 2,3,2, + 2,3,2 }; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size())); + EXPECT_TRUE(theta_tensor->CopyDataToTensor(theta_data.data(), theta_data.size())); + auto op = graph->CreateOperation( + 3, 3, true, true, true, true, true, true, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 + ); + (*op).BindInputs({input_tensor, theta_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output_values(values_golden.size()); + + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output_values.data())); + EXPECT_EQ(values_golden, output_values); +}