diff --git a/include/tim/vx/types.h b/include/tim/vx/types.h index 7418fd4..3c8c7bb 100644 --- a/include/tim/vx/types.h +++ b/include/tim/vx/types.h @@ -43,7 +43,13 @@ enum class DataType { UINT4 }; -enum class QuantType { NONE, ASYMMETRIC, SYMMETRIC_PER_CHANNEL, DYNAMIC_FIXED_POINT }; +enum class QuantType { + NONE, + ASYMMETRIC, + SYMMETRIC_PER_CHANNEL, + ASYMMETRIC_PER_CHANNEL, + DYNAMIC_FIXED_POINT +}; enum TensorAttribute { CONSTANT = 1 << 0, diff --git a/src/tim/vx/tensor.cc b/src/tim/vx/tensor.cc index b6dd9ac..0450cbf 100644 --- a/src/tim/vx/tensor.cc +++ b/src/tim/vx/tensor.cc @@ -53,7 +53,26 @@ void PackTensorDtype(tim::vx::TensorSpec& spec, vsi_nn_dtype_t* dtype) { break; case tim::vx::QuantType::SYMMETRIC_PER_CHANNEL: { dtype->scales = spec.quantization_.Scales().data(); - dtype->scale_dim = spec.quantization_.ZeroPoints().size(); + dtype->scale_dim = spec.quantization_.Scales().size(); +#if (VSI_NN_VERSION_MAJOR == 1 && VSI_NN_VERSION_MINOR == 1 && \ + VSI_NN_VERSION_PATCH <= 18) + { + std::vector zps(spec.quantization_.ZeroPoints().size()); + std::transform(spec.quantization_.ZeroPoints().begin(), + spec.quantization_.ZeroPoints().end(), zps.begin(), + [](const int& it) { return static_cast(it); }); + dtype->zero_points = zps.data(); + } +#else + dtype->zero_points = spec.quantization_.ZeroPoints().data(); +#endif + dtype->zero_points_dim = spec.quantization_.ZeroPoints().size(); + dtype->channel_dim = spec.quantization_.ChannelDim(); + break; + } + case tim::vx::QuantType::ASYMMETRIC_PER_CHANNEL: { + dtype->scales = spec.quantization_.Scales().data(); + dtype->scale_dim = spec.quantization_.Scales().size(); #if (VSI_NN_VERSION_MAJOR == 1 && VSI_NN_VERSION_MINOR == 1 && \ VSI_NN_VERSION_PATCH <= 18) { diff --git a/src/tim/vx/type_utils.cc b/src/tim/vx/type_utils.cc index a701452..f1e4169 100644 --- a/src/tim/vx/type_utils.cc +++ b/src/tim/vx/type_utils.cc @@ -65,6 +65,8 @@ vsi_nn_qnt_type_e TranslateQuantType(QuantType qtype) { return VSI_NN_QNT_TYPE_AFFINE_ASYMMETRIC; case QuantType::SYMMETRIC_PER_CHANNEL: return VSI_NN_QNT_TYPE_AFFINE_PERCHANNEL_SYMMETRIC; + case QuantType::ASYMMETRIC_PER_CHANNEL: + return VSI_NN_QNT_TYPE_AFFINE_PERCHANNEL_ASYMMETRIC; case QuantType::DYNAMIC_FIXED_POINT: return VSI_NN_QNT_TYPE_DFP; default: