diff --git a/include/tim/vx/tensor.h b/include/tim/vx/tensor.h index ecf1fa8..20934c9 100644 --- a/include/tim/vx/tensor.h +++ b/include/tim/vx/tensor.h @@ -28,6 +28,7 @@ #include #include #include +#include #include "tim/vx/types.h" @@ -145,8 +146,13 @@ class Tensor { virtual void unmap() = 0; virtual bool IsPlaceHolder() = 0; virtual bool IsConstTensor() = 0; + virtual bool SaveTensorToTextByFp32(std::string filename) = 0; + virtual void* ConvertTensorToData(uint8_t* tensorData) = 0; }; - +namespace utils{ + bool Float32ToDtype(std::shared_ptr tensor, std::vector fval, uint8_t* tensorData); + bool DtypeToFloat32(std::shared_ptr tensor, uint8_t* tensorData, float* data); +} //namespace utils } // namespace vx } // namespace tim diff --git a/src/tim/vx/tensor.cc b/src/tim/vx/tensor.cc index ae5914d..581f8c2 100644 --- a/src/tim/vx/tensor.cc +++ b/src/tim/vx/tensor.cc @@ -118,6 +118,18 @@ TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, void* data) TensorImpl::~TensorImpl() {} +bool TensorImpl::SaveTensorToTextByFp32(std::string filename){ + vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); + vsi_nn_SaveTensorToTextByFp32(graph_->graph(), tensor, filename.c_str(), NULL); + return true; +} + +void* TensorImpl::ConvertTensorToData(uint8_t* tensorData){ + vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); + tensorData = vsi_nn_ConvertTensorToData(graph_->graph(), tensor); + return tensorData; +} + bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) { (void)size_in_bytes; if (!IsWriteable()) { @@ -432,6 +444,31 @@ int64_t TensorSpec::GetElementByteSize() const { int64_t TensorSpec::GetByteSize() const { return GetElementNum() * GetElementByteSize(); } +namespace utils{ +bool Float32ToDtype(std::shared_ptr tensor, std::vector fval, uint8_t* tensorData){ +bool retn = true; +vsi_nn_tensor_attr_t attr; +uint32_t sz = tensor->GetSpec().GetElementNum(); +uint32_t stride = tensor->GetSpec().GetElementByteSize(); +PackTensorDtype(tensor->GetSpec(), &attr.dtype); +for (uint32_t i = 0; i < sz; i++){ + retn = (VSI_SUCCESS == vsi_nn_Float32ToDtype(fval[i], &tensorData[i * stride], &attr.dtype)); + if (!retn) { + VSILOGE("Convert data fail"); + return retn; + } +} +return retn; +} +bool DtypeToFloat32(std::shared_ptr tensor, uint8_t* tensorData, float* data){ + bool retn = true; + vsi_nn_tensor_attr_t attr; + + PackTensorDtype(tensor->GetSpec(), &attr.dtype); + retn = (VSI_SUCCESS == vsi_nn_DtypeToFloat32(tensorData, data, &attr.dtype)); + return retn; +} +} //namespace utils } // namespace vx } // namespace tim diff --git a/src/tim/vx/tensor_private.h b/src/tim/vx/tensor_private.h index e22eb9a..904634c 100644 --- a/src/tim/vx/tensor_private.h +++ b/src/tim/vx/tensor_private.h @@ -41,21 +41,23 @@ class TensorImpl : public Tensor { bool IsWriteable(); bool IsReadable(); - const ShapeType& GetShape() { return spec_.shape_; } - DataType GetDataType() { return spec_.datatype_; } - const Quantization& GetQuantization() { return spec_.quantization_; } - TensorSpec& GetSpec() { return spec_; } - uint32_t GetId(); - bool CopyDataToTensor(const void* data, uint32_t size = 0); - bool CopyDataFromTensor(void* data); - bool FlushCacheForHandle(); - bool InvalidateCacheForHandle(); - void* map(bool invalidate_cpu_cache = false); - void unmap(); - bool IsPlaceHolder() { return false; } - bool IsConstTensor() { + const ShapeType& GetShape() override { return spec_.shape_; } + DataType GetDataType() override { return spec_.datatype_; } + const Quantization& GetQuantization() override { return spec_.quantization_; } + TensorSpec& GetSpec() override { return spec_; } + uint32_t GetId() override; + bool CopyDataToTensor(const void* data, uint32_t size = 0) override; + bool CopyDataFromTensor(void* data) override; + bool FlushCacheForHandle() override; + bool InvalidateCacheForHandle() override; + void* map(bool invalidate_cpu_cache = false) override; + void unmap() override; + bool IsPlaceHolder() override { return false; } + bool IsConstTensor() override { return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT; } + bool SaveTensorToTextByFp32(std::string filename) override; + void* ConvertTensorToData(uint8_t* tensorData) override; GraphImpl* graph_; vsi_nn_tensor_id_t id_; @@ -69,30 +71,38 @@ class TensorPlaceholder : public Tensor { TensorPlaceholder(Graph* graph) : id_(VSI_NN_TENSOR_ID_NA) {(void)(graph);} ~TensorPlaceholder(){}; - const ShapeType& GetShape() { return spec_.shape_; } - DataType GetDataType() { return spec_.datatype_; } - const Quantization& GetQuantization() { return spec_.quantization_; } - TensorSpec& GetSpec() { return spec_; } - uint32_t GetId() { return id_; }; - bool CopyDataToTensor(const void* data, uint32_t size = 0) { + const ShapeType& GetShape() override { return spec_.shape_; } + DataType GetDataType() override { return spec_.datatype_; } + const Quantization& GetQuantization() override { return spec_.quantization_; } + TensorSpec& GetSpec() override { return spec_; } + uint32_t GetId() override { return id_; }; + bool CopyDataToTensor(const void* data, uint32_t size = 0) override { (void)data, void(size); return false; } - bool CopyDataFromTensor(void* data) { + bool CopyDataFromTensor(void* data) override { (void)data; return false; } - bool InvalidateCacheForHandle() { return false; } - bool FlushCacheForHandle() { return false; } - void* map(bool invalidate_cpu_cache = false) { + bool InvalidateCacheForHandle() override { return false; } + bool FlushCacheForHandle() override { return false; } + void* map(bool invalidate_cpu_cache = false) override { (void)invalidate_cpu_cache; return nullptr; } - void unmap() { return; } - bool IsPlaceHolder() { return true; } - bool IsConstTensor() { + void unmap() override { return; } + bool IsPlaceHolder() override { return true; } + bool IsConstTensor() override { return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT; } + bool SaveTensorToTextByFp32(std::string filename) override { + (void)filename; + return false; + } + void* ConvertTensorToData(uint8_t* tensorData) override { + (void)tensorData; + return nullptr; + } vsi_nn_tensor_id_t id_; TensorSpec spec_;