From b81f7979fade7504ea0feabeacdf56a90f797fd5 Mon Sep 17 00:00:00 2001 From: Chen Feiyue <69809761+chenfeiyue-cfy@users.noreply.github.com> Date: Wed, 10 May 2023 16:58:30 +0800 Subject: [PATCH] Reload "==" operator for quantizations of two tensor (#583) Reload operator "==" to check two quantization same or not Type: New Feature Signed-off-by: Feiyue Chen --- include/tim/vx/tensor.h | 2 ++ src/tim/vx/tensor.cc | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/include/tim/vx/tensor.h b/include/tim/vx/tensor.h index 20934c9..ac50e69 100644 --- a/include/tim/vx/tensor.h +++ b/include/tim/vx/tensor.h @@ -81,6 +81,8 @@ class Quantization { const std::int8_t& Fl() const{ return this->fl_; } + bool operator == (const Quantization& other_quant) const; + protected: QuantType type_{QuantType::NONE}; int32_t channel_dim_{-1}; diff --git a/src/tim/vx/tensor.cc b/src/tim/vx/tensor.cc index 581f8c2..a2b58f8 100644 --- a/src/tim/vx/tensor.cc +++ b/src/tim/vx/tensor.cc @@ -444,6 +444,19 @@ int64_t TensorSpec::GetElementByteSize() const { int64_t TensorSpec::GetByteSize() const { return GetElementNum() * GetElementByteSize(); } + +bool Quantization::operator == (const Quantization& other_quant) const { + if (type_ != tim::vx::QuantType::DYNAMIC_FIXED_POINT){ + if(type_ == other_quant.type_ && + scales_ == other_quant.scales_ && + zero_points_ == other_quant.zero_points_ && + channel_dim_ == other_quant.channel_dim_) + return true; + } + else if(fl_ == other_quant.fl_) return true; + return false; +} + namespace utils{ bool Float32ToDtype(std::shared_ptr tensor, std::vector fval, uint8_t* tensorData){ bool retn = true; @@ -462,7 +475,7 @@ return retn; } bool DtypeToFloat32(std::shared_ptr tensor, uint8_t* tensorData, float* data){ - bool retn = true; + bool retn = true; vsi_nn_tensor_attr_t attr; PackTensorDtype(tensor->GetSpec(), &attr.dtype);