From cc34b5f0ead016ec190cc3416636c137406e241d Mon Sep 17 00:00:00 2001 From: meseraph <13070195207@163.com> Date: Sun, 18 Dec 2022 17:56:21 +0800 Subject: [PATCH] mapped pool1d --- include/tim/vx/ops.h | 1 + include/tim/vx/ops/pool1d.h | 110 +++++++++++++++ src/tim/vx/ops/avg_pool_test.cc | 228 ++++++++++++++++++++++++++++++++ src/tim/vx/ops/max_pool_test.cc | 228 ++++++++++++++++++++++++++++++++ src/tim/vx/ops/pool1d.cc | 110 +++++++++++++++ 5 files changed, 677 insertions(+) create mode 100644 include/tim/vx/ops/pool1d.h create mode 100644 src/tim/vx/ops/pool1d.cc diff --git a/include/tim/vx/ops.h b/include/tim/vx/ops.h index cf82e8e..655c933 100644 --- a/include/tim/vx/ops.h +++ b/include/tim/vx/ops.h @@ -63,6 +63,7 @@ #include "tim/vx/ops/onehot.h" #include "tim/vx/ops/pad.h" #include "tim/vx/ops/pad_v2.h" +#include "tim/vx/ops/pool1d.h" #include "tim/vx/ops/pool2d.h" #include "tim/vx/ops/reduce.h" #include "tim/vx/ops/relational_operations.h" diff --git a/include/tim/vx/ops/pool1d.h b/include/tim/vx/ops/pool1d.h new file mode 100644 index 0000000..ebb23ae --- /dev/null +++ b/include/tim/vx/ops/pool1d.h @@ -0,0 +1,110 @@ +/**************************************************************************** +* +* Copyright (c) 2022 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_POOL1D_H_ +#define TIM_VX_OPS_POOL1D_H_ + +#include + +#include "tim/vx/builtin_op.h" +#include "tim/vx/types.h" + +namespace tim { +namespace vx { +namespace ops { + +/** + * ## Pool1d + * + * ### Classic Pool1d + * + * Performs an 1-D pooling operation. + * + * - type : MAX, AVG, L2 or AVG_ANDROID. + * - padding : AUTO, VALID or SAME. + * - pad : Specify the number of pad values for left, right. + * - ksize : filter size. + * - stride : stride along each spatial axis. + * - round_type : CEILING or FLOOR. + * + * ### Global Pool1d + * + * - type : MAX, AVG, L2 or AVG_ANDROID. + * - input_size : input size(only [W]) + * - round_type : CEILING or FLOOR. + * + * ### Adaptive Pool1d + * + * Same as torch.nn.AdaptiveXXXPool1d. + * + * - type : MAX, AVG, L2 or AVG_ANDROID. + * - input_size : input size(only [W]) + * - output_size : output size(only [W]) + * - round_type : CEILING or FLOOR. + * + */ + +class Pool1d : public BuiltinOp { + public: + /* for Classic Pool1d, pool does not support auto-completion of pad value, + you need to specify pad size explicitly, it is recommended to use the second api.*/ + Pool1d(Graph* graph, PoolType type, PadType padding, + uint32_t ksize, + uint32_t stride, + RoundType round_type = RoundType::FLOOR, + DataLayout layout = DataLayout::WCN); + Pool1d(Graph* graph, PoolType type, const std::array& pad, + uint32_t ksize, + uint32_t stride, + RoundType round_type = RoundType::FLOOR, + DataLayout layout = DataLayout::WCN); + + // for Global Pool1d + Pool1d(Graph* graph, PoolType type, uint32_t input_size, + RoundType round_type = RoundType::FLOOR, + DataLayout layout = DataLayout::WCN); + + // for Adaptive Pool1d + Pool1d(Graph* graph, PoolType type, uint32_t input_size, + uint32_t output_size, + RoundType round_type = RoundType::FLOOR, + DataLayout layout = DataLayout::WCN); + + std::shared_ptr Clone( + std::shared_ptr& graph) const override; + void Init(); + + protected: + const PoolType type_; + const PadType padding_; + uint32_t ksize_; + uint32_t stride_; + const RoundType round_type_; + const std::array pad_; +}; + +} // namespace ops +} // namespace vx +} // namespace tim + +#endif /* TIM_VX_OPS_POOL1D_H_ */ \ No newline at end of file diff --git a/src/tim/vx/ops/avg_pool_test.cc b/src/tim/vx/ops/avg_pool_test.cc index 81756bc..90f666e 100644 --- a/src/tim/vx/ops/avg_pool_test.cc +++ b/src/tim/vx/ops/avg_pool_test.cc @@ -24,10 +24,238 @@ #include "tim/vx/context.h" #include "tim/vx/graph.h" #include "tim/vx/ops/pool2d.h" +#include "tim/vx/ops/pool1d.h" #include #include "gtest/gtest.h" #include "test_utils.h" +TEST(AVG, shape_32_3_1_fp32_kernel_2_stride_1) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType in_shape({32, 3, 1}); + tim::vx::ShapeType out_shape({31, 3, 1}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + out_shape, tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 1.764052391052246, + 0.40015721321105957, + 0.978738009929657, + 2.2408931255340576, + 1.8675580024719238, + -0.9772778749465942, + 0.9500884413719177, + -0.15135720372200012, + -0.10321885347366333, + 0.4105985164642334, + 0.14404356479644775, + 1.4542734622955322, + 0.7610377073287964, + 0.12167501449584961, + 0.44386324286460876, + 0.3336743414402008, + 1.4940791130065918, + -0.2051582634449005, + 0.3130677044391632, + -0.8540957570075989, + -2.5529897212982178, + 0.653618574142456, + 0.8644362092018127, + -0.7421650290489197, + 2.269754648208618, + -1.4543657302856445, + 0.04575851559638977, + -0.18718385696411133, + 1.5327792167663574, + 1.4693588018417358, + 0.154947429895401, + 0.37816253304481506, + + -0.8877857327461243, + -1.980796456336975, + -0.34791216254234314, + 0.15634897351264954, + 1.2302906513214111, + 1.202379822731018, + -0.38732680678367615, + -0.302302747964859, + -1.0485529899597168, + -1.420017957687378, + -1.7062702178955078, + 1.950775384902954, + -0.5096521973609924, + -0.4380742907524109, + -1.2527953386306763, + 0.7774903774261475, + -1.6138978004455566, + -0.21274028718471527, + -0.8954665660858154, + 0.38690251111984253, + -0.5108051300048828, + -1.18063223361969, + -0.02818222902715206, + 0.4283318817615509, + 0.06651721894741058, + 0.30247190594673157, + -0.6343221068382263, + -0.3627411723136902, + -0.6724604368209839, + -0.35955315828323364, + -0.8131462931632996, + -1.7262825965881348, + + 0.17742614448070526, + -0.4017809331417084, + -1.630198359489441, + 0.46278226375579834, + -0.9072983860969543, + 0.05194539576768875, + 0.7290905714035034, + 0.12898291647434235, + 1.1394007205963135, + -1.234825849533081, + 0.4023416340351105, + -0.6848101019859314, + -0.8707971572875977, + -0.5788496732711792, + -0.3115525245666504, + 0.056165341287851334, + -1.1651498079299927, + 0.9008265137672424, + 0.4656624495983124, + -1.5362436771392822, + 1.4882521629333496, + 1.895889163017273, + 1.1787796020507812, + -0.1799248307943344, + -1.0707526206970215, + 1.0544517040252686, + -0.4031769335269928, + 1.222445011138916, + 0.2082749754190445, + 0.9766390323638916, + 0.3563663959503174, + 0.7065731883049011 + }; + std::vector golden = { + 1.0821048021316528, + 0.6894476413726807, + 1.6098155975341797, + 2.054225444793701, + 0.4451400637626648, + -0.013594716787338257, + 0.3993656039237976, + -0.12728802859783173, + 0.15368983149528503, + 0.2773210406303406, + 0.79915851354599, + 1.1076555252075195, + 0.441356360912323, + 0.2827691435813904, + 0.3887687921524048, + 0.9138767123222351, + 0.6444604396820068, + 0.05395472049713135, + -0.27051401138305664, + -1.703542709350586, + -0.9496855735778809, + 0.759027361869812, + 0.06113559007644653, + 0.7637947797775269, + 0.4076944589614868, + -0.7043036222457886, + -0.07071267068386078, + 0.672797679901123, + 1.5010690689086914, + 0.8121531009674072, + 0.26655498147010803, + + -1.434291124343872, + -1.1643543243408203, + -0.0957815945148468, + 0.6933197975158691, + 1.2163352966308594, + 0.40752649307250977, + -0.3448147773742676, + -0.6754278540611267, + -1.2342854738235474, + -1.5631440877914429, + 0.12225258350372314, + 0.7205616235733032, + -0.47386324405670166, + -0.8454347848892212, + -0.2376524806022644, + -0.4182037115097046, + -0.9133190512657166, + -0.554103434085846, + -0.25428202748298645, + -0.06195130944252014, + -0.8457186818122864, + -0.6044072508811951, + 0.20007482171058655, + 0.24742454290390015, + 0.18449455499649048, + -0.16592510044574738, + -0.49853163957595825, + -0.5176007747650146, + -0.5160068273544312, + -0.5863497257232666, + -1.2697144746780396, + + -0.11217739433050156, + -1.0159896612167358, + -0.5837080478668213, + -0.222258061170578, + -0.4276764988899231, + 0.3905179798603058, + 0.4290367364883423, + 0.6341918110847473, + -0.04771256446838379, + -0.4162421226501465, + -0.14123423397541046, + -0.7778036594390869, + -0.7248234152793884, + -0.4452010989189148, + -0.12769359350204468, + -0.5544922351837158, + -0.13216164708137512, + 0.6832444667816162, + -0.5352905988693237, + -0.02399575710296631, + 1.692070722579956, + 1.5373344421386719, + 0.4994273781776428, + -0.6253387331962585, + -0.008150458335876465, + 0.3256374001502991, + 0.4096340537071228, + 0.7153599858283997, + 0.5924569964408875, + 0.6665027141571045, + 0.5314698219299316 + }; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size()*4)); + uint32_t ksize = 2; + uint32_t stride = 1; + auto op = graph->CreateOperation(tim::vx::PoolType::AVG, + tim::vx::PadType::VALID, ksize, stride); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_EQ(golden, output); +} + TEST(AVG, shape_3_3_1_2_fp32_kernel_2_stride_1) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); diff --git a/src/tim/vx/ops/max_pool_test.cc b/src/tim/vx/ops/max_pool_test.cc index b3a59a2..fb7c2fb 100644 --- a/src/tim/vx/ops/max_pool_test.cc +++ b/src/tim/vx/ops/max_pool_test.cc @@ -24,10 +24,238 @@ #include "tim/vx/context.h" #include "tim/vx/graph.h" #include "tim/vx/ops/pool2d.h" +#include "tim/vx/ops/pool1d.h" #include #include "gtest/gtest.h" #include "test_utils.h" +TEST(MAX, shape_32_3_1_fp32_kernel_2_stride_1) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType in_shape({32, 3, 1}); + tim::vx::ShapeType out_shape({31, 3, 1}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + out_shape, tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 1.764052391052246, + 0.40015721321105957, + 0.978738009929657, + 2.2408931255340576, + 1.8675580024719238, + -0.9772778749465942, + 0.9500884413719177, + -0.15135720372200012, + -0.10321885347366333, + 0.4105985164642334, + 0.14404356479644775, + 1.4542734622955322, + 0.7610377073287964, + 0.12167501449584961, + 0.44386324286460876, + 0.3336743414402008, + 1.4940791130065918, + -0.2051582634449005, + 0.3130677044391632, + -0.8540957570075989, + -2.5529897212982178, + 0.653618574142456, + 0.8644362092018127, + -0.7421650290489197, + 2.269754648208618, + -1.4543657302856445, + 0.04575851559638977, + -0.18718385696411133, + 1.5327792167663574, + 1.4693588018417358, + 0.154947429895401, + 0.37816253304481506, + + -0.8877857327461243, + -1.980796456336975, + -0.34791216254234314, + 0.15634897351264954, + 1.2302906513214111, + 1.202379822731018, + -0.38732680678367615, + -0.302302747964859, + -1.0485529899597168, + -1.420017957687378, + -1.7062702178955078, + 1.950775384902954, + -0.5096521973609924, + -0.4380742907524109, + -1.2527953386306763, + 0.7774903774261475, + -1.6138978004455566, + -0.21274028718471527, + -0.8954665660858154, + 0.38690251111984253, + -0.5108051300048828, + -1.18063223361969, + -0.02818222902715206, + 0.4283318817615509, + 0.06651721894741058, + 0.30247190594673157, + -0.6343221068382263, + -0.3627411723136902, + -0.6724604368209839, + -0.35955315828323364, + -0.8131462931632996, + -1.7262825965881348, + + 0.17742614448070526, + -0.4017809331417084, + -1.630198359489441, + 0.46278226375579834, + -0.9072983860969543, + 0.05194539576768875, + 0.7290905714035034, + 0.12898291647434235, + 1.1394007205963135, + -1.234825849533081, + 0.4023416340351105, + -0.6848101019859314, + -0.8707971572875977, + -0.5788496732711792, + -0.3115525245666504, + 0.056165341287851334, + -1.1651498079299927, + 0.9008265137672424, + 0.4656624495983124, + -1.5362436771392822, + 1.4882521629333496, + 1.895889163017273, + 1.1787796020507812, + -0.1799248307943344, + -1.0707526206970215, + 1.0544517040252686, + -0.4031769335269928, + 1.222445011138916, + 0.2082749754190445, + 0.9766390323638916, + 0.3563663959503174, + 0.7065731883049011 + }; + std::vector golden = { + 1.764052391052246, + 0.978738009929657, + 2.2408931255340576, + 2.2408931255340576, + 1.8675580024719238, + 0.9500884413719177, + 0.9500884413719177, + -0.10321885347366333, + 0.4105985164642334, + 0.4105985164642334, + 1.4542734622955322, + 1.4542734622955322, + 0.7610377073287964, + 0.44386324286460876, + 0.44386324286460876, + 1.4940791130065918, + 1.4940791130065918, + 0.3130677044391632, + 0.3130677044391632, + -0.8540957570075989, + 0.653618574142456, + 0.8644362092018127, + 0.8644362092018127, + 2.269754648208618, + 2.269754648208618, + 0.04575851559638977, + 0.04575851559638977, + 1.5327792167663574, + 1.5327792167663574, + 1.4693588018417358, + 0.37816253304481506, + + -0.8877857327461243, + -0.34791216254234314, + 0.15634897351264954, + 1.2302906513214111, + 1.2302906513214111, + 1.202379822731018, + -0.302302747964859, + -0.302302747964859, + -1.0485529899597168, + -1.420017957687378, + 1.950775384902954, + 1.950775384902954, + -0.4380742907524109, + -0.4380742907524109, + 0.7774903774261475, + 0.7774903774261475, + -0.21274028718471527, + -0.21274028718471527, + 0.38690251111984253, + 0.38690251111984253, + -0.5108051300048828, + -0.02818222902715206, + 0.4283318817615509, + 0.4283318817615509, + 0.30247190594673157, + 0.30247190594673157, + -0.3627411723136902, + -0.3627411723136902, + -0.35955315828323364, + -0.35955315828323364, + -0.8131462931632996, + + 0.17742614448070526, + -0.4017809331417084, + 0.46278226375579834, + 0.46278226375579834, + 0.05194539576768875, + 0.7290905714035034, + 0.7290905714035034, + 1.1394007205963135, + 1.1394007205963135, + 0.4023416340351105, + 0.4023416340351105, + -0.6848101019859314, + -0.5788496732711792, + -0.3115525245666504, + 0.056165341287851334, + 0.056165341287851334, + 0.9008265137672424, + 0.9008265137672424, + 0.4656624495983124, + 1.4882521629333496, + 1.895889163017273, + 1.895889163017273, + 1.1787796020507812, + -0.1799248307943344, + 1.0544517040252686, + 1.0544517040252686, + 1.222445011138916, + 1.222445011138916, + 0.9766390323638916, + 0.9766390323638916, + 0.7065731883049011 + }; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size()*4)); + uint32_t ksize = 2; + uint32_t stride = 1; + auto op = graph->CreateOperation(tim::vx::PoolType::MAX, + tim::vx::PadType::VALID, ksize, stride); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_EQ(golden, output); +} + TEST(MAX, shape_6_6_1_1_fp32_kernel_3_stride_2) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); diff --git a/src/tim/vx/ops/pool1d.cc b/src/tim/vx/ops/pool1d.cc new file mode 100644 index 0000000..d26aeef --- /dev/null +++ b/src/tim/vx/ops/pool1d.cc @@ -0,0 +1,110 @@ +/**************************************************************************** +* +* Copyright (c) 2022 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/ops/pool1d.h" + +#include "builtin_op_impl.h" +#include "type_utils.h" +#include "vsi_nn_pub.h" + +namespace tim { +namespace vx { +namespace ops { + +// for Classic Pool1d +Pool1d::Pool1d(Graph* graph, PoolType type, PadType padding, + uint32_t ksize, uint32_t stride, RoundType round_type, + DataLayout layout) + : BuiltinOp(graph, VSI_NN_OP_POOL, 1, 1, layout), + type_(type), + padding_(padding), + ksize_(ksize), + stride_(stride), + round_type_(round_type), + pad_({0,0}) { + this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_); + this->impl()->node()->nn_param.pool.round_type = + TranslateRoundType(round_type_); + this->impl()->node()->nn_param.pool.ksize[0] = ksize_; + this->impl()->node()->nn_param.pool.stride[0] = stride_; + this->impl()->node()->nn_param.pool.pad_type = TranslatePadType(padding_); + this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_); +} + +Pool1d::Pool1d(Graph* graph, PoolType type, + const std::array& pad, + uint32_t ksize, + uint32_t stride, + RoundType round_type, + DataLayout layout) + : BuiltinOp(graph, VSI_NN_OP_POOL, 1, 1, layout), + type_(type), padding_(PadType::AUTO), ksize_(ksize), stride_(stride), + round_type_(round_type), pad_(pad) { + Init(); +} + +// for Global Pool1d +Pool1d::Pool1d(Graph* graph, PoolType type, + uint32_t input_size, + RoundType round_type, + DataLayout layout) + : BuiltinOp(graph, VSI_NN_OP_POOL, 1, 1, layout), + type_(type), padding_(PadType::AUTO), ksize_(input_size), stride_(input_size), + round_type_(round_type), pad_({0, 0}) { + Init(); +} + +// for Adaptive Pool1d +Pool1d::Pool1d(Graph* graph, PoolType type, + uint32_t input_size, + uint32_t output_size, + RoundType round_type, + DataLayout layout) + : BuiltinOp(graph, VSI_NN_OP_POOL, 1, 1, layout), + type_(type), padding_(PadType::AUTO), + round_type_(round_type), pad_({0, 0}) { + stride_ = floor(input_size / (float)(output_size)); + ksize_ = input_size - (output_size - 1) * stride_; + Init(); +} + +void Pool1d::Init() { + this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_); + this->impl()->node()->nn_param.pool.round_type = + TranslateRoundType(round_type_); + this->impl()->node()->nn_param.pool.ksize[0] = ksize_; + this->impl()->node()->nn_param.pool.stride[0] = stride_; + this->impl()->node()->nn_param.pool.pad[0] = pad_[0]; + this->impl()->node()->nn_param.pool.pad[1] = pad_[1]; + this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_); +} + +std::shared_ptr Pool1d::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->type_, this->pad_, this->ksize_, + this->stride_, this->round_type_, + this->impl_->layout_); +} + +} // namespace ops +} // namespace vx +} // namespace tim \ No newline at end of file