Add align_corners support for SpatialTransformer
Signed-off-by: Kainan Cha <kainan.zha@verisilicon.com>
This commit is contained in:
parent
4d4bc08d6a
commit
6a949bb315
|
|
@ -34,7 +34,7 @@ namespace ops {
|
||||||
*
|
*
|
||||||
* 'Spatial Transformer Networks', Jaderberg et. al,
|
* 'Spatial Transformer Networks', Jaderberg et. al,
|
||||||
* (https://arxiv.org/abs/1506.02025)
|
* (https://arxiv.org/abs/1506.02025)
|
||||||
*
|
*
|
||||||
* - theta : Affine transform tensor of shape (B, 6). Permits cropping,
|
* - theta : Affine transform tensor of shape (B, 6). Permits cropping,
|
||||||
translation and isotropic scaling. Initialize to identity matrix.
|
translation and isotropic scaling. Initialize to identity matrix.
|
||||||
It is the output of the localization network.
|
It is the output of the localization network.
|
||||||
|
|
@ -46,7 +46,7 @@ class SpatialTransformer : public Operation {
|
||||||
bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3,
|
bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3,
|
||||||
bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3,
|
bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3,
|
||||||
float theta_1_1, float theta_1_2, float theta_1_3,
|
float theta_1_1, float theta_1_2, float theta_1_3,
|
||||||
float theta_2_1, float theta_2_2, float theta_2_3);
|
float theta_2_1, float theta_2_2, float theta_2_3, bool align_corners = false);
|
||||||
|
|
||||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||||
|
|
||||||
|
|
@ -65,10 +65,11 @@ class SpatialTransformer : public Operation {
|
||||||
float theta_2_1_;
|
float theta_2_1_;
|
||||||
float theta_2_2_;
|
float theta_2_2_;
|
||||||
float theta_2_3_;
|
float theta_2_3_;
|
||||||
|
bool align_corners_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
||||||
#endif /* TIM_VX_OPS_SPATIAL_TRANSFORMER_H_ */
|
#endif /* TIM_VX_OPS_SPATIAL_TRANSFORMER_H_ */
|
||||||
|
|
|
||||||
|
|
@ -34,12 +34,13 @@ SpatialTransformer::SpatialTransformer(Graph* graph, uint32_t output_h, uint32_t
|
||||||
bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3,
|
bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3,
|
||||||
bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3,
|
bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3,
|
||||||
float theta_1_1, float theta_1_2, float theta_1_3,
|
float theta_1_1, float theta_1_2, float theta_1_3,
|
||||||
float theta_2_1, float theta_2_2, float theta_2_3)
|
float theta_2_1, float theta_2_2, float theta_2_3, bool align_corners)
|
||||||
: Operation(graph, VSI_NN_OP_SPATIAL_TRANSFORMER, 2, 1), output_h_(output_h), output_w_(output_w),
|
: Operation(graph, VSI_NN_OP_SPATIAL_TRANSFORMER, 2, 1), output_h_(output_h), output_w_(output_w),
|
||||||
has_theta_1_1_(has_theta_1_1), has_theta_1_2_(has_theta_1_2), has_theta_1_3_(has_theta_1_3),
|
has_theta_1_1_(has_theta_1_1), has_theta_1_2_(has_theta_1_2), has_theta_1_3_(has_theta_1_3),
|
||||||
has_theta_2_1_(has_theta_2_1), has_theta_2_2_(has_theta_2_2), has_theta_2_3_(has_theta_2_3),
|
has_theta_2_1_(has_theta_2_1), has_theta_2_2_(has_theta_2_2), has_theta_2_3_(has_theta_2_3),
|
||||||
theta_1_1_(theta_1_1), theta_1_2_(theta_1_2), theta_1_3_(theta_1_3),
|
theta_1_1_(theta_1_1), theta_1_2_(theta_1_2), theta_1_3_(theta_1_3),
|
||||||
theta_2_1_(theta_2_1), theta_2_2_(theta_2_2), theta_2_3_(theta_2_3) {
|
theta_2_1_(theta_2_1), theta_2_2_(theta_2_2), theta_2_3_(theta_2_3),
|
||||||
|
align_corners_(align_corners) {
|
||||||
this->impl()->node()->nn_param.spatial_transformer.output_H = output_h_;
|
this->impl()->node()->nn_param.spatial_transformer.output_H = output_h_;
|
||||||
this->impl()->node()->nn_param.spatial_transformer.output_W = output_w_;
|
this->impl()->node()->nn_param.spatial_transformer.output_W = output_w_;
|
||||||
this->impl()->node()->nn_param.spatial_transformer.has_theta_1_1 = has_theta_1_1_;
|
this->impl()->node()->nn_param.spatial_transformer.has_theta_1_1 = has_theta_1_1_;
|
||||||
|
|
@ -54,6 +55,7 @@ SpatialTransformer::SpatialTransformer(Graph* graph, uint32_t output_h, uint32_t
|
||||||
this->impl()->node()->nn_param.spatial_transformer.theta_2_1 = theta_2_1_;
|
this->impl()->node()->nn_param.spatial_transformer.theta_2_1 = theta_2_1_;
|
||||||
this->impl()->node()->nn_param.spatial_transformer.theta_2_2 = theta_2_2_;
|
this->impl()->node()->nn_param.spatial_transformer.theta_2_2 = theta_2_2_;
|
||||||
this->impl()->node()->nn_param.spatial_transformer.theta_2_3 = theta_2_3_;
|
this->impl()->node()->nn_param.spatial_transformer.theta_2_3 = theta_2_3_;
|
||||||
|
this->impl()->node()->nn_param.spatial_transformer.align_corners = align_corners_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Operation> SpatialTransformer::Clone(
|
std::shared_ptr<Operation> SpatialTransformer::Clone(
|
||||||
|
|
@ -63,9 +65,9 @@ std::shared_ptr<Operation> SpatialTransformer::Clone(
|
||||||
this->has_theta_1_2_, this->has_theta_1_3_, this->has_theta_2_1_,
|
this->has_theta_1_2_, this->has_theta_1_3_, this->has_theta_2_1_,
|
||||||
this->has_theta_2_2_, this->has_theta_2_3_, this->theta_1_1_,
|
this->has_theta_2_2_, this->has_theta_2_3_, this->theta_1_1_,
|
||||||
this->theta_1_2_, this->theta_1_3_, this->theta_2_1_, this->theta_2_2_,
|
this->theta_1_2_, this->theta_1_3_, this->theta_2_1_, this->theta_2_2_,
|
||||||
this->theta_2_3_);
|
this->theta_2_3_, this->align_corners_);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue