From 2f8f87d1cb72bdc7b88ef9b42e870c8cb24ad179 Mon Sep 17 00:00:00 2001 From: "yuenan.li" Date: Tue, 6 Jul 2021 16:14:41 +0800 Subject: [PATCH] Add Clone API for SpatialTrasformer Signed-off-by: yuenan.li --- include/tim/vx/ops/spatial_transformer.h | 3 +++ src/tim/vx/ops/spatial_transformer.cc | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/include/tim/vx/ops/spatial_transformer.h b/include/tim/vx/ops/spatial_transformer.h index fdfab0c..6ba359f 100644 --- a/include/tim/vx/ops/spatial_transformer.h +++ b/include/tim/vx/ops/spatial_transformer.h @@ -47,6 +47,9 @@ class SpatialTransformer : public Operation { bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3, float theta_1_1, float theta_1_2, float theta_1_3, float theta_2_1, float theta_2_2, float theta_2_3); + + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const uint32_t output_h_; const uint32_t output_w_; diff --git a/src/tim/vx/ops/spatial_transformer.cc b/src/tim/vx/ops/spatial_transformer.cc index d1369df..c86cfd1 100644 --- a/src/tim/vx/ops/spatial_transformer.cc +++ b/src/tim/vx/ops/spatial_transformer.cc @@ -56,6 +56,16 @@ SpatialTransformer::SpatialTransformer(Graph* graph, uint32_t output_h, uint32_t this->impl()->node()->nn_param.spatial_transformer.theta_2_3 = theta_2_3_; } +std::shared_ptr SpatialTransformer::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation( + this->output_h_, this->output_w_, this->has_theta_1_1_, + this->has_theta_1_2_, this->has_theta_1_3_, this->has_theta_2_1_, + this->has_theta_2_2_, this->has_theta_2_3_, this->theta_1_1_, + this->theta_1_2_, this->theta_1_3_, this->theta_2_1_, this->theta_2_2_, + this->theta_2_3_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file