add GetElementNum/GetElementByteSize/GetByteSize for TensorSpec (#393)
This commit is contained in:
parent
0d8ac3dc2b
commit
b3677305c4
|
|
@ -94,45 +94,25 @@ struct TensorSpec {
|
|||
this->quantization_ = quantization;
|
||||
}
|
||||
|
||||
TensorSpec(const TensorSpec& other) {
|
||||
this->datatype_ = other.datatype_;
|
||||
this->shape_ = other.shape_;
|
||||
this->attr_ = other.attr_;
|
||||
this->quantization_ = other.quantization_;
|
||||
}
|
||||
TensorSpec(const TensorSpec& other);
|
||||
|
||||
TensorSpec& operator =(const TensorSpec& other) {
|
||||
this->datatype_ = other.datatype_;
|
||||
this->shape_ = other.shape_;
|
||||
this->attr_ = other.attr_;
|
||||
this->quantization_ = other.quantization_;
|
||||
return *this;
|
||||
}
|
||||
TensorSpec& operator=(const TensorSpec& other);
|
||||
|
||||
TensorSpec& SetDataType(DataType datatype) {
|
||||
this->datatype_ = datatype;
|
||||
return *this;
|
||||
}
|
||||
TensorSpec& SetDataType(DataType datatype);
|
||||
|
||||
TensorSpec& SetShape(ShapeType& shape) {
|
||||
this->shape_ = shape;
|
||||
return *this;
|
||||
}
|
||||
TensorSpec& SetShape(ShapeType& shape);
|
||||
|
||||
TensorSpec& SetAttribute(TensorAttribute attr) {
|
||||
this->attr_ = attr;
|
||||
return *this;
|
||||
}
|
||||
TensorSpec& SetAttribute(TensorAttribute attr);
|
||||
|
||||
TensorSpec& SetQuantization(Quantization& quantization) {
|
||||
this->quantization_ = quantization;
|
||||
return *this;
|
||||
}
|
||||
TensorSpec& SetQuantization(Quantization& quantization);
|
||||
|
||||
TensorSpec AsTransientSpec() const {
|
||||
return TensorSpec(this->datatype_, ShapeType({}),
|
||||
TensorAttribute::TRANSIENT, this->quantization_);
|
||||
}
|
||||
TensorSpec AsTransientSpec() const;
|
||||
|
||||
int64_t GetElementNum() const;
|
||||
|
||||
int64_t GetElementByteSize() const;
|
||||
|
||||
int64_t GetByteSize() const;
|
||||
|
||||
DataType datatype_;
|
||||
ShapeType shape_;
|
||||
|
|
|
|||
|
|
@ -240,5 +240,76 @@ bool TensorImpl::IsReadable() {
|
|||
return spec_.attr_ != TensorAttribute::TRANSIENT;
|
||||
}
|
||||
|
||||
TensorSpec::TensorSpec(const TensorSpec& other) {
|
||||
this->datatype_ = other.datatype_;
|
||||
this->shape_ = other.shape_;
|
||||
this->attr_ = other.attr_;
|
||||
this->quantization_ = other.quantization_;
|
||||
}
|
||||
|
||||
TensorSpec& TensorSpec::operator=(const TensorSpec& other) {
|
||||
this->datatype_ = other.datatype_;
|
||||
this->shape_ = other.shape_;
|
||||
this->attr_ = other.attr_;
|
||||
this->quantization_ = other.quantization_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
TensorSpec& TensorSpec::SetDataType(DataType datatype) {
|
||||
this->datatype_ = datatype;
|
||||
return *this;
|
||||
}
|
||||
|
||||
TensorSpec& TensorSpec::SetShape(ShapeType& shape) {
|
||||
this->shape_ = shape;
|
||||
return *this;
|
||||
}
|
||||
|
||||
TensorSpec& TensorSpec::SetAttribute(TensorAttribute attr) {
|
||||
this->attr_ = attr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
TensorSpec& TensorSpec::SetQuantization(Quantization& quantization) {
|
||||
this->quantization_ = quantization;
|
||||
return *this;
|
||||
}
|
||||
|
||||
TensorSpec TensorSpec::AsTransientSpec() const {
|
||||
return TensorSpec(this->datatype_, ShapeType({}), TensorAttribute::TRANSIENT,
|
||||
this->quantization_);
|
||||
}
|
||||
|
||||
int64_t TensorSpec::GetElementNum() const {
|
||||
int64_t count = 1;
|
||||
for (auto dim : shape_) {
|
||||
count *= dim;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
int64_t TensorSpec::GetElementByteSize() const {
|
||||
switch (datatype_) {
|
||||
case DataType::INT8:
|
||||
case DataType::UINT8:
|
||||
case DataType::BOOL8:
|
||||
return 1;
|
||||
case DataType::INT16:
|
||||
case DataType::UINT16:
|
||||
case DataType::FLOAT16:
|
||||
return 2;
|
||||
case DataType::INT32:
|
||||
case DataType::UINT32:
|
||||
case DataType::FLOAT32:
|
||||
return 4;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t TensorSpec::GetByteSize() const {
|
||||
return GetElementNum() * GetElementByteSize();
|
||||
}
|
||||
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
|
|||
Loading…
Reference in New Issue