From 6a949bb315b7d711cf6f73056f1022ae23d32579 Mon Sep 17 00:00:00 2001 From: Kainan Cha Date: Tue, 3 Aug 2021 10:52:51 +0800 Subject: [PATCH] Add align_corners support for SpatialTransformer Signed-off-by: Kainan Cha --- include/tim/vx/ops/spatial_transformer.h | 7 ++++--- src/tim/vx/ops/spatial_transformer.cc | 10 ++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/include/tim/vx/ops/spatial_transformer.h b/include/tim/vx/ops/spatial_transformer.h index 6ba359f..3972b07 100644 --- a/include/tim/vx/ops/spatial_transformer.h +++ b/include/tim/vx/ops/spatial_transformer.h @@ -34,7 +34,7 @@ namespace ops { * * 'Spatial Transformer Networks', Jaderberg et. al, * (https://arxiv.org/abs/1506.02025) - * + * * - theta : Affine transform tensor of shape (B, 6). Permits cropping, translation and isotropic scaling. Initialize to identity matrix. It is the output of the localization network. @@ -46,7 +46,7 @@ class SpatialTransformer : public Operation { bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3, bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3, float theta_1_1, float theta_1_2, float theta_1_3, - float theta_2_1, float theta_2_2, float theta_2_3); + float theta_2_1, float theta_2_2, float theta_2_3, bool align_corners = false); std::shared_ptr Clone(std::shared_ptr& graph) const override; @@ -65,10 +65,11 @@ class SpatialTransformer : public Operation { float theta_2_1_; float theta_2_2_; float theta_2_3_; + bool align_corners_; }; } // namespace ops } // namespace vx } // namespace tim -#endif /* TIM_VX_OPS_SPATIAL_TRANSFORMER_H_ */ \ No newline at end of file +#endif /* TIM_VX_OPS_SPATIAL_TRANSFORMER_H_ */ diff --git a/src/tim/vx/ops/spatial_transformer.cc b/src/tim/vx/ops/spatial_transformer.cc index c86cfd1..b30c082 100644 --- a/src/tim/vx/ops/spatial_transformer.cc +++ b/src/tim/vx/ops/spatial_transformer.cc @@ -34,12 +34,13 @@ SpatialTransformer::SpatialTransformer(Graph* graph, uint32_t output_h, uint32_t bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3, bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3, float theta_1_1, float theta_1_2, float theta_1_3, - float theta_2_1, float theta_2_2, float theta_2_3) + float theta_2_1, float theta_2_2, float theta_2_3, bool align_corners) : Operation(graph, VSI_NN_OP_SPATIAL_TRANSFORMER, 2, 1), output_h_(output_h), output_w_(output_w), has_theta_1_1_(has_theta_1_1), has_theta_1_2_(has_theta_1_2), has_theta_1_3_(has_theta_1_3), has_theta_2_1_(has_theta_2_1), has_theta_2_2_(has_theta_2_2), has_theta_2_3_(has_theta_2_3), theta_1_1_(theta_1_1), theta_1_2_(theta_1_2), theta_1_3_(theta_1_3), - theta_2_1_(theta_2_1), theta_2_2_(theta_2_2), theta_2_3_(theta_2_3) { + theta_2_1_(theta_2_1), theta_2_2_(theta_2_2), theta_2_3_(theta_2_3), + align_corners_(align_corners) { this->impl()->node()->nn_param.spatial_transformer.output_H = output_h_; this->impl()->node()->nn_param.spatial_transformer.output_W = output_w_; this->impl()->node()->nn_param.spatial_transformer.has_theta_1_1 = has_theta_1_1_; @@ -54,6 +55,7 @@ SpatialTransformer::SpatialTransformer(Graph* graph, uint32_t output_h, uint32_t this->impl()->node()->nn_param.spatial_transformer.theta_2_1 = theta_2_1_; this->impl()->node()->nn_param.spatial_transformer.theta_2_2 = theta_2_2_; this->impl()->node()->nn_param.spatial_transformer.theta_2_3 = theta_2_3_; + this->impl()->node()->nn_param.spatial_transformer.align_corners = align_corners_; } std::shared_ptr SpatialTransformer::Clone( @@ -63,9 +65,9 @@ std::shared_ptr SpatialTransformer::Clone( this->has_theta_1_2_, this->has_theta_1_3_, this->has_theta_2_1_, this->has_theta_2_2_, this->has_theta_2_3_, this->theta_1_1_, this->theta_1_2_, this->theta_1_3_, this->theta_2_1_, this->theta_2_2_, - this->theta_2_3_); + this->theta_2_3_, this->align_corners_); } } // namespace ops } // namespace vx -} // namespace tim \ No newline at end of file +} // namespace tim