Remove DownScaleSizeRounding type

Use RoundType instead of DownScaleSizeRounding.

Signed-off-by: zhao.xia <zhao.xia@verisilicon.com>
This commit is contained in:
zhao.xia 2021-05-25 16:03:05 +08:00 committed by Kainan Cha
parent eccc117ec5
commit 37f686c34d
7 changed files with 10 additions and 13 deletions

View File

@ -43,9 +43,8 @@ class Operation {
Operation& BindOutputs(const std::vector<std::shared_ptr<Tensor>>& tensors);
Operation& SetRoundingPolicy(
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO,
DownScaleSizeRounding down_scale_size_rounding =
DownScaleSizeRounding::FLOOR,
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
RoundType down_scale_size_rounding = RoundType::FLOOR,
uint32_t accumulator_bits = 0);
std::unique_ptr<OperationImpl>& impl();

View File

@ -60,8 +60,6 @@ enum class OverflowPolicy { WRAP, SATURATE };
enum class RoundingPolicy { TO_ZERO, RTNE };
enum class DownScaleSizeRounding { FLOOR, CEILING };
enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA };
enum class DataLayout {

View File

@ -70,7 +70,7 @@ OperationImpl& OperationImpl::BindOutput(
OperationImpl& OperationImpl::SetRoundingPolicy(
OverflowPolicy overflow_policy, RoundingPolicy rounding_policy,
DownScaleSizeRounding down_scale_size_rounding, uint32_t accumulator_bits) {
RoundType down_scale_size_rounding, uint32_t accumulator_bits) {
node_->vx_param.overflow_policy = TranslateOverflowPolicy(overflow_policy);
node_->vx_param.rounding_policy = TranslateRoundingPolicy(rounding_policy);
node_->vx_param.down_scale_size_rounding =
@ -104,7 +104,7 @@ Operation& Operation::BindOutput(const std::shared_ptr<Tensor>& tensor) {
Operation& Operation::SetRoundingPolicy(
OverflowPolicy overflow_policy, RoundingPolicy rounding_policy,
DownScaleSizeRounding down_scale_size_rounding, uint32_t accumulator_bits) {
RoundType down_scale_size_rounding, uint32_t accumulator_bits) {
impl_->SetRoundingPolicy(overflow_policy, rounding_policy,
down_scale_size_rounding, accumulator_bits);
return *this;

View File

@ -41,8 +41,7 @@ class OperationImpl {
OperationImpl& SetRoundingPolicy(
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
DownScaleSizeRounding down_scale_size_rounding =
DownScaleSizeRounding::FLOOR,
RoundType down_scale_size_rounding = RoundType::FLOOR,
uint32_t accumulator_bits = 0);
vsi_nn_node_t* node() { return this->node_; }

View File

@ -49,6 +49,7 @@ Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding,
this->impl()->node()->nn_param.pool.stride[0] = stride_[0];
this->impl()->node()->nn_param.pool.stride[1] = stride_[1];
this->impl()->node()->nn_param.pool.pad_type = TranslatePadType(padding_);
this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_);
}
} // namespace ops

View File

@ -134,11 +134,11 @@ vsi_enum TranslateRoundingPolicy(RoundingPolicy type) {
return VX_ROUND_POLICY_TO_NEAREST_EVEN;
}
vsi_enum TranslateDownScaleSizeRounding(DownScaleSizeRounding type) {
vsi_enum TranslateDownScaleSizeRounding(RoundType type) {
switch (type) {
case DownScaleSizeRounding::FLOOR:
case RoundType::FLOOR:
return VX_CONVOLUTIONAL_NETWORK_DS_SIZE_ROUNDING_FLOOR;
case DownScaleSizeRounding::CEILING:
case RoundType::CEILING:
return VX_CONVOLUTIONAL_NETWORK_DS_SIZE_ROUNDING_CEILING;
default:
break;

View File

@ -36,7 +36,7 @@ vsi_enum TranslatePoolType(PoolType type);
vsi_nn_round_type_e TranslateRoundType(RoundType type);
vsi_enum TranslateOverflowPolicy(OverflowPolicy type);
vsi_enum TranslateRoundingPolicy(RoundingPolicy type);
vsi_enum TranslateDownScaleSizeRounding(DownScaleSizeRounding type);
vsi_enum TranslateDownScaleSizeRounding(RoundType type);
vsi_enum TranslateResizeType(ResizeType type);
vx_bool_e ToVxBool(bool val);
} // namespace vx