mapped pool1d

This commit is contained in:
meseraph 2022-12-18 17:56:21 +08:00 committed by Sven
parent aa0b474c19
commit cc34b5f0ea
5 changed files with 677 additions and 0 deletions

View File

@ -63,6 +63,7 @@
#include "tim/vx/ops/onehot.h"
#include "tim/vx/ops/pad.h"
#include "tim/vx/ops/pad_v2.h"
#include "tim/vx/ops/pool1d.h"
#include "tim/vx/ops/pool2d.h"
#include "tim/vx/ops/reduce.h"
#include "tim/vx/ops/relational_operations.h"

110
include/tim/vx/ops/pool1d.h Normal file
View File

@ -0,0 +1,110 @@
/****************************************************************************
*
* Copyright (c) 2022 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_VX_OPS_POOL1D_H_
#define TIM_VX_OPS_POOL1D_H_
#include <array>
#include "tim/vx/builtin_op.h"
#include "tim/vx/types.h"
namespace tim {
namespace vx {
namespace ops {
/**
* ## Pool1d
*
* ### Classic Pool1d
*
* Performs an 1-D pooling operation.
*
* - type : MAX, AVG, L2 or AVG_ANDROID.
* - padding : AUTO, VALID or SAME.
* - pad : Specify the number of pad values for left, right.
* - ksize : filter size.
* - stride : stride along each spatial axis.
* - round_type : CEILING or FLOOR.
*
* ### Global Pool1d
*
* - type : MAX, AVG, L2 or AVG_ANDROID.
* - input_size : input size(only [W])
* - round_type : CEILING or FLOOR.
*
* ### Adaptive Pool1d
*
* Same as torch.nn.AdaptiveXXXPool1d.
*
* - type : MAX, AVG, L2 or AVG_ANDROID.
* - input_size : input size(only [W])
* - output_size : output size(only [W])
* - round_type : CEILING or FLOOR.
*
*/
class Pool1d : public BuiltinOp {
public:
/* for Classic Pool1d, pool does not support auto-completion of pad value,
you need to specify pad size explicitly, it is recommended to use the second api.*/
Pool1d(Graph* graph, PoolType type, PadType padding,
uint32_t ksize,
uint32_t stride,
RoundType round_type = RoundType::FLOOR,
DataLayout layout = DataLayout::WCN);
Pool1d(Graph* graph, PoolType type, const std::array<uint32_t, 2>& pad,
uint32_t ksize,
uint32_t stride,
RoundType round_type = RoundType::FLOOR,
DataLayout layout = DataLayout::WCN);
// for Global Pool1d
Pool1d(Graph* graph, PoolType type, uint32_t input_size,
RoundType round_type = RoundType::FLOOR,
DataLayout layout = DataLayout::WCN);
// for Adaptive Pool1d
Pool1d(Graph* graph, PoolType type, uint32_t input_size,
uint32_t output_size,
RoundType round_type = RoundType::FLOOR,
DataLayout layout = DataLayout::WCN);
std::shared_ptr<Operation> Clone(
std::shared_ptr<Graph>& graph) const override;
void Init();
protected:
const PoolType type_;
const PadType padding_;
uint32_t ksize_;
uint32_t stride_;
const RoundType round_type_;
const std::array<uint32_t, 2> pad_;
};
} // namespace ops
} // namespace vx
} // namespace tim
#endif /* TIM_VX_OPS_POOL1D_H_ */

View File

@ -24,10 +24,238 @@
#include "tim/vx/context.h"
#include "tim/vx/graph.h"
#include "tim/vx/ops/pool2d.h"
#include "tim/vx/ops/pool1d.h"
#include <iostream>
#include "gtest/gtest.h"
#include "test_utils.h"
TEST(AVG, shape_32_3_1_fp32_kernel_2_stride_1) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType in_shape({32, 3, 1});
tim::vx::ShapeType out_shape({31, 3, 1});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32,
in_shape, tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
out_shape, tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {
1.764052391052246,
0.40015721321105957,
0.978738009929657,
2.2408931255340576,
1.8675580024719238,
-0.9772778749465942,
0.9500884413719177,
-0.15135720372200012,
-0.10321885347366333,
0.4105985164642334,
0.14404356479644775,
1.4542734622955322,
0.7610377073287964,
0.12167501449584961,
0.44386324286460876,
0.3336743414402008,
1.4940791130065918,
-0.2051582634449005,
0.3130677044391632,
-0.8540957570075989,
-2.5529897212982178,
0.653618574142456,
0.8644362092018127,
-0.7421650290489197,
2.269754648208618,
-1.4543657302856445,
0.04575851559638977,
-0.18718385696411133,
1.5327792167663574,
1.4693588018417358,
0.154947429895401,
0.37816253304481506,
-0.8877857327461243,
-1.980796456336975,
-0.34791216254234314,
0.15634897351264954,
1.2302906513214111,
1.202379822731018,
-0.38732680678367615,
-0.302302747964859,
-1.0485529899597168,
-1.420017957687378,
-1.7062702178955078,
1.950775384902954,
-0.5096521973609924,
-0.4380742907524109,
-1.2527953386306763,
0.7774903774261475,
-1.6138978004455566,
-0.21274028718471527,
-0.8954665660858154,
0.38690251111984253,
-0.5108051300048828,
-1.18063223361969,
-0.02818222902715206,
0.4283318817615509,
0.06651721894741058,
0.30247190594673157,
-0.6343221068382263,
-0.3627411723136902,
-0.6724604368209839,
-0.35955315828323364,
-0.8131462931632996,
-1.7262825965881348,
0.17742614448070526,
-0.4017809331417084,
-1.630198359489441,
0.46278226375579834,
-0.9072983860969543,
0.05194539576768875,
0.7290905714035034,
0.12898291647434235,
1.1394007205963135,
-1.234825849533081,
0.4023416340351105,
-0.6848101019859314,
-0.8707971572875977,
-0.5788496732711792,
-0.3115525245666504,
0.056165341287851334,
-1.1651498079299927,
0.9008265137672424,
0.4656624495983124,
-1.5362436771392822,
1.4882521629333496,
1.895889163017273,
1.1787796020507812,
-0.1799248307943344,
-1.0707526206970215,
1.0544517040252686,
-0.4031769335269928,
1.222445011138916,
0.2082749754190445,
0.9766390323638916,
0.3563663959503174,
0.7065731883049011
};
std::vector<float> golden = {
1.0821048021316528,
0.6894476413726807,
1.6098155975341797,
2.054225444793701,
0.4451400637626648,
-0.013594716787338257,
0.3993656039237976,
-0.12728802859783173,
0.15368983149528503,
0.2773210406303406,
0.79915851354599,
1.1076555252075195,
0.441356360912323,
0.2827691435813904,
0.3887687921524048,
0.9138767123222351,
0.6444604396820068,
0.05395472049713135,
-0.27051401138305664,
-1.703542709350586,
-0.9496855735778809,
0.759027361869812,
0.06113559007644653,
0.7637947797775269,
0.4076944589614868,
-0.7043036222457886,
-0.07071267068386078,
0.672797679901123,
1.5010690689086914,
0.8121531009674072,
0.26655498147010803,
-1.434291124343872,
-1.1643543243408203,
-0.0957815945148468,
0.6933197975158691,
1.2163352966308594,
0.40752649307250977,
-0.3448147773742676,
-0.6754278540611267,
-1.2342854738235474,
-1.5631440877914429,
0.12225258350372314,
0.7205616235733032,
-0.47386324405670166,
-0.8454347848892212,
-0.2376524806022644,
-0.4182037115097046,
-0.9133190512657166,
-0.554103434085846,
-0.25428202748298645,
-0.06195130944252014,
-0.8457186818122864,
-0.6044072508811951,
0.20007482171058655,
0.24742454290390015,
0.18449455499649048,
-0.16592510044574738,
-0.49853163957595825,
-0.5176007747650146,
-0.5160068273544312,
-0.5863497257232666,
-1.2697144746780396,
-0.11217739433050156,
-1.0159896612167358,
-0.5837080478668213,
-0.222258061170578,
-0.4276764988899231,
0.3905179798603058,
0.4290367364883423,
0.6341918110847473,
-0.04771256446838379,
-0.4162421226501465,
-0.14123423397541046,
-0.7778036594390869,
-0.7248234152793884,
-0.4452010989189148,
-0.12769359350204468,
-0.5544922351837158,
-0.13216164708137512,
0.6832444667816162,
-0.5352905988693237,
-0.02399575710296631,
1.692070722579956,
1.5373344421386719,
0.4994273781776428,
-0.6253387331962585,
-0.008150458335876465,
0.3256374001502991,
0.4096340537071228,
0.7153599858283997,
0.5924569964408875,
0.6665027141571045,
0.5314698219299316
};
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size()*4));
uint32_t ksize = 2;
uint32_t stride = 1;
auto op = graph->CreateOperation<tim::vx::ops::Pool1d>(tim::vx::PoolType::AVG,
tim::vx::PadType::VALID, ksize, stride);
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
EXPECT_TRUE(graph->Compile());
EXPECT_TRUE(graph->Run());
std::vector<float> output(golden.size());
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
EXPECT_EQ(golden, output);
}
TEST(AVG, shape_3_3_1_2_fp32_kernel_2_stride_1) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();

View File

@ -24,10 +24,238 @@
#include "tim/vx/context.h"
#include "tim/vx/graph.h"
#include "tim/vx/ops/pool2d.h"
#include "tim/vx/ops/pool1d.h"
#include <iostream>
#include "gtest/gtest.h"
#include "test_utils.h"
TEST(MAX, shape_32_3_1_fp32_kernel_2_stride_1) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType in_shape({32, 3, 1});
tim::vx::ShapeType out_shape({31, 3, 1});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32,
in_shape, tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
out_shape, tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {
1.764052391052246,
0.40015721321105957,
0.978738009929657,
2.2408931255340576,
1.8675580024719238,
-0.9772778749465942,
0.9500884413719177,
-0.15135720372200012,
-0.10321885347366333,
0.4105985164642334,
0.14404356479644775,
1.4542734622955322,
0.7610377073287964,
0.12167501449584961,
0.44386324286460876,
0.3336743414402008,
1.4940791130065918,
-0.2051582634449005,
0.3130677044391632,
-0.8540957570075989,
-2.5529897212982178,
0.653618574142456,
0.8644362092018127,
-0.7421650290489197,
2.269754648208618,
-1.4543657302856445,
0.04575851559638977,
-0.18718385696411133,
1.5327792167663574,
1.4693588018417358,
0.154947429895401,
0.37816253304481506,
-0.8877857327461243,
-1.980796456336975,
-0.34791216254234314,
0.15634897351264954,
1.2302906513214111,
1.202379822731018,
-0.38732680678367615,
-0.302302747964859,
-1.0485529899597168,
-1.420017957687378,
-1.7062702178955078,
1.950775384902954,
-0.5096521973609924,
-0.4380742907524109,
-1.2527953386306763,
0.7774903774261475,
-1.6138978004455566,
-0.21274028718471527,
-0.8954665660858154,
0.38690251111984253,
-0.5108051300048828,
-1.18063223361969,
-0.02818222902715206,
0.4283318817615509,
0.06651721894741058,
0.30247190594673157,
-0.6343221068382263,
-0.3627411723136902,
-0.6724604368209839,
-0.35955315828323364,
-0.8131462931632996,
-1.7262825965881348,
0.17742614448070526,
-0.4017809331417084,
-1.630198359489441,
0.46278226375579834,
-0.9072983860969543,
0.05194539576768875,
0.7290905714035034,
0.12898291647434235,
1.1394007205963135,
-1.234825849533081,
0.4023416340351105,
-0.6848101019859314,
-0.8707971572875977,
-0.5788496732711792,
-0.3115525245666504,
0.056165341287851334,
-1.1651498079299927,
0.9008265137672424,
0.4656624495983124,
-1.5362436771392822,
1.4882521629333496,
1.895889163017273,
1.1787796020507812,
-0.1799248307943344,
-1.0707526206970215,
1.0544517040252686,
-0.4031769335269928,
1.222445011138916,
0.2082749754190445,
0.9766390323638916,
0.3563663959503174,
0.7065731883049011
};
std::vector<float> golden = {
1.764052391052246,
0.978738009929657,
2.2408931255340576,
2.2408931255340576,
1.8675580024719238,
0.9500884413719177,
0.9500884413719177,
-0.10321885347366333,
0.4105985164642334,
0.4105985164642334,
1.4542734622955322,
1.4542734622955322,
0.7610377073287964,
0.44386324286460876,
0.44386324286460876,
1.4940791130065918,
1.4940791130065918,
0.3130677044391632,
0.3130677044391632,
-0.8540957570075989,
0.653618574142456,
0.8644362092018127,
0.8644362092018127,
2.269754648208618,
2.269754648208618,
0.04575851559638977,
0.04575851559638977,
1.5327792167663574,
1.5327792167663574,
1.4693588018417358,
0.37816253304481506,
-0.8877857327461243,
-0.34791216254234314,
0.15634897351264954,
1.2302906513214111,
1.2302906513214111,
1.202379822731018,
-0.302302747964859,
-0.302302747964859,
-1.0485529899597168,
-1.420017957687378,
1.950775384902954,
1.950775384902954,
-0.4380742907524109,
-0.4380742907524109,
0.7774903774261475,
0.7774903774261475,
-0.21274028718471527,
-0.21274028718471527,
0.38690251111984253,
0.38690251111984253,
-0.5108051300048828,
-0.02818222902715206,
0.4283318817615509,
0.4283318817615509,
0.30247190594673157,
0.30247190594673157,
-0.3627411723136902,
-0.3627411723136902,
-0.35955315828323364,
-0.35955315828323364,
-0.8131462931632996,
0.17742614448070526,
-0.4017809331417084,
0.46278226375579834,
0.46278226375579834,
0.05194539576768875,
0.7290905714035034,
0.7290905714035034,
1.1394007205963135,
1.1394007205963135,
0.4023416340351105,
0.4023416340351105,
-0.6848101019859314,
-0.5788496732711792,
-0.3115525245666504,
0.056165341287851334,
0.056165341287851334,
0.9008265137672424,
0.9008265137672424,
0.4656624495983124,
1.4882521629333496,
1.895889163017273,
1.895889163017273,
1.1787796020507812,
-0.1799248307943344,
1.0544517040252686,
1.0544517040252686,
1.222445011138916,
1.222445011138916,
0.9766390323638916,
0.9766390323638916,
0.7065731883049011
};
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size()*4));
uint32_t ksize = 2;
uint32_t stride = 1;
auto op = graph->CreateOperation<tim::vx::ops::Pool1d>(tim::vx::PoolType::MAX,
tim::vx::PadType::VALID, ksize, stride);
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
EXPECT_TRUE(graph->Compile());
EXPECT_TRUE(graph->Run());
std::vector<float> output(golden.size());
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
EXPECT_EQ(golden, output);
}
TEST(MAX, shape_6_6_1_1_fp32_kernel_3_stride_2) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();

110
src/tim/vx/ops/pool1d.cc Normal file
View File

@ -0,0 +1,110 @@
/****************************************************************************
*
* Copyright (c) 2022 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include "tim/vx/ops/pool1d.h"
#include "builtin_op_impl.h"
#include "type_utils.h"
#include "vsi_nn_pub.h"
namespace tim {
namespace vx {
namespace ops {
// for Classic Pool1d
Pool1d::Pool1d(Graph* graph, PoolType type, PadType padding,
uint32_t ksize, uint32_t stride, RoundType round_type,
DataLayout layout)
: BuiltinOp(graph, VSI_NN_OP_POOL, 1, 1, layout),
type_(type),
padding_(padding),
ksize_(ksize),
stride_(stride),
round_type_(round_type),
pad_({0,0}) {
this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_);
this->impl()->node()->nn_param.pool.round_type =
TranslateRoundType(round_type_);
this->impl()->node()->nn_param.pool.ksize[0] = ksize_;
this->impl()->node()->nn_param.pool.stride[0] = stride_;
this->impl()->node()->nn_param.pool.pad_type = TranslatePadType(padding_);
this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_);
}
Pool1d::Pool1d(Graph* graph, PoolType type,
const std::array<uint32_t, 2>& pad,
uint32_t ksize,
uint32_t stride,
RoundType round_type,
DataLayout layout)
: BuiltinOp(graph, VSI_NN_OP_POOL, 1, 1, layout),
type_(type), padding_(PadType::AUTO), ksize_(ksize), stride_(stride),
round_type_(round_type), pad_(pad) {
Init();
}
// for Global Pool1d
Pool1d::Pool1d(Graph* graph, PoolType type,
uint32_t input_size,
RoundType round_type,
DataLayout layout)
: BuiltinOp(graph, VSI_NN_OP_POOL, 1, 1, layout),
type_(type), padding_(PadType::AUTO), ksize_(input_size), stride_(input_size),
round_type_(round_type), pad_({0, 0}) {
Init();
}
// for Adaptive Pool1d
Pool1d::Pool1d(Graph* graph, PoolType type,
uint32_t input_size,
uint32_t output_size,
RoundType round_type,
DataLayout layout)
: BuiltinOp(graph, VSI_NN_OP_POOL, 1, 1, layout),
type_(type), padding_(PadType::AUTO),
round_type_(round_type), pad_({0, 0}) {
stride_ = floor(input_size / (float)(output_size));
ksize_ = input_size - (output_size - 1) * stride_;
Init();
}
void Pool1d::Init() {
this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_);
this->impl()->node()->nn_param.pool.round_type =
TranslateRoundType(round_type_);
this->impl()->node()->nn_param.pool.ksize[0] = ksize_;
this->impl()->node()->nn_param.pool.stride[0] = stride_;
this->impl()->node()->nn_param.pool.pad[0] = pad_[0];
this->impl()->node()->nn_param.pool.pad[1] = pad_[1];
this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_);
}
std::shared_ptr<Operation> Pool1d::Clone(std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Pool1d>(this->type_, this->pad_, this->ksize_,
this->stride_, this->round_type_,
this->impl_->layout_);
}
} // namespace ops
} // namespace vx
} // namespace tim