From 214cbe5874e6081782cd747e06120435c3a20092 Mon Sep 17 00:00:00 2001 From: Antkillerfarm Date: Tue, 9 Nov 2021 20:25:02 +0800 Subject: [PATCH] add Global Pool2d & Adaptive Pool2d (#210) --- include/tim/vx/ops/pool2d.h | 37 +++++++++++++++++++++++++++++++++++-- src/tim/vx/ops/pool2d.cc | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/include/tim/vx/ops/pool2d.h b/include/tim/vx/ops/pool2d.h index fbb2041..73188e7 100644 --- a/include/tim/vx/ops/pool2d.h +++ b/include/tim/vx/ops/pool2d.h @@ -36,6 +36,8 @@ namespace ops { /** * ## Pool2d * + * ### Classic Pool2d + * * Performs an 2-D pooling operation. * * - type : MAX, AVG, L2 or AVG_ANDROID. @@ -43,10 +45,27 @@ namespace ops { * - ksize : filter size. * - stride : stride along each spatial axis. * - round_type : CEILING or FLOOR. + * + * ### Global Pool2d + * + * - type : MAX, AVG, L2 or AVG_ANDROID. + * - input_size : input size(only [W, H]) + * - round_type : CEILING or FLOOR. + * + * ### Adaptive Pool2d + * + * Same as torch.nn.AdaptiveXXXPool2d. + * + * - type : MAX, AVG, L2 or AVG_ANDROID. + * - input_size : input size(only [W, H]) + * - output_size : output size(only [W, H]) + * - round_type : CEILING or FLOOR. + * */ class Pool2d : public Operation { public: + // for Classic Pool2d Pool2d(Graph* graph, PoolType type, PadType padding, const std::array& ksize, const std::array& stride, @@ -59,13 +78,27 @@ class Pool2d : public Operation { RoundType round_type = RoundType::FLOOR, DataLayout layout = DataLayout::WHCN); + // for Global Pool2d + Pool2d(Graph* graph, PoolType type, + const std::array& input_size, + RoundType round_type = RoundType::FLOOR, + DataLayout layout = DataLayout::WHCN); + + // for Adaptive Pool2d + Pool2d(Graph* graph, PoolType type, + const std::array& input_size, + const std::array& output_size, + RoundType round_type = RoundType::FLOOR, + DataLayout layout = DataLayout::WHCN); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + void Init(); protected: const PoolType type_; const PadType padding_; - const std::array ksize_; - const std::array stride_; + std::array ksize_; + std::array stride_; const RoundType round_type_; const std::array pad_; }; diff --git a/src/tim/vx/ops/pool2d.cc b/src/tim/vx/ops/pool2d.cc index 7ae5d38..083d86f 100644 --- a/src/tim/vx/ops/pool2d.cc +++ b/src/tim/vx/ops/pool2d.cc @@ -31,6 +31,7 @@ namespace tim { namespace vx { namespace ops { +// for Classic Pool2d Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding, const std::array& ksize, const std::array& stride, RoundType round_type, @@ -61,6 +62,37 @@ Pool2d::Pool2d(Graph* graph, PoolType type, : Operation(graph, VSI_NN_OP_POOL, 1, 1, layout), type_(type), padding_(PadType::AUTO), ksize_(ksize), stride_(stride), round_type_(round_type), pad_(pad) { + Init(); +} + +// for Global Pool2d +Pool2d::Pool2d(Graph* graph, PoolType type, + const std::array& input_size, + RoundType round_type, + DataLayout layout) + : Operation(graph, VSI_NN_OP_POOL, 1, 1, layout), + type_(type), padding_(PadType::AUTO), ksize_(input_size), stride_(input_size), + round_type_(round_type), pad_({0, 0, 0, 0}) { + Init(); +} + +// for Adaptive Pool2d +Pool2d::Pool2d(Graph* graph, PoolType type, + const std::array& input_size, + const std::array& output_size, + RoundType round_type, + DataLayout layout) + : Operation(graph, VSI_NN_OP_POOL, 1, 1, layout), + type_(type), padding_(PadType::AUTO), + round_type_(round_type), pad_({0, 0, 0, 0}) { + stride_[0] = floor(input_size[0] / (float)(output_size[0])); + stride_[1] = floor(input_size[1] / (float)(output_size[1])); + ksize_[0] = input_size[0] - (output_size[0] - 1) * stride_[0]; + ksize_[1] = input_size[1] - (output_size[1] - 1) * stride_[1]; + Init(); +} + +void Pool2d::Init() { this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_); this->impl()->node()->nn_param.pool.round_type = TranslateRoundType(round_type_);