From bd9c5df70a28befdf75f440b5e122ae40d2d1d88 Mon Sep 17 00:00:00 2001 From: "zhao.xia" Date: Wed, 2 Jun 2021 15:35:53 +0800 Subject: [PATCH] Add pad parameter to pool2d Signed-off-by: zhao.xia --- include/tim/vx/ops/pool2d.h | 7 +++++++ src/tim/vx/ops/pool2d.cc | 25 ++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/include/tim/vx/ops/pool2d.h b/include/tim/vx/ops/pool2d.h index 879a7d3..19c2864 100644 --- a/include/tim/vx/ops/pool2d.h +++ b/include/tim/vx/ops/pool2d.h @@ -52,6 +52,12 @@ class Pool2d : public Operation { const std::array& stride, RoundType round_type = RoundType::FLOOR, DataLayout layout = DataLayout::WHCN); + Pool2d(Graph* graph, PoolType type, + const std::array& pad, + const std::array& ksize, + const std::array& stride, + RoundType round_type = RoundType::FLOOR, + DataLayout layout = DataLayout::WHCN); protected: const PoolType type_; @@ -59,6 +65,7 @@ class Pool2d : public Operation { const std::array ksize_; const std::array stride_; const RoundType round_type_; + const std::array pad_; }; } // namespace ops diff --git a/src/tim/vx/ops/pool2d.cc b/src/tim/vx/ops/pool2d.cc index c3771dc..b63a5a2 100644 --- a/src/tim/vx/ops/pool2d.cc +++ b/src/tim/vx/ops/pool2d.cc @@ -40,7 +40,8 @@ Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding, padding_(padding), ksize_(ksize), stride_(stride), - round_type_(round_type) { + round_type_(round_type), + pad_({0,0,0,0}) { this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_); this->impl()->node()->nn_param.pool.round_type = TranslateRoundType(round_type_); @@ -52,6 +53,28 @@ Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding, this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_); } +Pool2d::Pool2d(Graph* graph, PoolType type, + const std::array& pad, + const std::array& ksize, + const std::array& stride, RoundType round_type, + DataLayout layout) + : Operation(graph, VSI_NN_OP_POOL, 1, 1, layout), + type_(type), padding_(PadType::AUTO), ksize_(ksize), stride_(stride), + round_type_(round_type), pad_(pad) { + this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_); + this->impl()->node()->nn_param.pool.round_type = + TranslateRoundType(round_type_); + this->impl()->node()->nn_param.pool.ksize[0] = ksize_[0]; + this->impl()->node()->nn_param.pool.ksize[1] = ksize_[1]; + this->impl()->node()->nn_param.pool.stride[0] = stride_[0]; + this->impl()->node()->nn_param.pool.stride[1] = stride_[1]; + this->impl()->node()->nn_param.pool.pad[0] = pad_[0]; + this->impl()->node()->nn_param.pool.pad[1] = pad_[1]; + this->impl()->node()->nn_param.pool.pad[2] = pad_[2]; + this->impl()->node()->nn_param.pool.pad[3] = pad_[3]; + this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file