Add pad parameter to pool2d
Signed-off-by: zhao.xia <zhao.xia@verisilicon.com>
This commit is contained in:
parent
748658e47d
commit
bd9c5df70a
|
|
@ -52,6 +52,12 @@ class Pool2d : public Operation {
|
|||
const std::array<uint32_t, 2>& stride,
|
||||
RoundType round_type = RoundType::FLOOR,
|
||||
DataLayout layout = DataLayout::WHCN);
|
||||
Pool2d(Graph* graph, PoolType type,
|
||||
const std::array<uint32_t, 4>& pad,
|
||||
const std::array<uint32_t, 2>& ksize,
|
||||
const std::array<uint32_t, 2>& stride,
|
||||
RoundType round_type = RoundType::FLOOR,
|
||||
DataLayout layout = DataLayout::WHCN);
|
||||
|
||||
protected:
|
||||
const PoolType type_;
|
||||
|
|
@ -59,6 +65,7 @@ class Pool2d : public Operation {
|
|||
const std::array<uint32_t, 2> ksize_;
|
||||
const std::array<uint32_t, 2> stride_;
|
||||
const RoundType round_type_;
|
||||
const std::array<uint32_t, 4> pad_;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
|
|
|
|||
|
|
@ -40,7 +40,8 @@ Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding,
|
|||
padding_(padding),
|
||||
ksize_(ksize),
|
||||
stride_(stride),
|
||||
round_type_(round_type) {
|
||||
round_type_(round_type),
|
||||
pad_({0,0,0,0}) {
|
||||
this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_);
|
||||
this->impl()->node()->nn_param.pool.round_type =
|
||||
TranslateRoundType(round_type_);
|
||||
|
|
@ -52,6 +53,28 @@ Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding,
|
|||
this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_);
|
||||
}
|
||||
|
||||
Pool2d::Pool2d(Graph* graph, PoolType type,
|
||||
const std::array<uint32_t, 4>& pad,
|
||||
const std::array<uint32_t, 2>& ksize,
|
||||
const std::array<uint32_t, 2>& stride, RoundType round_type,
|
||||
DataLayout layout)
|
||||
: Operation(graph, VSI_NN_OP_POOL, 1, 1, layout),
|
||||
type_(type), padding_(PadType::AUTO), ksize_(ksize), stride_(stride),
|
||||
round_type_(round_type), pad_(pad) {
|
||||
this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_);
|
||||
this->impl()->node()->nn_param.pool.round_type =
|
||||
TranslateRoundType(round_type_);
|
||||
this->impl()->node()->nn_param.pool.ksize[0] = ksize_[0];
|
||||
this->impl()->node()->nn_param.pool.ksize[1] = ksize_[1];
|
||||
this->impl()->node()->nn_param.pool.stride[0] = stride_[0];
|
||||
this->impl()->node()->nn_param.pool.stride[1] = stride_[1];
|
||||
this->impl()->node()->nn_param.pool.pad[0] = pad_[0];
|
||||
this->impl()->node()->nn_param.pool.pad[1] = pad_[1];
|
||||
this->impl()->node()->nn_param.pool.pad[2] = pad_[2];
|
||||
this->impl()->node()->nn_param.pool.pad[3] = pad_[3];
|
||||
this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
Loading…
Reference in New Issue