Remove DownScaleSizeRounding type
Use RoundType instead of DownScaleSizeRounding. Signed-off-by: zhao.xia <zhao.xia@verisilicon.com>
This commit is contained in:
parent
eccc117ec5
commit
37f686c34d
|
|
@ -43,9 +43,8 @@ class Operation {
|
||||||
Operation& BindOutputs(const std::vector<std::shared_ptr<Tensor>>& tensors);
|
Operation& BindOutputs(const std::vector<std::shared_ptr<Tensor>>& tensors);
|
||||||
Operation& SetRoundingPolicy(
|
Operation& SetRoundingPolicy(
|
||||||
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
||||||
RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO,
|
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
|
||||||
DownScaleSizeRounding down_scale_size_rounding =
|
RoundType down_scale_size_rounding = RoundType::FLOOR,
|
||||||
DownScaleSizeRounding::FLOOR,
|
|
||||||
uint32_t accumulator_bits = 0);
|
uint32_t accumulator_bits = 0);
|
||||||
std::unique_ptr<OperationImpl>& impl();
|
std::unique_ptr<OperationImpl>& impl();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,8 +60,6 @@ enum class OverflowPolicy { WRAP, SATURATE };
|
||||||
|
|
||||||
enum class RoundingPolicy { TO_ZERO, RTNE };
|
enum class RoundingPolicy { TO_ZERO, RTNE };
|
||||||
|
|
||||||
enum class DownScaleSizeRounding { FLOOR, CEILING };
|
|
||||||
|
|
||||||
enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA };
|
enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA };
|
||||||
|
|
||||||
enum class DataLayout {
|
enum class DataLayout {
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ OperationImpl& OperationImpl::BindOutput(
|
||||||
|
|
||||||
OperationImpl& OperationImpl::SetRoundingPolicy(
|
OperationImpl& OperationImpl::SetRoundingPolicy(
|
||||||
OverflowPolicy overflow_policy, RoundingPolicy rounding_policy,
|
OverflowPolicy overflow_policy, RoundingPolicy rounding_policy,
|
||||||
DownScaleSizeRounding down_scale_size_rounding, uint32_t accumulator_bits) {
|
RoundType down_scale_size_rounding, uint32_t accumulator_bits) {
|
||||||
node_->vx_param.overflow_policy = TranslateOverflowPolicy(overflow_policy);
|
node_->vx_param.overflow_policy = TranslateOverflowPolicy(overflow_policy);
|
||||||
node_->vx_param.rounding_policy = TranslateRoundingPolicy(rounding_policy);
|
node_->vx_param.rounding_policy = TranslateRoundingPolicy(rounding_policy);
|
||||||
node_->vx_param.down_scale_size_rounding =
|
node_->vx_param.down_scale_size_rounding =
|
||||||
|
|
@ -104,7 +104,7 @@ Operation& Operation::BindOutput(const std::shared_ptr<Tensor>& tensor) {
|
||||||
|
|
||||||
Operation& Operation::SetRoundingPolicy(
|
Operation& Operation::SetRoundingPolicy(
|
||||||
OverflowPolicy overflow_policy, RoundingPolicy rounding_policy,
|
OverflowPolicy overflow_policy, RoundingPolicy rounding_policy,
|
||||||
DownScaleSizeRounding down_scale_size_rounding, uint32_t accumulator_bits) {
|
RoundType down_scale_size_rounding, uint32_t accumulator_bits) {
|
||||||
impl_->SetRoundingPolicy(overflow_policy, rounding_policy,
|
impl_->SetRoundingPolicy(overflow_policy, rounding_policy,
|
||||||
down_scale_size_rounding, accumulator_bits);
|
down_scale_size_rounding, accumulator_bits);
|
||||||
return *this;
|
return *this;
|
||||||
|
|
|
||||||
|
|
@ -41,8 +41,7 @@ class OperationImpl {
|
||||||
OperationImpl& SetRoundingPolicy(
|
OperationImpl& SetRoundingPolicy(
|
||||||
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
||||||
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
|
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
|
||||||
DownScaleSizeRounding down_scale_size_rounding =
|
RoundType down_scale_size_rounding = RoundType::FLOOR,
|
||||||
DownScaleSizeRounding::FLOOR,
|
|
||||||
uint32_t accumulator_bits = 0);
|
uint32_t accumulator_bits = 0);
|
||||||
|
|
||||||
vsi_nn_node_t* node() { return this->node_; }
|
vsi_nn_node_t* node() { return this->node_; }
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding,
|
||||||
this->impl()->node()->nn_param.pool.stride[0] = stride_[0];
|
this->impl()->node()->nn_param.pool.stride[0] = stride_[0];
|
||||||
this->impl()->node()->nn_param.pool.stride[1] = stride_[1];
|
this->impl()->node()->nn_param.pool.stride[1] = stride_[1];
|
||||||
this->impl()->node()->nn_param.pool.pad_type = TranslatePadType(padding_);
|
this->impl()->node()->nn_param.pool.pad_type = TranslatePadType(padding_);
|
||||||
|
this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
||||||
|
|
@ -134,11 +134,11 @@ vsi_enum TranslateRoundingPolicy(RoundingPolicy type) {
|
||||||
return VX_ROUND_POLICY_TO_NEAREST_EVEN;
|
return VX_ROUND_POLICY_TO_NEAREST_EVEN;
|
||||||
}
|
}
|
||||||
|
|
||||||
vsi_enum TranslateDownScaleSizeRounding(DownScaleSizeRounding type) {
|
vsi_enum TranslateDownScaleSizeRounding(RoundType type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case DownScaleSizeRounding::FLOOR:
|
case RoundType::FLOOR:
|
||||||
return VX_CONVOLUTIONAL_NETWORK_DS_SIZE_ROUNDING_FLOOR;
|
return VX_CONVOLUTIONAL_NETWORK_DS_SIZE_ROUNDING_FLOOR;
|
||||||
case DownScaleSizeRounding::CEILING:
|
case RoundType::CEILING:
|
||||||
return VX_CONVOLUTIONAL_NETWORK_DS_SIZE_ROUNDING_CEILING;
|
return VX_CONVOLUTIONAL_NETWORK_DS_SIZE_ROUNDING_CEILING;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ vsi_enum TranslatePoolType(PoolType type);
|
||||||
vsi_nn_round_type_e TranslateRoundType(RoundType type);
|
vsi_nn_round_type_e TranslateRoundType(RoundType type);
|
||||||
vsi_enum TranslateOverflowPolicy(OverflowPolicy type);
|
vsi_enum TranslateOverflowPolicy(OverflowPolicy type);
|
||||||
vsi_enum TranslateRoundingPolicy(RoundingPolicy type);
|
vsi_enum TranslateRoundingPolicy(RoundingPolicy type);
|
||||||
vsi_enum TranslateDownScaleSizeRounding(DownScaleSizeRounding type);
|
vsi_enum TranslateDownScaleSizeRounding(RoundType type);
|
||||||
vsi_enum TranslateResizeType(ResizeType type);
|
vsi_enum TranslateResizeType(ResizeType type);
|
||||||
vx_bool_e ToVxBool(bool val);
|
vx_bool_e ToVxBool(bool val);
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue