Update reshape to reshape2 (#310)
Update built-in op reshape to reshape2 Signed-off-by: xiang.zhang <xiang.zhang@verisilicon.com>
This commit is contained in:
parent
c8a25d32ad
commit
e63059857b
|
|
@ -37,14 +37,11 @@ namespace ops {
|
||||||
* - size : defining the shape of the output tensor.
|
* - size : defining the shape of the output tensor.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
class Reshape : public DirectMapOp {
|
class Reshape : public Operation{
|
||||||
public:
|
public:
|
||||||
Reshape(Graph* graph, const std::vector<uint32_t>& size);
|
Reshape(Graph* graph, const std::vector<uint32_t>& target_shape);
|
||||||
|
|
||||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||||
|
|
||||||
protected:
|
|
||||||
std::vector<uint32_t> size_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
||||||
|
|
@ -26,19 +26,54 @@
|
||||||
#include "direct_map_op_impl.h"
|
#include "direct_map_op_impl.h"
|
||||||
#include "vsi_nn_pub.h"
|
#include "vsi_nn_pub.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
namespace tim {
|
namespace tim {
|
||||||
namespace vx {
|
namespace vx {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
class ReshapeImpl : public DirectMapOpImpl {
|
||||||
|
public:
|
||||||
|
ReshapeImpl(Graph* graph, const std::vector<vsi_size_t>& shape)
|
||||||
|
: DirectMapOpImpl(graph,
|
||||||
|
#ifdef _VSI_NN_OP_RESHAPE2_H
|
||||||
|
VSI_NN_OP_RESHAPE2
|
||||||
|
#else
|
||||||
|
VSI_NN_OP_RESHAPE
|
||||||
|
#endif
|
||||||
|
, 1, 1), shape_(shape) {}
|
||||||
|
|
||||||
|
std::vector<vsi_size_t> shape_;
|
||||||
|
};
|
||||||
|
|
||||||
Reshape::Reshape(Graph* graph, const std::vector<uint32_t>& size)
|
Reshape::Reshape(Graph* graph, const std::vector<uint32_t>& size)
|
||||||
: DirectMapOp(graph, VSI_NN_OP_RESHAPE), size_(std::move(size)) {
|
{
|
||||||
this->impl()->node()->nn_param.reshape.size = size_.data();
|
std::vector<vsi_size_t> shape;
|
||||||
this->impl()->node()->nn_param.reshape.dim_num = size_.size();
|
std::transform(size.begin(), size.end(), std::back_inserter(shape), [](const uint32_t& d){
|
||||||
|
return static_cast<vsi_size_t>(d);
|
||||||
|
});
|
||||||
|
|
||||||
|
auto lcl_impl = std::make_unique<ReshapeImpl>(graph, shape);
|
||||||
|
|
||||||
|
#ifdef _VSI_NN_OP_RESHAPE2_H
|
||||||
|
lcl_impl->node()->nn_param.reshape2.size = lcl_impl->shape_.data();
|
||||||
|
lcl_impl->node()->nn_param.reshape2.dim_num = size.size();
|
||||||
|
#else
|
||||||
|
lcl_impl->node()->nn_param.reshape.size = lcl_impl->shape_.data();
|
||||||
|
lcl_impl->node()->nn_param.reshape.dim_num = size.size();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
impl_.reset(dynamic_cast<OpImpl*>(lcl_impl.release()));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Operation> Reshape::Clone(
|
std::shared_ptr<Operation> Reshape::Clone(
|
||||||
std::shared_ptr<Graph>& graph) const {
|
std::shared_ptr<Graph>& graph) const {
|
||||||
return graph->CreateOperation<Reshape>(this->size_);
|
std::vector<uint32_t> size;
|
||||||
|
const ReshapeImpl* lcl_impl = (dynamic_cast<ReshapeImpl*>(impl_.get()));
|
||||||
|
std::transform(lcl_impl->shape_.begin(), lcl_impl->shape_.end(), std::back_inserter(size), [](const vsi_size_t& d){
|
||||||
|
return static_cast<uint32_t>(d);
|
||||||
|
});
|
||||||
|
|
||||||
|
return graph->CreateOperation<Reshape>(size);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue