Add pad parameter to pool2d
Signed-off-by: zhao.xia <zhao.xia@verisilicon.com>
This commit is contained in:
parent
748658e47d
commit
bd9c5df70a
|
|
@ -52,6 +52,12 @@ class Pool2d : public Operation {
|
||||||
const std::array<uint32_t, 2>& stride,
|
const std::array<uint32_t, 2>& stride,
|
||||||
RoundType round_type = RoundType::FLOOR,
|
RoundType round_type = RoundType::FLOOR,
|
||||||
DataLayout layout = DataLayout::WHCN);
|
DataLayout layout = DataLayout::WHCN);
|
||||||
|
Pool2d(Graph* graph, PoolType type,
|
||||||
|
const std::array<uint32_t, 4>& pad,
|
||||||
|
const std::array<uint32_t, 2>& ksize,
|
||||||
|
const std::array<uint32_t, 2>& stride,
|
||||||
|
RoundType round_type = RoundType::FLOOR,
|
||||||
|
DataLayout layout = DataLayout::WHCN);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const PoolType type_;
|
const PoolType type_;
|
||||||
|
|
@ -59,6 +65,7 @@ class Pool2d : public Operation {
|
||||||
const std::array<uint32_t, 2> ksize_;
|
const std::array<uint32_t, 2> ksize_;
|
||||||
const std::array<uint32_t, 2> stride_;
|
const std::array<uint32_t, 2> stride_;
|
||||||
const RoundType round_type_;
|
const RoundType round_type_;
|
||||||
|
const std::array<uint32_t, 4> pad_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,8 @@ Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding,
|
||||||
padding_(padding),
|
padding_(padding),
|
||||||
ksize_(ksize),
|
ksize_(ksize),
|
||||||
stride_(stride),
|
stride_(stride),
|
||||||
round_type_(round_type) {
|
round_type_(round_type),
|
||||||
|
pad_({0,0,0,0}) {
|
||||||
this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_);
|
this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_);
|
||||||
this->impl()->node()->nn_param.pool.round_type =
|
this->impl()->node()->nn_param.pool.round_type =
|
||||||
TranslateRoundType(round_type_);
|
TranslateRoundType(round_type_);
|
||||||
|
|
@ -52,6 +53,28 @@ Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding,
|
||||||
this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_);
|
this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Pool2d::Pool2d(Graph* graph, PoolType type,
|
||||||
|
const std::array<uint32_t, 4>& pad,
|
||||||
|
const std::array<uint32_t, 2>& ksize,
|
||||||
|
const std::array<uint32_t, 2>& stride, RoundType round_type,
|
||||||
|
DataLayout layout)
|
||||||
|
: Operation(graph, VSI_NN_OP_POOL, 1, 1, layout),
|
||||||
|
type_(type), padding_(PadType::AUTO), ksize_(ksize), stride_(stride),
|
||||||
|
round_type_(round_type), pad_(pad) {
|
||||||
|
this->impl()->node()->nn_param.pool.type = TranslatePoolType(type_);
|
||||||
|
this->impl()->node()->nn_param.pool.round_type =
|
||||||
|
TranslateRoundType(round_type_);
|
||||||
|
this->impl()->node()->nn_param.pool.ksize[0] = ksize_[0];
|
||||||
|
this->impl()->node()->nn_param.pool.ksize[1] = ksize_[1];
|
||||||
|
this->impl()->node()->nn_param.pool.stride[0] = stride_[0];
|
||||||
|
this->impl()->node()->nn_param.pool.stride[1] = stride_[1];
|
||||||
|
this->impl()->node()->nn_param.pool.pad[0] = pad_[0];
|
||||||
|
this->impl()->node()->nn_param.pool.pad[1] = pad_[1];
|
||||||
|
this->impl()->node()->nn_param.pool.pad[2] = pad_[2];
|
||||||
|
this->impl()->node()->nn_param.pool.pad[3] = pad_[3];
|
||||||
|
this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
Loading…
Reference in New Issue