Remove DownScaleSizeRounding type
Use RoundType instead of DownScaleSizeRounding. Signed-off-by: zhao.xia <zhao.xia@verisilicon.com>
This commit is contained in:
parent
eccc117ec5
commit
37f686c34d
|
|
@ -43,9 +43,8 @@ class Operation {
|
|||
Operation& BindOutputs(const std::vector<std::shared_ptr<Tensor>>& tensors);
|
||||
Operation& SetRoundingPolicy(
|
||||
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
||||
RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO,
|
||||
DownScaleSizeRounding down_scale_size_rounding =
|
||||
DownScaleSizeRounding::FLOOR,
|
||||
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
|
||||
RoundType down_scale_size_rounding = RoundType::FLOOR,
|
||||
uint32_t accumulator_bits = 0);
|
||||
std::unique_ptr<OperationImpl>& impl();
|
||||
|
||||
|
|
|
|||
|
|
@ -60,8 +60,6 @@ enum class OverflowPolicy { WRAP, SATURATE };
|
|||
|
||||
enum class RoundingPolicy { TO_ZERO, RTNE };
|
||||
|
||||
enum class DownScaleSizeRounding { FLOOR, CEILING };
|
||||
|
||||
enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA };
|
||||
|
||||
enum class DataLayout {
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ OperationImpl& OperationImpl::BindOutput(
|
|||
|
||||
OperationImpl& OperationImpl::SetRoundingPolicy(
|
||||
OverflowPolicy overflow_policy, RoundingPolicy rounding_policy,
|
||||
DownScaleSizeRounding down_scale_size_rounding, uint32_t accumulator_bits) {
|
||||
RoundType down_scale_size_rounding, uint32_t accumulator_bits) {
|
||||
node_->vx_param.overflow_policy = TranslateOverflowPolicy(overflow_policy);
|
||||
node_->vx_param.rounding_policy = TranslateRoundingPolicy(rounding_policy);
|
||||
node_->vx_param.down_scale_size_rounding =
|
||||
|
|
@ -104,7 +104,7 @@ Operation& Operation::BindOutput(const std::shared_ptr<Tensor>& tensor) {
|
|||
|
||||
Operation& Operation::SetRoundingPolicy(
|
||||
OverflowPolicy overflow_policy, RoundingPolicy rounding_policy,
|
||||
DownScaleSizeRounding down_scale_size_rounding, uint32_t accumulator_bits) {
|
||||
RoundType down_scale_size_rounding, uint32_t accumulator_bits) {
|
||||
impl_->SetRoundingPolicy(overflow_policy, rounding_policy,
|
||||
down_scale_size_rounding, accumulator_bits);
|
||||
return *this;
|
||||
|
|
|
|||
|
|
@ -41,8 +41,7 @@ class OperationImpl {
|
|||
OperationImpl& SetRoundingPolicy(
|
||||
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
||||
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
|
||||
DownScaleSizeRounding down_scale_size_rounding =
|
||||
DownScaleSizeRounding::FLOOR,
|
||||
RoundType down_scale_size_rounding = RoundType::FLOOR,
|
||||
uint32_t accumulator_bits = 0);
|
||||
|
||||
vsi_nn_node_t* node() { return this->node_; }
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding,
|
|||
this->impl()->node()->nn_param.pool.stride[0] = stride_[0];
|
||||
this->impl()->node()->nn_param.pool.stride[1] = stride_[1];
|
||||
this->impl()->node()->nn_param.pool.pad_type = TranslatePadType(padding_);
|
||||
this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
|
|
|
|||
|
|
@ -134,11 +134,11 @@ vsi_enum TranslateRoundingPolicy(RoundingPolicy type) {
|
|||
return VX_ROUND_POLICY_TO_NEAREST_EVEN;
|
||||
}
|
||||
|
||||
vsi_enum TranslateDownScaleSizeRounding(DownScaleSizeRounding type) {
|
||||
vsi_enum TranslateDownScaleSizeRounding(RoundType type) {
|
||||
switch (type) {
|
||||
case DownScaleSizeRounding::FLOOR:
|
||||
case RoundType::FLOOR:
|
||||
return VX_CONVOLUTIONAL_NETWORK_DS_SIZE_ROUNDING_FLOOR;
|
||||
case DownScaleSizeRounding::CEILING:
|
||||
case RoundType::CEILING:
|
||||
return VX_CONVOLUTIONAL_NETWORK_DS_SIZE_ROUNDING_CEILING;
|
||||
default:
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ vsi_enum TranslatePoolType(PoolType type);
|
|||
vsi_nn_round_type_e TranslateRoundType(RoundType type);
|
||||
vsi_enum TranslateOverflowPolicy(OverflowPolicy type);
|
||||
vsi_enum TranslateRoundingPolicy(RoundingPolicy type);
|
||||
vsi_enum TranslateDownScaleSizeRounding(DownScaleSizeRounding type);
|
||||
vsi_enum TranslateDownScaleSizeRounding(RoundType type);
|
||||
vsi_enum TranslateResizeType(ResizeType type);
|
||||
vx_bool_e ToVxBool(bool val);
|
||||
} // namespace vx
|
||||
|
|
|
|||
Loading…
Reference in New Issue