From e63059857b06dacdde6f1c5aee98a3dd7df55f4d Mon Sep 17 00:00:00 2001 From: Sven Date: Tue, 1 Mar 2022 17:04:02 +0800 Subject: [PATCH] Update reshape to reshape2 (#310) Update built-in op reshape to reshape2 Signed-off-by: xiang.zhang --- include/tim/vx/ops/reshape.h | 7 ++---- src/tim/vx/ops/reshape.cc | 43 ++++++++++++++++++++++++++++++++---- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/include/tim/vx/ops/reshape.h b/include/tim/vx/ops/reshape.h index b02ccc9..85843de 100644 --- a/include/tim/vx/ops/reshape.h +++ b/include/tim/vx/ops/reshape.h @@ -37,14 +37,11 @@ namespace ops { * - size : defining the shape of the output tensor. */ -class Reshape : public DirectMapOp { +class Reshape : public Operation{ public: - Reshape(Graph* graph, const std::vector& size); + Reshape(Graph* graph, const std::vector& target_shape); std::shared_ptr Clone(std::shared_ptr& graph) const override; - - protected: - std::vector size_; }; } // namespace ops diff --git a/src/tim/vx/ops/reshape.cc b/src/tim/vx/ops/reshape.cc index d3e80f0..a197abb 100644 --- a/src/tim/vx/ops/reshape.cc +++ b/src/tim/vx/ops/reshape.cc @@ -26,19 +26,54 @@ #include "direct_map_op_impl.h" #include "vsi_nn_pub.h" +#include + namespace tim { namespace vx { namespace ops { +class ReshapeImpl : public DirectMapOpImpl { + public: + ReshapeImpl(Graph* graph, const std::vector& shape) + : DirectMapOpImpl(graph, + #ifdef _VSI_NN_OP_RESHAPE2_H + VSI_NN_OP_RESHAPE2 + #else + VSI_NN_OP_RESHAPE + #endif + , 1, 1), shape_(shape) {} + + std::vector shape_; +}; Reshape::Reshape(Graph* graph, const std::vector& size) - : DirectMapOp(graph, VSI_NN_OP_RESHAPE), size_(std::move(size)) { - this->impl()->node()->nn_param.reshape.size = size_.data(); - this->impl()->node()->nn_param.reshape.dim_num = size_.size(); +{ + std::vector shape; + std::transform(size.begin(), size.end(), std::back_inserter(shape), [](const uint32_t& d){ + return static_cast(d); + }); + + auto lcl_impl = std::make_unique(graph, shape); + + #ifdef _VSI_NN_OP_RESHAPE2_H + lcl_impl->node()->nn_param.reshape2.size = lcl_impl->shape_.data(); + lcl_impl->node()->nn_param.reshape2.dim_num = size.size(); + #else + lcl_impl->node()->nn_param.reshape.size = lcl_impl->shape_.data(); + lcl_impl->node()->nn_param.reshape.dim_num = size.size(); + #endif + + impl_.reset(dynamic_cast(lcl_impl.release())); } std::shared_ptr Reshape::Clone( std::shared_ptr& graph) const { - return graph->CreateOperation(this->size_); + std::vector size; + const ReshapeImpl* lcl_impl = (dynamic_cast(impl_.get())); + std::transform(lcl_impl->shape_.begin(), lcl_impl->shape_.end(), std::back_inserter(size), [](const vsi_size_t& d){ + return static_cast(d); + }); + + return graph->CreateOperation(size); } } // namespace ops