add GetElementNum/GetElementByteSize/GetByteSize for TensorSpec (#393)

This commit is contained in:
Antkillerfarm 2022-05-13 14:29:25 +08:00 committed by GitHub
parent 0d8ac3dc2b
commit b3677305c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 84 additions and 33 deletions

View File

@ -94,45 +94,25 @@ struct TensorSpec {
this->quantization_ = quantization;
}
TensorSpec(const TensorSpec& other) {
this->datatype_ = other.datatype_;
this->shape_ = other.shape_;
this->attr_ = other.attr_;
this->quantization_ = other.quantization_;
}
TensorSpec(const TensorSpec& other);
TensorSpec& operator =(const TensorSpec& other) {
this->datatype_ = other.datatype_;
this->shape_ = other.shape_;
this->attr_ = other.attr_;
this->quantization_ = other.quantization_;
return *this;
}
TensorSpec& operator=(const TensorSpec& other);
TensorSpec& SetDataType(DataType datatype) {
this->datatype_ = datatype;
return *this;
}
TensorSpec& SetDataType(DataType datatype);
TensorSpec& SetShape(ShapeType& shape) {
this->shape_ = shape;
return *this;
}
TensorSpec& SetShape(ShapeType& shape);
TensorSpec& SetAttribute(TensorAttribute attr) {
this->attr_ = attr;
return *this;
}
TensorSpec& SetAttribute(TensorAttribute attr);
TensorSpec& SetQuantization(Quantization& quantization) {
this->quantization_ = quantization;
return *this;
}
TensorSpec& SetQuantization(Quantization& quantization);
TensorSpec AsTransientSpec() const {
return TensorSpec(this->datatype_, ShapeType({}),
TensorAttribute::TRANSIENT, this->quantization_);
}
TensorSpec AsTransientSpec() const;
int64_t GetElementNum() const;
int64_t GetElementByteSize() const;
int64_t GetByteSize() const;
DataType datatype_;
ShapeType shape_;

View File

@ -240,5 +240,76 @@ bool TensorImpl::IsReadable() {
return spec_.attr_ != TensorAttribute::TRANSIENT;
}
TensorSpec::TensorSpec(const TensorSpec& other) {
this->datatype_ = other.datatype_;
this->shape_ = other.shape_;
this->attr_ = other.attr_;
this->quantization_ = other.quantization_;
}
TensorSpec& TensorSpec::operator=(const TensorSpec& other) {
this->datatype_ = other.datatype_;
this->shape_ = other.shape_;
this->attr_ = other.attr_;
this->quantization_ = other.quantization_;
return *this;
}
TensorSpec& TensorSpec::SetDataType(DataType datatype) {
this->datatype_ = datatype;
return *this;
}
TensorSpec& TensorSpec::SetShape(ShapeType& shape) {
this->shape_ = shape;
return *this;
}
TensorSpec& TensorSpec::SetAttribute(TensorAttribute attr) {
this->attr_ = attr;
return *this;
}
TensorSpec& TensorSpec::SetQuantization(Quantization& quantization) {
this->quantization_ = quantization;
return *this;
}
TensorSpec TensorSpec::AsTransientSpec() const {
return TensorSpec(this->datatype_, ShapeType({}), TensorAttribute::TRANSIENT,
this->quantization_);
}
int64_t TensorSpec::GetElementNum() const {
int64_t count = 1;
for (auto dim : shape_) {
count *= dim;
}
return count;
}
int64_t TensorSpec::GetElementByteSize() const {
switch (datatype_) {
case DataType::INT8:
case DataType::UINT8:
case DataType::BOOL8:
return 1;
case DataType::INT16:
case DataType::UINT16:
case DataType::FLOAT16:
return 2;
case DataType::INT32:
case DataType::UINT32:
case DataType::FLOAT32:
return 4;
default:
return 1;
}
}
int64_t TensorSpec::GetByteSize() const {
return GetElementNum() * GetElementByteSize();
}
} // namespace vx
} // namespace tim