diff --git a/include/tim/vx/tensor.h b/include/tim/vx/tensor.h index 12845c5..8ba8861 100644 --- a/include/tim/vx/tensor.h +++ b/include/tim/vx/tensor.h @@ -102,6 +102,8 @@ struct TensorSpec { TensorSpec(const TensorSpec& other); + bool operator==(const TensorSpec& other_spec) const; + TensorSpec& operator=(const TensorSpec& other); TensorSpec& SetDataType(DataType datatype); diff --git a/src/tim/vx/graph.cc b/src/tim/vx/graph.cc index 998a373..06ab0db 100644 --- a/src/tim/vx/graph.cc +++ b/src/tim/vx/graph.cc @@ -142,10 +142,7 @@ std::shared_ptr GraphImpl::GetTensorFromCache(const TensorSpec& spec, std::shared_ptr tensor; std::string md5_key = CalculateCacheKey(spec, data); if (GetTensorCacheMap().find(md5_key) != GetTensorCacheMap().end() && - GetTensorCacheMap()[md5_key]->GetQuantization().Scales() == - spec.quantization_.Scales() && - GetTensorCacheMap()[md5_key]->GetQuantization().ZeroPoints() == - spec.quantization_.ZeroPoints()) { + GetTensorCacheMap()[md5_key]->GetSpec() == spec) { tensor = GetTensorCacheMap()[md5_key]; } else { tensor = std::make_shared(this, spec, data); diff --git a/src/tim/vx/tensor.cc b/src/tim/vx/tensor.cc index 973d078..4386b08 100644 --- a/src/tim/vx/tensor.cc +++ b/src/tim/vx/tensor.cc @@ -508,6 +508,15 @@ int64_t TensorSpec::GetByteSize() const { return GetElementNum() * GetElementByteSize(); } +bool TensorSpec::operator==(const TensorSpec& other_spec) const { + if (datatype_ == other_spec.datatype_ && shape_ == other_spec.shape_ && + attr_ == other_spec.attr_ && quantization_ == other_spec.quantization_) { + return true; + } else { + return false; + } +} + bool Quantization::operator==(const Quantization& other_quant) const { if (type_ != tim::vx::QuantType::DYNAMIC_FIXED_POINT) { if (type_ == other_quant.type_ && scales_ == other_quant.scales_ &&