From 37f686c34d90865530c6e92d2173970d96c622ce Mon Sep 17 00:00:00 2001 From: "zhao.xia" Date: Tue, 25 May 2021 16:03:05 +0800 Subject: [PATCH] Remove DownScaleSizeRounding type Use RoundType instead of DownScaleSizeRounding. Signed-off-by: zhao.xia --- include/tim/vx/operation.h | 5 ++--- include/tim/vx/types.h | 2 -- src/tim/vx/operation.cc | 4 ++-- src/tim/vx/operation_private.h | 3 +-- src/tim/vx/ops/pool2d.cc | 1 + src/tim/vx/type_utils.cc | 6 +++--- src/tim/vx/type_utils.h | 2 +- 7 files changed, 10 insertions(+), 13 deletions(-) diff --git a/include/tim/vx/operation.h b/include/tim/vx/operation.h index 5736585..2367879 100644 --- a/include/tim/vx/operation.h +++ b/include/tim/vx/operation.h @@ -43,9 +43,8 @@ class Operation { Operation& BindOutputs(const std::vector>& tensors); Operation& SetRoundingPolicy( OverflowPolicy overflow_policy = OverflowPolicy::SATURATE, - RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO, - DownScaleSizeRounding down_scale_size_rounding = - DownScaleSizeRounding::FLOOR, + RoundingPolicy rounding_policy = RoundingPolicy::RTNE, + RoundType down_scale_size_rounding = RoundType::FLOOR, uint32_t accumulator_bits = 0); std::unique_ptr& impl(); diff --git a/include/tim/vx/types.h b/include/tim/vx/types.h index 2f9554c..9ea9727 100644 --- a/include/tim/vx/types.h +++ b/include/tim/vx/types.h @@ -60,8 +60,6 @@ enum class OverflowPolicy { WRAP, SATURATE }; enum class RoundingPolicy { TO_ZERO, RTNE }; -enum class DownScaleSizeRounding { FLOOR, CEILING }; - enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA }; enum class DataLayout { diff --git a/src/tim/vx/operation.cc b/src/tim/vx/operation.cc index a757968..d180db4 100644 --- a/src/tim/vx/operation.cc +++ b/src/tim/vx/operation.cc @@ -70,7 +70,7 @@ OperationImpl& OperationImpl::BindOutput( OperationImpl& OperationImpl::SetRoundingPolicy( OverflowPolicy overflow_policy, RoundingPolicy rounding_policy, - DownScaleSizeRounding down_scale_size_rounding, uint32_t accumulator_bits) { + RoundType down_scale_size_rounding, uint32_t accumulator_bits) { node_->vx_param.overflow_policy = TranslateOverflowPolicy(overflow_policy); node_->vx_param.rounding_policy = TranslateRoundingPolicy(rounding_policy); node_->vx_param.down_scale_size_rounding = @@ -104,7 +104,7 @@ Operation& Operation::BindOutput(const std::shared_ptr& tensor) { Operation& Operation::SetRoundingPolicy( OverflowPolicy overflow_policy, RoundingPolicy rounding_policy, - DownScaleSizeRounding down_scale_size_rounding, uint32_t accumulator_bits) { + RoundType down_scale_size_rounding, uint32_t accumulator_bits) { impl_->SetRoundingPolicy(overflow_policy, rounding_policy, down_scale_size_rounding, accumulator_bits); return *this; diff --git a/src/tim/vx/operation_private.h b/src/tim/vx/operation_private.h index 48c6612..176b8a1 100644 --- a/src/tim/vx/operation_private.h +++ b/src/tim/vx/operation_private.h @@ -41,8 +41,7 @@ class OperationImpl { OperationImpl& SetRoundingPolicy( OverflowPolicy overflow_policy = OverflowPolicy::SATURATE, RoundingPolicy rounding_policy = RoundingPolicy::RTNE, - DownScaleSizeRounding down_scale_size_rounding = - DownScaleSizeRounding::FLOOR, + RoundType down_scale_size_rounding = RoundType::FLOOR, uint32_t accumulator_bits = 0); vsi_nn_node_t* node() { return this->node_; } diff --git a/src/tim/vx/ops/pool2d.cc b/src/tim/vx/ops/pool2d.cc index 82bd0cd..c3771dc 100644 --- a/src/tim/vx/ops/pool2d.cc +++ b/src/tim/vx/ops/pool2d.cc @@ -49,6 +49,7 @@ Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding, this->impl()->node()->nn_param.pool.stride[0] = stride_[0]; this->impl()->node()->nn_param.pool.stride[1] = stride_[1]; this->impl()->node()->nn_param.pool.pad_type = TranslatePadType(padding_); + this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_); } } // namespace ops diff --git a/src/tim/vx/type_utils.cc b/src/tim/vx/type_utils.cc index 12e4991..b84e1ed 100644 --- a/src/tim/vx/type_utils.cc +++ b/src/tim/vx/type_utils.cc @@ -134,11 +134,11 @@ vsi_enum TranslateRoundingPolicy(RoundingPolicy type) { return VX_ROUND_POLICY_TO_NEAREST_EVEN; } -vsi_enum TranslateDownScaleSizeRounding(DownScaleSizeRounding type) { +vsi_enum TranslateDownScaleSizeRounding(RoundType type) { switch (type) { - case DownScaleSizeRounding::FLOOR: + case RoundType::FLOOR: return VX_CONVOLUTIONAL_NETWORK_DS_SIZE_ROUNDING_FLOOR; - case DownScaleSizeRounding::CEILING: + case RoundType::CEILING: return VX_CONVOLUTIONAL_NETWORK_DS_SIZE_ROUNDING_CEILING; default: break; diff --git a/src/tim/vx/type_utils.h b/src/tim/vx/type_utils.h index 140e1e2..5a4c9d9 100644 --- a/src/tim/vx/type_utils.h +++ b/src/tim/vx/type_utils.h @@ -36,7 +36,7 @@ vsi_enum TranslatePoolType(PoolType type); vsi_nn_round_type_e TranslateRoundType(RoundType type); vsi_enum TranslateOverflowPolicy(OverflowPolicy type); vsi_enum TranslateRoundingPolicy(RoundingPolicy type); -vsi_enum TranslateDownScaleSizeRounding(DownScaleSizeRounding type); +vsi_enum TranslateDownScaleSizeRounding(RoundType type); vsi_enum TranslateResizeType(ResizeType type); vx_bool_e ToVxBool(bool val); } // namespace vx