fixed tensor cache mismatch issue (#644)
Type: Bug fix Signed-off-by: Chen <jack.chen@verisilicon.com> Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
parent
5668856fc9
commit
01235266c5
|
|
@ -102,6 +102,8 @@ struct TensorSpec {
|
|||
|
||||
TensorSpec(const TensorSpec& other);
|
||||
|
||||
bool operator==(const TensorSpec& other_spec) const;
|
||||
|
||||
TensorSpec& operator=(const TensorSpec& other);
|
||||
|
||||
TensorSpec& SetDataType(DataType datatype);
|
||||
|
|
|
|||
|
|
@ -142,10 +142,7 @@ std::shared_ptr<Tensor> GraphImpl::GetTensorFromCache(const TensorSpec& spec,
|
|||
std::shared_ptr<tim::vx::Tensor> tensor;
|
||||
std::string md5_key = CalculateCacheKey(spec, data);
|
||||
if (GetTensorCacheMap().find(md5_key) != GetTensorCacheMap().end() &&
|
||||
GetTensorCacheMap()[md5_key]->GetQuantization().Scales() ==
|
||||
spec.quantization_.Scales() &&
|
||||
GetTensorCacheMap()[md5_key]->GetQuantization().ZeroPoints() ==
|
||||
spec.quantization_.ZeroPoints()) {
|
||||
GetTensorCacheMap()[md5_key]->GetSpec() == spec) {
|
||||
tensor = GetTensorCacheMap()[md5_key];
|
||||
} else {
|
||||
tensor = std::make_shared<TensorImpl>(this, spec, data);
|
||||
|
|
|
|||
|
|
@ -508,6 +508,15 @@ int64_t TensorSpec::GetByteSize() const {
|
|||
return GetElementNum() * GetElementByteSize();
|
||||
}
|
||||
|
||||
bool TensorSpec::operator==(const TensorSpec& other_spec) const {
|
||||
if (datatype_ == other_spec.datatype_ && shape_ == other_spec.shape_ &&
|
||||
attr_ == other_spec.attr_ && quantization_ == other_spec.quantization_) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool Quantization::operator==(const Quantization& other_quant) const {
|
||||
if (type_ != tim::vx::QuantType::DYNAMIC_FIXED_POINT) {
|
||||
if (type_ == other_quant.type_ && scales_ == other_quant.scales_ &&
|
||||
|
|
|
|||
Loading…
Reference in New Issue