Add align_corners support for SpatialTransformer

Signed-off-by: Kainan Cha <kainan.zha@verisilicon.com>
This commit is contained in:
Kainan Cha 2021-08-03 10:52:51 +08:00
parent 4d4bc08d6a
commit 6a949bb315
2 changed files with 10 additions and 7 deletions

View File

@ -34,7 +34,7 @@ namespace ops {
* *
* 'Spatial Transformer Networks', Jaderberg et. al, * 'Spatial Transformer Networks', Jaderberg et. al,
* (https://arxiv.org/abs/1506.02025) * (https://arxiv.org/abs/1506.02025)
* *
* - theta : Affine transform tensor of shape (B, 6). Permits cropping, * - theta : Affine transform tensor of shape (B, 6). Permits cropping,
translation and isotropic scaling. Initialize to identity matrix. translation and isotropic scaling. Initialize to identity matrix.
It is the output of the localization network. It is the output of the localization network.
@ -46,7 +46,7 @@ class SpatialTransformer : public Operation {
bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3, bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3,
bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3, bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3,
float theta_1_1, float theta_1_2, float theta_1_3, float theta_1_1, float theta_1_2, float theta_1_3,
float theta_2_1, float theta_2_2, float theta_2_3); float theta_2_1, float theta_2_2, float theta_2_3, bool align_corners = false);
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override; std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
@ -65,10 +65,11 @@ class SpatialTransformer : public Operation {
float theta_2_1_; float theta_2_1_;
float theta_2_2_; float theta_2_2_;
float theta_2_3_; float theta_2_3_;
bool align_corners_;
}; };
} // namespace ops } // namespace ops
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim
#endif /* TIM_VX_OPS_SPATIAL_TRANSFORMER_H_ */ #endif /* TIM_VX_OPS_SPATIAL_TRANSFORMER_H_ */

View File

@ -34,12 +34,13 @@ SpatialTransformer::SpatialTransformer(Graph* graph, uint32_t output_h, uint32_t
bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3, bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3,
bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3, bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3,
float theta_1_1, float theta_1_2, float theta_1_3, float theta_1_1, float theta_1_2, float theta_1_3,
float theta_2_1, float theta_2_2, float theta_2_3) float theta_2_1, float theta_2_2, float theta_2_3, bool align_corners)
: Operation(graph, VSI_NN_OP_SPATIAL_TRANSFORMER, 2, 1), output_h_(output_h), output_w_(output_w), : Operation(graph, VSI_NN_OP_SPATIAL_TRANSFORMER, 2, 1), output_h_(output_h), output_w_(output_w),
has_theta_1_1_(has_theta_1_1), has_theta_1_2_(has_theta_1_2), has_theta_1_3_(has_theta_1_3), has_theta_1_1_(has_theta_1_1), has_theta_1_2_(has_theta_1_2), has_theta_1_3_(has_theta_1_3),
has_theta_2_1_(has_theta_2_1), has_theta_2_2_(has_theta_2_2), has_theta_2_3_(has_theta_2_3), has_theta_2_1_(has_theta_2_1), has_theta_2_2_(has_theta_2_2), has_theta_2_3_(has_theta_2_3),
theta_1_1_(theta_1_1), theta_1_2_(theta_1_2), theta_1_3_(theta_1_3), theta_1_1_(theta_1_1), theta_1_2_(theta_1_2), theta_1_3_(theta_1_3),
theta_2_1_(theta_2_1), theta_2_2_(theta_2_2), theta_2_3_(theta_2_3) { theta_2_1_(theta_2_1), theta_2_2_(theta_2_2), theta_2_3_(theta_2_3),
align_corners_(align_corners) {
this->impl()->node()->nn_param.spatial_transformer.output_H = output_h_; this->impl()->node()->nn_param.spatial_transformer.output_H = output_h_;
this->impl()->node()->nn_param.spatial_transformer.output_W = output_w_; this->impl()->node()->nn_param.spatial_transformer.output_W = output_w_;
this->impl()->node()->nn_param.spatial_transformer.has_theta_1_1 = has_theta_1_1_; this->impl()->node()->nn_param.spatial_transformer.has_theta_1_1 = has_theta_1_1_;
@ -54,6 +55,7 @@ SpatialTransformer::SpatialTransformer(Graph* graph, uint32_t output_h, uint32_t
this->impl()->node()->nn_param.spatial_transformer.theta_2_1 = theta_2_1_; this->impl()->node()->nn_param.spatial_transformer.theta_2_1 = theta_2_1_;
this->impl()->node()->nn_param.spatial_transformer.theta_2_2 = theta_2_2_; this->impl()->node()->nn_param.spatial_transformer.theta_2_2 = theta_2_2_;
this->impl()->node()->nn_param.spatial_transformer.theta_2_3 = theta_2_3_; this->impl()->node()->nn_param.spatial_transformer.theta_2_3 = theta_2_3_;
this->impl()->node()->nn_param.spatial_transformer.align_corners = align_corners_;
} }
std::shared_ptr<Operation> SpatialTransformer::Clone( std::shared_ptr<Operation> SpatialTransformer::Clone(
@ -63,9 +65,9 @@ std::shared_ptr<Operation> SpatialTransformer::Clone(
this->has_theta_1_2_, this->has_theta_1_3_, this->has_theta_2_1_, this->has_theta_1_2_, this->has_theta_1_3_, this->has_theta_2_1_,
this->has_theta_2_2_, this->has_theta_2_3_, this->theta_1_1_, this->has_theta_2_2_, this->has_theta_2_3_, this->theta_1_1_,
this->theta_1_2_, this->theta_1_3_, this->theta_2_1_, this->theta_2_2_, this->theta_1_2_, this->theta_1_3_, this->theta_2_1_, this->theta_2_2_,
this->theta_2_3_); this->theta_2_3_, this->align_corners_);
} }
} // namespace ops } // namespace ops
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim