Reload "==" operator for quantizations of two tensor (#583)

Reload operator "==" to check two quantization same or not

Type: New Feature
Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
Chen Feiyue 2023-05-10 16:58:30 +08:00 committed by GitHub
parent 308a967bcf
commit b81f7979fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 1 deletions

View File

@ -81,6 +81,8 @@ class Quantization {
const std::int8_t& Fl() const{ return this->fl_; } const std::int8_t& Fl() const{ return this->fl_; }
bool operator == (const Quantization& other_quant) const;
protected: protected:
QuantType type_{QuantType::NONE}; QuantType type_{QuantType::NONE};
int32_t channel_dim_{-1}; int32_t channel_dim_{-1};

View File

@ -444,6 +444,19 @@ int64_t TensorSpec::GetElementByteSize() const {
int64_t TensorSpec::GetByteSize() const { int64_t TensorSpec::GetByteSize() const {
return GetElementNum() * GetElementByteSize(); return GetElementNum() * GetElementByteSize();
} }
bool Quantization::operator == (const Quantization& other_quant) const {
if (type_ != tim::vx::QuantType::DYNAMIC_FIXED_POINT){
if(type_ == other_quant.type_ &&
scales_ == other_quant.scales_ &&
zero_points_ == other_quant.zero_points_ &&
channel_dim_ == other_quant.channel_dim_)
return true;
}
else if(fl_ == other_quant.fl_) return true;
return false;
}
namespace utils{ namespace utils{
bool Float32ToDtype(std::shared_ptr<tim::vx::Tensor> tensor, std::vector<float> fval, uint8_t* tensorData){ bool Float32ToDtype(std::shared_ptr<tim::vx::Tensor> tensor, std::vector<float> fval, uint8_t* tensorData){
bool retn = true; bool retn = true;
@ -462,7 +475,7 @@ return retn;
} }
bool DtypeToFloat32(std::shared_ptr<tim::vx::Tensor> tensor, uint8_t* tensorData, float* data){ bool DtypeToFloat32(std::shared_ptr<tim::vx::Tensor> tensor, uint8_t* tensorData, float* data){
bool retn = true; bool retn = true;
vsi_nn_tensor_attr_t attr; vsi_nn_tensor_attr_t attr;
PackTensorDtype(tensor->GetSpec(), &attr.dtype); PackTensorDtype(tensor->GetSpec(), &attr.dtype);