mapped pool1d
This commit is contained in:
parent
aa0b474c19
commit
cc34b5f0ea
|
|
@ -63,6 +63,7 @@
|
||||||
#include "tim/vx/ops/onehot.h"
|
#include "tim/vx/ops/onehot.h"
|
||||||
#include "tim/vx/ops/pad.h"
|
#include "tim/vx/ops/pad.h"
|
||||||
#include "tim/vx/ops/pad_v2.h"
|
#include "tim/vx/ops/pad_v2.h"
|
||||||
|
#include "tim/vx/ops/pool1d.h"
|
||||||
#include "tim/vx/ops/pool2d.h"
|
#include "tim/vx/ops/pool2d.h"
|
||||||
#include "tim/vx/ops/reduce.h"
|
#include "tim/vx/ops/reduce.h"
|
||||||
#include "tim/vx/ops/relational_operations.h"
|
#include "tim/vx/ops/relational_operations.h"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,110 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2022 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#ifndef TIM_VX_OPS_POOL1D_H_
|
||||||
|
#define TIM_VX_OPS_POOL1D_H_
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
#include "tim/vx/builtin_op.h"
|
||||||
|
#include "tim/vx/types.h"
|
||||||
|
|
||||||
|
namespace tim {
|
||||||
|
namespace vx {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ## Pool1d
|
||||||
|
*
|
||||||
|
* ### Classic Pool1d
|
||||||
|
*
|
||||||
|
* Performs an 1-D pooling operation.
|
||||||
|
*
|
||||||
|
* - type : MAX, AVG, L2 or AVG_ANDROID.
|
||||||
|
* - padding : AUTO, VALID or SAME.
|
||||||
|
* - pad : Specify the number of pad values for left, right.
|
||||||
|
* - ksize : filter size.
|
||||||
|
* - stride : stride along each spatial axis.
|
||||||
|
* - round_type : CEILING or FLOOR.
|
||||||
|
*
|
||||||
|
* ### Global Pool1d
|
||||||
|
*
|
||||||
|
* - type : MAX, AVG, L2 or AVG_ANDROID.
|
||||||
|
* - input_size : input size(only [W])
|
||||||
|
* - round_type : CEILING or FLOOR.
|
||||||
|
*
|
||||||
|
* ### Adaptive Pool1d
|
||||||
|
*
|
||||||
|
* Same as torch.nn.AdaptiveXXXPool1d.
|
||||||
|
*
|
||||||
|
* - type : MAX, AVG, L2 or AVG_ANDROID.
|
||||||
|
* - input_size : input size(only [W])
|
||||||
|
* - output_size : output size(only [W])
|
||||||
|
* - round_type : CEILING or FLOOR.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
class Pool1d : public BuiltinOp {
|
||||||
|
public:
|
||||||
|
/* for Classic Pool1d, pool does not support auto-completion of pad value,
|
||||||
|
you need to specify pad size explicitly, it is recommended to use the second api.*/
|
||||||
|
Pool1d(Graph* graph, PoolType type, PadType padding,
|
||||||
|
uint32_t ksize,
|
||||||
|
uint32_t stride,
|
||||||
|
RoundType round_type = RoundType::FLOOR,
|
||||||
|
DataLayout layout = DataLayout::WCN);
|
||||||
|
Pool1d(Graph* graph, PoolType type, const std::array<uint32_t, 2>& pad,
|
||||||
|
uint32_t ksize,
|
||||||
|
uint32_t stride,
|
||||||
|
RoundType round_type = RoundType::FLOOR,
|
||||||
|
DataLayout layout = DataLayout::WCN);
|
||||||
|
|
||||||
|
// for Global Pool1d
|
||||||
|
Pool1d(Graph* graph, PoolType type, uint32_t input_size,
|
||||||
|
RoundType round_type = RoundType::FLOOR,
|
||||||
|
DataLayout layout = DataLayout::WCN);
|
||||||
|
|
||||||
|
// for Adaptive Pool1d
|
||||||
|
Pool1d(Graph* graph, PoolType type, uint32_t input_size,
|
||||||
|
uint32_t output_size,
|
||||||
|
RoundType round_type = RoundType::FLOOR,
|
||||||
|
DataLayout layout = DataLayout::WCN);
|
||||||
|
|
||||||
|
std::shared_ptr<Operation> Clone(
|
||||||
|
std::shared_ptr<Graph>& graph) const override;
|
||||||
|
void Init();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const PoolType type_;
|
||||||
|
const PadType padding_;
|
||||||
|
uint32_t ksize_;
|
||||||
|
uint32_t stride_;
|
||||||
|
const RoundType round_type_;
|
||||||
|
const std::array<uint32_t, 2> pad_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace vx
|
||||||
|
} // namespace tim
|
||||||
|
|
||||||
|
#endif /* TIM_VX_OPS_POOL1D_H_ */
|
||||||
|
|
@ -24,10 +24,238 @@
|
||||||
#include "tim/vx/context.h"
|
#include "tim/vx/context.h"
|
||||||
#include "tim/vx/graph.h"
|
#include "tim/vx/graph.h"
|
||||||
#include "tim/vx/ops/pool2d.h"
|
#include "tim/vx/ops/pool2d.h"
|
||||||
|
#include "tim/vx/ops/pool1d.h"
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "test_utils.h"
|
#include "test_utils.h"
|
||||||
|
|
||||||
|
TEST(AVG, shape_32_3_1_fp32_kernel_2_stride_1) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType in_shape({32, 3, 1});
|
||||||
|
tim::vx::ShapeType out_shape({31, 3, 1});
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
in_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
out_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<float> in_data = {
|
||||||
|
1.764052391052246,
|
||||||
|
0.40015721321105957,
|
||||||
|
0.978738009929657,
|
||||||
|
2.2408931255340576,
|
||||||
|
1.8675580024719238,
|
||||||
|
-0.9772778749465942,
|
||||||
|
0.9500884413719177,
|
||||||
|
-0.15135720372200012,
|
||||||
|
-0.10321885347366333,
|
||||||
|
0.4105985164642334,
|
||||||
|
0.14404356479644775,
|
||||||
|
1.4542734622955322,
|
||||||
|
0.7610377073287964,
|
||||||
|
0.12167501449584961,
|
||||||
|
0.44386324286460876,
|
||||||
|
0.3336743414402008,
|
||||||
|
1.4940791130065918,
|
||||||
|
-0.2051582634449005,
|
||||||
|
0.3130677044391632,
|
||||||
|
-0.8540957570075989,
|
||||||
|
-2.5529897212982178,
|
||||||
|
0.653618574142456,
|
||||||
|
0.8644362092018127,
|
||||||
|
-0.7421650290489197,
|
||||||
|
2.269754648208618,
|
||||||
|
-1.4543657302856445,
|
||||||
|
0.04575851559638977,
|
||||||
|
-0.18718385696411133,
|
||||||
|
1.5327792167663574,
|
||||||
|
1.4693588018417358,
|
||||||
|
0.154947429895401,
|
||||||
|
0.37816253304481506,
|
||||||
|
|
||||||
|
-0.8877857327461243,
|
||||||
|
-1.980796456336975,
|
||||||
|
-0.34791216254234314,
|
||||||
|
0.15634897351264954,
|
||||||
|
1.2302906513214111,
|
||||||
|
1.202379822731018,
|
||||||
|
-0.38732680678367615,
|
||||||
|
-0.302302747964859,
|
||||||
|
-1.0485529899597168,
|
||||||
|
-1.420017957687378,
|
||||||
|
-1.7062702178955078,
|
||||||
|
1.950775384902954,
|
||||||
|
-0.5096521973609924,
|
||||||
|
-0.4380742907524109,
|
||||||
|
-1.2527953386306763,
|
||||||
|
0.7774903774261475,
|
||||||
|
-1.6138978004455566,
|
||||||
|
-0.21274028718471527,
|
||||||
|
-0.8954665660858154,
|
||||||
|
0.38690251111984253,
|
||||||
|
-0.5108051300048828,
|
||||||
|
-1.18063223361969,
|
||||||
|
-0.02818222902715206,
|
||||||
|
0.4283318817615509,
|
||||||
|
0.06651721894741058,
|
||||||
|
0.30247190594673157,
|
||||||
|
-0.6343221068382263,
|
||||||
|
-0.3627411723136902,
|
||||||
|
-0.6724604368209839,
|
||||||
|
-0.35955315828323364,
|
||||||
|
-0.8131462931632996,
|
||||||
|
-1.7262825965881348,
|
||||||
|
|
||||||
|
0.17742614448070526,
|
||||||
|
-0.4017809331417084,
|
||||||
|
-1.630198359489441,
|
||||||
|
0.46278226375579834,
|
||||||
|
-0.9072983860969543,
|
||||||
|
0.05194539576768875,
|
||||||
|
0.7290905714035034,
|
||||||
|
0.12898291647434235,
|
||||||
|
1.1394007205963135,
|
||||||
|
-1.234825849533081,
|
||||||
|
0.4023416340351105,
|
||||||
|
-0.6848101019859314,
|
||||||
|
-0.8707971572875977,
|
||||||
|
-0.5788496732711792,
|
||||||
|
-0.3115525245666504,
|
||||||
|
0.056165341287851334,
|
||||||
|
-1.1651498079299927,
|
||||||
|
0.9008265137672424,
|
||||||
|
0.4656624495983124,
|
||||||
|
-1.5362436771392822,
|
||||||
|
1.4882521629333496,
|
||||||
|
1.895889163017273,
|
||||||
|
1.1787796020507812,
|
||||||
|
-0.1799248307943344,
|
||||||
|
-1.0707526206970215,
|
||||||
|
1.0544517040252686,
|
||||||
|
-0.4031769335269928,
|
||||||
|
1.222445011138916,
|
||||||
|
0.2082749754190445,
|
||||||
|
0.9766390323638916,
|
||||||
|
0.3563663959503174,
|
||||||
|
0.7065731883049011
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
1.0821048021316528,
|
||||||
|
0.6894476413726807,
|
||||||
|
1.6098155975341797,
|
||||||
|
2.054225444793701,
|
||||||
|
0.4451400637626648,
|
||||||
|
-0.013594716787338257,
|
||||||
|
0.3993656039237976,
|
||||||
|
-0.12728802859783173,
|
||||||
|
0.15368983149528503,
|
||||||
|
0.2773210406303406,
|
||||||
|
0.79915851354599,
|
||||||
|
1.1076555252075195,
|
||||||
|
0.441356360912323,
|
||||||
|
0.2827691435813904,
|
||||||
|
0.3887687921524048,
|
||||||
|
0.9138767123222351,
|
||||||
|
0.6444604396820068,
|
||||||
|
0.05395472049713135,
|
||||||
|
-0.27051401138305664,
|
||||||
|
-1.703542709350586,
|
||||||
|
-0.9496855735778809,
|
||||||
|
0.759027361869812,
|
||||||
|
0.06113559007644653,
|
||||||
|
0.7637947797775269,
|
||||||
|
0.4076944589614868,
|
||||||
|
-0.7043036222457886,
|
||||||
|
-0.07071267068386078,
|
||||||
|
0.672797679901123,
|
||||||
|
1.5010690689086914,
|
||||||
|
0.8121531009674072,
|
||||||
|
0.26655498147010803,
|
||||||
|
|
||||||
|
-1.434291124343872,
|
||||||
|
-1.1643543243408203,
|
||||||
|
-0.0957815945148468,
|
||||||
|
0.6933197975158691,
|
||||||
|
1.2163352966308594,
|
||||||
|
0.40752649307250977,
|
||||||
|
-0.3448147773742676,
|
||||||
|
-0.6754278540611267,
|
||||||
|
-1.2342854738235474,
|
||||||
|
-1.5631440877914429,
|
||||||
|
0.12225258350372314,
|
||||||
|
0.7205616235733032,
|
||||||
|
-0.47386324405670166,
|
||||||
|
-0.8454347848892212,
|
||||||
|
-0.2376524806022644,
|
||||||
|
-0.4182037115097046,
|
||||||
|
-0.9133190512657166,
|
||||||
|
-0.554103434085846,
|
||||||
|
-0.25428202748298645,
|
||||||
|
-0.06195130944252014,
|
||||||
|
-0.8457186818122864,
|
||||||
|
-0.6044072508811951,
|
||||||
|
0.20007482171058655,
|
||||||
|
0.24742454290390015,
|
||||||
|
0.18449455499649048,
|
||||||
|
-0.16592510044574738,
|
||||||
|
-0.49853163957595825,
|
||||||
|
-0.5176007747650146,
|
||||||
|
-0.5160068273544312,
|
||||||
|
-0.5863497257232666,
|
||||||
|
-1.2697144746780396,
|
||||||
|
|
||||||
|
-0.11217739433050156,
|
||||||
|
-1.0159896612167358,
|
||||||
|
-0.5837080478668213,
|
||||||
|
-0.222258061170578,
|
||||||
|
-0.4276764988899231,
|
||||||
|
0.3905179798603058,
|
||||||
|
0.4290367364883423,
|
||||||
|
0.6341918110847473,
|
||||||
|
-0.04771256446838379,
|
||||||
|
-0.4162421226501465,
|
||||||
|
-0.14123423397541046,
|
||||||
|
-0.7778036594390869,
|
||||||
|
-0.7248234152793884,
|
||||||
|
-0.4452010989189148,
|
||||||
|
-0.12769359350204468,
|
||||||
|
-0.5544922351837158,
|
||||||
|
-0.13216164708137512,
|
||||||
|
0.6832444667816162,
|
||||||
|
-0.5352905988693237,
|
||||||
|
-0.02399575710296631,
|
||||||
|
1.692070722579956,
|
||||||
|
1.5373344421386719,
|
||||||
|
0.4994273781776428,
|
||||||
|
-0.6253387331962585,
|
||||||
|
-0.008150458335876465,
|
||||||
|
0.3256374001502991,
|
||||||
|
0.4096340537071228,
|
||||||
|
0.7153599858283997,
|
||||||
|
0.5924569964408875,
|
||||||
|
0.6665027141571045,
|
||||||
|
0.5314698219299316
|
||||||
|
};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size()*4));
|
||||||
|
uint32_t ksize = 2;
|
||||||
|
uint32_t stride = 1;
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::Pool1d>(tim::vx::PoolType::AVG,
|
||||||
|
tim::vx::PadType::VALID, ksize, stride);
|
||||||
|
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
|
||||||
|
std::vector<float> output(golden.size());
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||||
|
EXPECT_EQ(golden, output);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(AVG, shape_3_3_1_2_fp32_kernel_2_stride_1) {
|
TEST(AVG, shape_3_3_1_2_fp32_kernel_2_stride_1) {
|
||||||
auto ctx = tim::vx::Context::Create();
|
auto ctx = tim::vx::Context::Create();
|
||||||
auto graph = ctx->CreateGraph();
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
|
||||||
|
|
@ -24,10 +24,238 @@
|
||||||
#include "tim/vx/context.h"
|
#include "tim/vx/context.h"
|
||||||
#include "tim/vx/graph.h"
|
#include "tim/vx/graph.h"
|
||||||
#include "tim/vx/ops/pool2d.h"
|
#include "tim/vx/ops/pool2d.h"
|
||||||
|
#include "tim/vx/ops/pool1d.h"
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "test_utils.h"
|
#include "test_utils.h"
|
||||||
|
|
||||||
|
TEST(MAX, shape_32_3_1_fp32_kernel_2_stride_1) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType in_shape({32, 3, 1});
|
||||||
|
tim::vx::ShapeType out_shape({31, 3, 1});
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
in_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
out_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<float> in_data = {
|
||||||
|
1.764052391052246,
|
||||||
|
0.40015721321105957,
|
||||||
|
0.978738009929657,
|
||||||
|
2.2408931255340576,
|
||||||
|
1.8675580024719238,
|
||||||
|
-0.9772778749465942,
|
||||||
|
0.9500884413719177,
|
||||||
|
-0.15135720372200012,
|
||||||
|
-0.10321885347366333,
|
||||||
|
0.4105985164642334,
|
||||||
|
0.14404356479644775,
|
||||||
|
1.4542734622955322,
|
||||||
|
0.7610377073287964,
|
||||||
|
0.12167501449584961,
|
||||||
|
0.44386324286460876,
|
||||||
|
0.3336743414402008,
|
||||||
|
1.4940791130065918,
|
||||||
|
-0.2051582634449005,
|
||||||
|
0.3130677044391632,
|
||||||
|
-0.8540957570075989,
|
||||||
|
-2.5529897212982178,
|
||||||
|
0.653618574142456,
|
||||||
|
0.8644362092018127,
|
||||||
|
-0.7421650290489197,
|
||||||
|
2.269754648208618,
|
||||||
|
-1.4543657302856445,
|
||||||
|
0.04575851559638977,
|
||||||
|
-0.18718385696411133,
|
||||||
|
1.5327792167663574,
|
||||||
|
1.4693588018417358,
|
||||||
|
0.154947429895401,
|
||||||
|
0.37816253304481506,
|
||||||
|
|
||||||
|
-0.8877857327461243,
|
||||||
|
-1.980796456336975,
|
||||||
|
-0.34791216254234314,
|
||||||
|
0.15634897351264954,
|
||||||
|
1.2302906513214111,
|
||||||
|
1.202379822731018,
|
||||||
|
-0.38732680678367615,
|
||||||
|
-0.302302747964859,
|
||||||
|
-1.0485529899597168,
|
||||||
|
-1.420017957687378,
|
||||||
|
-1.7062702178955078,
|
||||||
|
1.950775384902954,
|
||||||
|
-0.5096521973609924,
|
||||||
|
-0.4380742907524109,
|
||||||
|
-1.2527953386306763,
|
||||||
|
0.7774903774261475,
|
||||||
|
-1.6138978004455566,
|
||||||
|
-0.21274028718471527,
|
||||||
|
-0.8954665660858154,
|
||||||
|
0.38690251111984253,
|
||||||
|
-0.5108051300048828,
|
||||||
|
-1.18063223361969,
|
||||||
|
-0.02818222902715206,
|
||||||
|
0.4283318817615509,
|
||||||
|
0.06651721894741058,
|
||||||
|
0.30247190594673157,
|
||||||
|
-0.6343221068382263,
|
||||||
|
-0.3627411723136902,
|
||||||
|
-0.6724604368209839,
|
||||||
|
-0.35955315828323364,
|
||||||
|
-0.8131462931632996,
|
||||||
|
-1.7262825965881348,
|
||||||
|
|
||||||
|
0.17742614448070526,
|
||||||
|
-0.4017809331417084,
|
||||||
|
-1.630198359489441,
|
||||||
|
0.46278226375579834,
|
||||||
|
-0.9072983860969543,
|
||||||
|
0.05194539576768875,
|
||||||
|
0.7290905714035034,
|
||||||
|
0.12898291647434235,
|
||||||
|
1.1394007205963135,
|
||||||
|
-1.234825849533081,
|
||||||
|
0.4023416340351105,
|
||||||
|
-0.6848101019859314,
|
||||||
|
-0.8707971572875977,
|
||||||
|
-0.5788496732711792,
|
||||||
|
-0.3115525245666504,
|
||||||
|
0.056165341287851334,
|
||||||
|
-1.1651498079299927,
|
||||||
|
0.9008265137672424,
|
||||||
|
0.4656624495983124,
|
||||||
|
-1.5362436771392822,
|
||||||
|
1.4882521629333496,
|
||||||
|
1.895889163017273,
|
||||||
|
1.1787796020507812,
|
||||||
|
-0.1799248307943344,
|
||||||
|
-1.0707526206970215,
|
||||||
|
1.0544517040252686,
|
||||||
|
-0.4031769335269928,
|
||||||
|
1.222445011138916,
|
||||||
|
0.2082749754190445,
|
||||||
|
0.9766390323638916,
|
||||||
|
0.3563663959503174,
|
||||||
|
0.7065731883049011
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
1.764052391052246,
|
||||||
|
0.978738009929657,
|
||||||
|
2.2408931255340576,
|
||||||
|
2.2408931255340576,
|
||||||
|
1.8675580024719238,
|
||||||
|
0.9500884413719177,
|
||||||
|
0.9500884413719177,
|
||||||
|
-0.10321885347366333,
|
||||||
|
0.4105985164642334,
|
||||||
|
0.4105985164642334,
|
||||||
|
1.4542734622955322,
|
||||||
|
1.4542734622955322,
|
||||||
|
0.7610377073287964,
|
||||||
|
0.44386324286460876,
|
||||||
|
0.44386324286460876,
|
||||||
|
1.4940791130065918,
|
||||||
|
1.4940791130065918,
|
||||||
|
0.3130677044391632,
|
||||||
|
0.3130677044391632,
|
||||||
|
-0.8540957570075989,
|
||||||
|
0.653618574142456,
|
||||||
|
0.8644362092018127,
|
||||||
|
0.8644362092018127,
|
||||||
|
2.269754648208618,
|
||||||
|
2.269754648208618,
|
||||||
|
0.04575851559638977,
|
||||||
|
0.04575851559638977,
|
||||||
|
1.5327792167663574,
|
||||||
|
1.5327792167663574,
|
||||||
|
1.4693588018417358,
|
||||||
|
0.37816253304481506,
|
||||||
|
|
||||||
|
-0.8877857327461243,
|
||||||
|
-0.34791216254234314,
|
||||||
|
0.15634897351264954,
|
||||||
|
1.2302906513214111,
|
||||||
|
1.2302906513214111,
|
||||||
|
1.202379822731018,
|
||||||
|
-0.302302747964859,
|
||||||
|
-0.302302747964859,
|
||||||
|
-1.0485529899597168,
|
||||||
|
-1.420017957687378,
|
||||||
|
1.950775384902954,
|
||||||
|
1.950775384902954,
|
||||||
|
-0.4380742907524109,
|
||||||
|
-0.4380742907524109,
|
||||||
|
0.7774903774261475,
|
||||||
|
0.7774903774261475,
|
||||||
|
-0.21274028718471527,
|
||||||
|
-0.21274028718471527,
|
||||||
|
0.38690251111984253,
|
||||||
|
0.38690251111984253,
|
||||||
|
-0.5108051300048828,
|
||||||
|
-0.02818222902715206,
|
||||||
|
0.4283318817615509,
|
||||||
|
0.4283318817615509,
|
||||||
|
0.30247190594673157,
|
||||||
|
0.30247190594673157,
|
||||||
|
-0.3627411723136902,
|
||||||
|
-0.3627411723136902,
|
||||||
|
-0.35955315828323364,
|
||||||
|
-0.35955315828323364,
|
||||||
|
-0.8131462931632996,
|
||||||
|
|
||||||
|
0.17742614448070526,
|
||||||
|
-0.4017809331417084,
|
||||||
|
0.46278226375579834,
|
||||||
|
0.46278226375579834,
|
||||||
|
0.05194539576768875,
|
||||||
|
0.7290905714035034,
|
||||||
|
0.7290905714035034,
|
||||||
|
1.1394007205963135,
|
||||||
|
1.1394007205963135,
|
||||||
|
0.4023416340351105,
|
||||||
|
0.4023416340351105,
|
||||||
|
-0.6848101019859314,
|
||||||
|
-0.5788496732711792,
|
||||||
|
-0.3115525245666504,
|
||||||
|
0.056165341287851334,
|
||||||
|
0.056165341287851334,
|
||||||
|
0.9008265137672424,
|
||||||
|
0.9008265137672424,
|
||||||
|
0.4656624495983124,
|
||||||
|
1.4882521629333496,
|
||||||
|
1.895889163017273,
|
||||||
|
1.895889163017273,
|
||||||
|
1.1787796020507812,
|
||||||
|
-0.1799248307943344,
|
||||||
|
1.0544517040252686,
|
||||||
|
1.0544517040252686,
|
||||||
|
1.222445011138916,
|
||||||
|
1.222445011138916,
|
||||||
|
0.9766390323638916,
|
||||||
|
0.9766390323638916,
|
||||||
|
0.7065731883049011
|
||||||
|
};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size()*4));
|
||||||
|
uint32_t ksize = 2;
|
||||||
|
uint32_t stride = 1;
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::Pool1d>(tim::vx::PoolType::MAX,
|
||||||
|
tim::vx::PadType::VALID, ksize, stride);
|
||||||
|
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
|
||||||
|
std::vector<float> output(golden.size());
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||||
|
EXPECT_EQ(golden, output);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(MAX, shape_6_6_1_1_fp32_kernel_3_stride_2) {
|
TEST(MAX, shape_6_6_1_1_fp32_kernel_3_stride_2) {
|
||||||
auto ctx = tim::vx::Context::Create();
|
auto ctx = tim::vx::Context::Create();
|
||||||
auto graph = ctx->CreateGraph();
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,110 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2022 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#include "tim/vx/ops/pool1d.h"
|
||||||
|
|
||||||
|
#include "builtin_op_impl.h"
|
||||||
|
#include "type_utils.h"
|
||||||
|
#include "vsi_nn_pub.h"
|
||||||
|
|
||||||
|
namespace tim {
|
||||||
|
namespace vx {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
// for Classic Pool1d
|
||||||
|
Pool1d::Pool1d(Graph* graph, PoolType type, PadType padding,
|
||||||
|
uint32_t ksize, uint32_t stride, RoundType round_type,
|
||||||
|
DataLayout layout)
|
||||||
|
: BuiltinOp(graph, VSI_NN_OP_POOL, 1, 1, layout),
|
||||||
|
type_(type),
|
||||||
|
padding_(padding),
|
||||||
|
ksize_(ksize),
|
||||||
|
stride_(stride),
|
||||||
|
round_type_(round_type),
|
||||||
|
pad_({0,0}) {
|
||||||
|
this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_);
|
||||||
|
this->impl()->node()->nn_param.pool.round_type =
|
||||||
|
TranslateRoundType(round_type_);
|
||||||
|
this->impl()->node()->nn_param.pool.ksize[0] = ksize_;
|
||||||
|
this->impl()->node()->nn_param.pool.stride[0] = stride_;
|
||||||
|
this->impl()->node()->nn_param.pool.pad_type = TranslatePadType(padding_);
|
||||||
|
this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Pool1d::Pool1d(Graph* graph, PoolType type,
|
||||||
|
const std::array<uint32_t, 2>& pad,
|
||||||
|
uint32_t ksize,
|
||||||
|
uint32_t stride,
|
||||||
|
RoundType round_type,
|
||||||
|
DataLayout layout)
|
||||||
|
: BuiltinOp(graph, VSI_NN_OP_POOL, 1, 1, layout),
|
||||||
|
type_(type), padding_(PadType::AUTO), ksize_(ksize), stride_(stride),
|
||||||
|
round_type_(round_type), pad_(pad) {
|
||||||
|
Init();
|
||||||
|
}
|
||||||
|
|
||||||
|
// for Global Pool1d
|
||||||
|
Pool1d::Pool1d(Graph* graph, PoolType type,
|
||||||
|
uint32_t input_size,
|
||||||
|
RoundType round_type,
|
||||||
|
DataLayout layout)
|
||||||
|
: BuiltinOp(graph, VSI_NN_OP_POOL, 1, 1, layout),
|
||||||
|
type_(type), padding_(PadType::AUTO), ksize_(input_size), stride_(input_size),
|
||||||
|
round_type_(round_type), pad_({0, 0}) {
|
||||||
|
Init();
|
||||||
|
}
|
||||||
|
|
||||||
|
// for Adaptive Pool1d
|
||||||
|
Pool1d::Pool1d(Graph* graph, PoolType type,
|
||||||
|
uint32_t input_size,
|
||||||
|
uint32_t output_size,
|
||||||
|
RoundType round_type,
|
||||||
|
DataLayout layout)
|
||||||
|
: BuiltinOp(graph, VSI_NN_OP_POOL, 1, 1, layout),
|
||||||
|
type_(type), padding_(PadType::AUTO),
|
||||||
|
round_type_(round_type), pad_({0, 0}) {
|
||||||
|
stride_ = floor(input_size / (float)(output_size));
|
||||||
|
ksize_ = input_size - (output_size - 1) * stride_;
|
||||||
|
Init();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Pool1d::Init() {
|
||||||
|
this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_);
|
||||||
|
this->impl()->node()->nn_param.pool.round_type =
|
||||||
|
TranslateRoundType(round_type_);
|
||||||
|
this->impl()->node()->nn_param.pool.ksize[0] = ksize_;
|
||||||
|
this->impl()->node()->nn_param.pool.stride[0] = stride_;
|
||||||
|
this->impl()->node()->nn_param.pool.pad[0] = pad_[0];
|
||||||
|
this->impl()->node()->nn_param.pool.pad[1] = pad_[1];
|
||||||
|
this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Operation> Pool1d::Clone(std::shared_ptr<Graph>& graph) const {
|
||||||
|
return graph->CreateOperation<Pool1d>(this->type_, this->pad_, this->ksize_,
|
||||||
|
this->stride_, this->round_type_,
|
||||||
|
this->impl_->layout_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace vx
|
||||||
|
} // namespace tim
|
||||||
Loading…
Reference in New Issue