Update reshape to reshape2 (#310)
Update built-in op reshape to reshape2 Signed-off-by: xiang.zhang <xiang.zhang@verisilicon.com>
This commit is contained in:
parent
c8a25d32ad
commit
e63059857b
|
|
@ -37,14 +37,11 @@ namespace ops {
|
|||
* - size : defining the shape of the output tensor.
|
||||
*/
|
||||
|
||||
class Reshape : public DirectMapOp {
|
||||
class Reshape : public Operation{
|
||||
public:
|
||||
Reshape(Graph* graph, const std::vector<uint32_t>& size);
|
||||
Reshape(Graph* graph, const std::vector<uint32_t>& target_shape);
|
||||
|
||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||
|
||||
protected:
|
||||
std::vector<uint32_t> size_;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
|
|
|
|||
|
|
@ -26,19 +26,54 @@
|
|||
#include "direct_map_op_impl.h"
|
||||
#include "vsi_nn_pub.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
class ReshapeImpl : public DirectMapOpImpl {
|
||||
public:
|
||||
ReshapeImpl(Graph* graph, const std::vector<vsi_size_t>& shape)
|
||||
: DirectMapOpImpl(graph,
|
||||
#ifdef _VSI_NN_OP_RESHAPE2_H
|
||||
VSI_NN_OP_RESHAPE2
|
||||
#else
|
||||
VSI_NN_OP_RESHAPE
|
||||
#endif
|
||||
, 1, 1), shape_(shape) {}
|
||||
|
||||
std::vector<vsi_size_t> shape_;
|
||||
};
|
||||
|
||||
Reshape::Reshape(Graph* graph, const std::vector<uint32_t>& size)
|
||||
: DirectMapOp(graph, VSI_NN_OP_RESHAPE), size_(std::move(size)) {
|
||||
this->impl()->node()->nn_param.reshape.size = size_.data();
|
||||
this->impl()->node()->nn_param.reshape.dim_num = size_.size();
|
||||
{
|
||||
std::vector<vsi_size_t> shape;
|
||||
std::transform(size.begin(), size.end(), std::back_inserter(shape), [](const uint32_t& d){
|
||||
return static_cast<vsi_size_t>(d);
|
||||
});
|
||||
|
||||
auto lcl_impl = std::make_unique<ReshapeImpl>(graph, shape);
|
||||
|
||||
#ifdef _VSI_NN_OP_RESHAPE2_H
|
||||
lcl_impl->node()->nn_param.reshape2.size = lcl_impl->shape_.data();
|
||||
lcl_impl->node()->nn_param.reshape2.dim_num = size.size();
|
||||
#else
|
||||
lcl_impl->node()->nn_param.reshape.size = lcl_impl->shape_.data();
|
||||
lcl_impl->node()->nn_param.reshape.dim_num = size.size();
|
||||
#endif
|
||||
|
||||
impl_.reset(dynamic_cast<OpImpl*>(lcl_impl.release()));
|
||||
}
|
||||
|
||||
std::shared_ptr<Operation> Reshape::Clone(
|
||||
std::shared_ptr<Graph>& graph) const {
|
||||
return graph->CreateOperation<Reshape>(this->size_);
|
||||
std::vector<uint32_t> size;
|
||||
const ReshapeImpl* lcl_impl = (dynamic_cast<ReshapeImpl*>(impl_.get()));
|
||||
std::transform(lcl_impl->shape_.begin(), lcl_impl->shape_.end(), std::back_inserter(size), [](const vsi_size_t& d){
|
||||
return static_cast<uint32_t>(d);
|
||||
});
|
||||
|
||||
return graph->CreateOperation<Reshape>(size);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
|
|
|
|||
Loading…
Reference in New Issue