Add Clone API for SpatialTrasformer
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
8aa11f5f29
commit
2f8f87d1cb
|
|
@ -47,6 +47,9 @@ class SpatialTransformer : public Operation {
|
||||||
bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3,
|
bool has_theta_2_1, bool has_theta_2_2, bool has_theta_2_3,
|
||||||
float theta_1_1, float theta_1_2, float theta_1_3,
|
float theta_1_1, float theta_1_2, float theta_1_3,
|
||||||
float theta_2_1, float theta_2_2, float theta_2_3);
|
float theta_2_1, float theta_2_2, float theta_2_3);
|
||||||
|
|
||||||
|
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const uint32_t output_h_;
|
const uint32_t output_h_;
|
||||||
const uint32_t output_w_;
|
const uint32_t output_w_;
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,16 @@ SpatialTransformer::SpatialTransformer(Graph* graph, uint32_t output_h, uint32_t
|
||||||
this->impl()->node()->nn_param.spatial_transformer.theta_2_3 = theta_2_3_;
|
this->impl()->node()->nn_param.spatial_transformer.theta_2_3 = theta_2_3_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Operation> SpatialTransformer::Clone(
|
||||||
|
std::shared_ptr<Graph>& graph) const {
|
||||||
|
return graph->CreateOperation<SpatialTransformer>(
|
||||||
|
this->output_h_, this->output_w_, this->has_theta_1_1_,
|
||||||
|
this->has_theta_1_2_, this->has_theta_1_3_, this->has_theta_2_1_,
|
||||||
|
this->has_theta_2_2_, this->has_theta_2_3_, this->theta_1_1_,
|
||||||
|
this->theta_1_2_, this->theta_1_3_, this->theta_2_1_, this->theta_2_2_,
|
||||||
|
this->theta_2_3_);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
Loading…
Reference in New Issue