Update reshape to reshape2 (#310)

Update built-in op reshape to reshape2

Signed-off-by: xiang.zhang <xiang.zhang@verisilicon.com>
This commit is contained in:
Sven 2022-03-01 17:04:02 +08:00 committed by GitHub
parent c8a25d32ad
commit e63059857b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 9 deletions

View File

@ -37,14 +37,11 @@ namespace ops {
* - size : defining the shape of the output tensor.
*/
class Reshape : public DirectMapOp {
class Reshape : public Operation{
public:
Reshape(Graph* graph, const std::vector<uint32_t>& size);
Reshape(Graph* graph, const std::vector<uint32_t>& target_shape);
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
protected:
std::vector<uint32_t> size_;
};
} // namespace ops

View File

@ -26,19 +26,54 @@
#include "direct_map_op_impl.h"
#include "vsi_nn_pub.h"
#include <algorithm>
namespace tim {
namespace vx {
namespace ops {
class ReshapeImpl : public DirectMapOpImpl {
public:
ReshapeImpl(Graph* graph, const std::vector<vsi_size_t>& shape)
: DirectMapOpImpl(graph,
#ifdef _VSI_NN_OP_RESHAPE2_H
VSI_NN_OP_RESHAPE2
#else
VSI_NN_OP_RESHAPE
#endif
, 1, 1), shape_(shape) {}
std::vector<vsi_size_t> shape_;
};
Reshape::Reshape(Graph* graph, const std::vector<uint32_t>& size)
: DirectMapOp(graph, VSI_NN_OP_RESHAPE), size_(std::move(size)) {
this->impl()->node()->nn_param.reshape.size = size_.data();
this->impl()->node()->nn_param.reshape.dim_num = size_.size();
{
std::vector<vsi_size_t> shape;
std::transform(size.begin(), size.end(), std::back_inserter(shape), [](const uint32_t& d){
return static_cast<vsi_size_t>(d);
});
auto lcl_impl = std::make_unique<ReshapeImpl>(graph, shape);
#ifdef _VSI_NN_OP_RESHAPE2_H
lcl_impl->node()->nn_param.reshape2.size = lcl_impl->shape_.data();
lcl_impl->node()->nn_param.reshape2.dim_num = size.size();
#else
lcl_impl->node()->nn_param.reshape.size = lcl_impl->shape_.data();
lcl_impl->node()->nn_param.reshape.dim_num = size.size();
#endif
impl_.reset(dynamic_cast<OpImpl*>(lcl_impl.release()));
}
std::shared_ptr<Operation> Reshape::Clone(
std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Reshape>(this->size_);
std::vector<uint32_t> size;
const ReshapeImpl* lcl_impl = (dynamic_cast<ReshapeImpl*>(impl_.get()));
std::transform(lcl_impl->shape_.begin(), lcl_impl->shape_.end(), std::back_inserter(size), [](const vsi_size_t& d){
return static_cast<uint32_t>(d);
});
return graph->CreateOperation<Reshape>(size);
}
} // namespace ops