fixed tensor cache mismatch issue (#644)

Type: Bug fix

Signed-off-by: Chen <jack.chen@verisilicon.com>
Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
chxin66 2023-08-30 14:23:20 +08:00 committed by GitHub
parent 5668856fc9
commit 01235266c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 4 deletions

View File

@ -102,6 +102,8 @@ struct TensorSpec {
TensorSpec(const TensorSpec& other);
bool operator==(const TensorSpec& other_spec) const;
TensorSpec& operator=(const TensorSpec& other);
TensorSpec& SetDataType(DataType datatype);

View File

@ -142,10 +142,7 @@ std::shared_ptr<Tensor> GraphImpl::GetTensorFromCache(const TensorSpec& spec,
std::shared_ptr<tim::vx::Tensor> tensor;
std::string md5_key = CalculateCacheKey(spec, data);
if (GetTensorCacheMap().find(md5_key) != GetTensorCacheMap().end() &&
GetTensorCacheMap()[md5_key]->GetQuantization().Scales() ==
spec.quantization_.Scales() &&
GetTensorCacheMap()[md5_key]->GetQuantization().ZeroPoints() ==
spec.quantization_.ZeroPoints()) {
GetTensorCacheMap()[md5_key]->GetSpec() == spec) {
tensor = GetTensorCacheMap()[md5_key];
} else {
tensor = std::make_shared<TensorImpl>(this, spec, data);

View File

@ -508,6 +508,15 @@ int64_t TensorSpec::GetByteSize() const {
return GetElementNum() * GetElementByteSize();
}
bool TensorSpec::operator==(const TensorSpec& other_spec) const {
if (datatype_ == other_spec.datatype_ && shape_ == other_spec.shape_ &&
attr_ == other_spec.attr_ && quantization_ == other_spec.quantization_) {
return true;
} else {
return false;
}
}
bool Quantization::operator==(const Quantization& other_quant) const {
if (type_ != tim::vx::QuantType::DYNAMIC_FIXED_POINT) {
if (type_ == other_quant.type_ && scales_ == other_quant.scales_ &&