From b3677305c4a39183aa718c766c2fcfbfe80068d5 Mon Sep 17 00:00:00 2001 From: Antkillerfarm Date: Fri, 13 May 2022 14:29:25 +0800 Subject: [PATCH] add GetElementNum/GetElementByteSize/GetByteSize for TensorSpec (#393) --- include/tim/vx/tensor.h | 46 ++++++++------------------ src/tim/vx/tensor.cc | 71 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 33 deletions(-) diff --git a/include/tim/vx/tensor.h b/include/tim/vx/tensor.h index 7297fcc..910a701 100644 --- a/include/tim/vx/tensor.h +++ b/include/tim/vx/tensor.h @@ -94,45 +94,25 @@ struct TensorSpec { this->quantization_ = quantization; } - TensorSpec(const TensorSpec& other) { - this->datatype_ = other.datatype_; - this->shape_ = other.shape_; - this->attr_ = other.attr_; - this->quantization_ = other.quantization_; - } + TensorSpec(const TensorSpec& other); - TensorSpec& operator =(const TensorSpec& other) { - this->datatype_ = other.datatype_; - this->shape_ = other.shape_; - this->attr_ = other.attr_; - this->quantization_ = other.quantization_; - return *this; - } + TensorSpec& operator=(const TensorSpec& other); - TensorSpec& SetDataType(DataType datatype) { - this->datatype_ = datatype; - return *this; - } + TensorSpec& SetDataType(DataType datatype); - TensorSpec& SetShape(ShapeType& shape) { - this->shape_ = shape; - return *this; - } + TensorSpec& SetShape(ShapeType& shape); - TensorSpec& SetAttribute(TensorAttribute attr) { - this->attr_ = attr; - return *this; - } + TensorSpec& SetAttribute(TensorAttribute attr); - TensorSpec& SetQuantization(Quantization& quantization) { - this->quantization_ = quantization; - return *this; - } + TensorSpec& SetQuantization(Quantization& quantization); - TensorSpec AsTransientSpec() const { - return TensorSpec(this->datatype_, ShapeType({}), - TensorAttribute::TRANSIENT, this->quantization_); - } + TensorSpec AsTransientSpec() const; + + int64_t GetElementNum() const; + + int64_t GetElementByteSize() const; + + int64_t GetByteSize() const; DataType datatype_; ShapeType shape_; diff --git a/src/tim/vx/tensor.cc b/src/tim/vx/tensor.cc index 884cbbb..2ee281a 100644 --- a/src/tim/vx/tensor.cc +++ b/src/tim/vx/tensor.cc @@ -240,5 +240,76 @@ bool TensorImpl::IsReadable() { return spec_.attr_ != TensorAttribute::TRANSIENT; } +TensorSpec::TensorSpec(const TensorSpec& other) { + this->datatype_ = other.datatype_; + this->shape_ = other.shape_; + this->attr_ = other.attr_; + this->quantization_ = other.quantization_; +} + +TensorSpec& TensorSpec::operator=(const TensorSpec& other) { + this->datatype_ = other.datatype_; + this->shape_ = other.shape_; + this->attr_ = other.attr_; + this->quantization_ = other.quantization_; + return *this; +} + +TensorSpec& TensorSpec::SetDataType(DataType datatype) { + this->datatype_ = datatype; + return *this; +} + +TensorSpec& TensorSpec::SetShape(ShapeType& shape) { + this->shape_ = shape; + return *this; +} + +TensorSpec& TensorSpec::SetAttribute(TensorAttribute attr) { + this->attr_ = attr; + return *this; +} + +TensorSpec& TensorSpec::SetQuantization(Quantization& quantization) { + this->quantization_ = quantization; + return *this; +} + +TensorSpec TensorSpec::AsTransientSpec() const { + return TensorSpec(this->datatype_, ShapeType({}), TensorAttribute::TRANSIENT, + this->quantization_); +} + +int64_t TensorSpec::GetElementNum() const { + int64_t count = 1; + for (auto dim : shape_) { + count *= dim; + } + return count; +} + +int64_t TensorSpec::GetElementByteSize() const { + switch (datatype_) { + case DataType::INT8: + case DataType::UINT8: + case DataType::BOOL8: + return 1; + case DataType::INT16: + case DataType::UINT16: + case DataType::FLOAT16: + return 2; + case DataType::INT32: + case DataType::UINT32: + case DataType::FLOAT32: + return 4; + default: + return 1; + } +} + +int64_t TensorSpec::GetByteSize() const { + return GetElementNum() * GetElementByteSize(); +} + } // namespace vx } // namespace tim