Add some tensor dtype convert APIs (#576)
For pre/post process. Type: Code Refine Co-authored-by: wangqian <wangqian@CNCDD9444.verisilicon.com>
This commit is contained in:
parent
c688ca6e81
commit
1543efe098
|
|
@ -28,6 +28,7 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "tim/vx/types.h"
|
#include "tim/vx/types.h"
|
||||||
|
|
||||||
|
|
@ -145,8 +146,13 @@ class Tensor {
|
||||||
virtual void unmap() = 0;
|
virtual void unmap() = 0;
|
||||||
virtual bool IsPlaceHolder() = 0;
|
virtual bool IsPlaceHolder() = 0;
|
||||||
virtual bool IsConstTensor() = 0;
|
virtual bool IsConstTensor() = 0;
|
||||||
|
virtual bool SaveTensorToTextByFp32(std::string filename) = 0;
|
||||||
|
virtual void* ConvertTensorToData(uint8_t* tensorData) = 0;
|
||||||
};
|
};
|
||||||
|
namespace utils{
|
||||||
|
bool Float32ToDtype(std::shared_ptr<tim::vx::Tensor> tensor, std::vector<float> fval, uint8_t* tensorData);
|
||||||
|
bool DtypeToFloat32(std::shared_ptr<tim::vx::Tensor> tensor, uint8_t* tensorData, float* data);
|
||||||
|
} //namespace utils
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -118,6 +118,18 @@ TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, void* data)
|
||||||
|
|
||||||
TensorImpl::~TensorImpl() {}
|
TensorImpl::~TensorImpl() {}
|
||||||
|
|
||||||
|
bool TensorImpl::SaveTensorToTextByFp32(std::string filename){
|
||||||
|
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
|
||||||
|
vsi_nn_SaveTensorToTextByFp32(graph_->graph(), tensor, filename.c_str(), NULL);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* TensorImpl::ConvertTensorToData(uint8_t* tensorData){
|
||||||
|
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
|
||||||
|
tensorData = vsi_nn_ConvertTensorToData(graph_->graph(), tensor);
|
||||||
|
return tensorData;
|
||||||
|
}
|
||||||
|
|
||||||
bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) {
|
bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) {
|
||||||
(void)size_in_bytes;
|
(void)size_in_bytes;
|
||||||
if (!IsWriteable()) {
|
if (!IsWriteable()) {
|
||||||
|
|
@ -432,6 +444,31 @@ int64_t TensorSpec::GetElementByteSize() const {
|
||||||
int64_t TensorSpec::GetByteSize() const {
|
int64_t TensorSpec::GetByteSize() const {
|
||||||
return GetElementNum() * GetElementByteSize();
|
return GetElementNum() * GetElementByteSize();
|
||||||
}
|
}
|
||||||
|
namespace utils{
|
||||||
|
bool Float32ToDtype(std::shared_ptr<tim::vx::Tensor> tensor, std::vector<float> fval, uint8_t* tensorData){
|
||||||
|
bool retn = true;
|
||||||
|
vsi_nn_tensor_attr_t attr;
|
||||||
|
uint32_t sz = tensor->GetSpec().GetElementNum();
|
||||||
|
uint32_t stride = tensor->GetSpec().GetElementByteSize();
|
||||||
|
PackTensorDtype(tensor->GetSpec(), &attr.dtype);
|
||||||
|
for (uint32_t i = 0; i < sz; i++){
|
||||||
|
retn = (VSI_SUCCESS == vsi_nn_Float32ToDtype(fval[i], &tensorData[i * stride], &attr.dtype));
|
||||||
|
if (!retn) {
|
||||||
|
VSILOGE("Convert data fail");
|
||||||
|
return retn;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return retn;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DtypeToFloat32(std::shared_ptr<tim::vx::Tensor> tensor, uint8_t* tensorData, float* data){
|
||||||
|
bool retn = true;
|
||||||
|
vsi_nn_tensor_attr_t attr;
|
||||||
|
|
||||||
|
PackTensorDtype(tensor->GetSpec(), &attr.dtype);
|
||||||
|
retn = (VSI_SUCCESS == vsi_nn_DtypeToFloat32(tensorData, data, &attr.dtype));
|
||||||
|
return retn;
|
||||||
|
}
|
||||||
|
} //namespace utils
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
|
||||||
|
|
@ -41,21 +41,23 @@ class TensorImpl : public Tensor {
|
||||||
bool IsWriteable();
|
bool IsWriteable();
|
||||||
bool IsReadable();
|
bool IsReadable();
|
||||||
|
|
||||||
const ShapeType& GetShape() { return spec_.shape_; }
|
const ShapeType& GetShape() override { return spec_.shape_; }
|
||||||
DataType GetDataType() { return spec_.datatype_; }
|
DataType GetDataType() override { return spec_.datatype_; }
|
||||||
const Quantization& GetQuantization() { return spec_.quantization_; }
|
const Quantization& GetQuantization() override { return spec_.quantization_; }
|
||||||
TensorSpec& GetSpec() { return spec_; }
|
TensorSpec& GetSpec() override { return spec_; }
|
||||||
uint32_t GetId();
|
uint32_t GetId() override;
|
||||||
bool CopyDataToTensor(const void* data, uint32_t size = 0);
|
bool CopyDataToTensor(const void* data, uint32_t size = 0) override;
|
||||||
bool CopyDataFromTensor(void* data);
|
bool CopyDataFromTensor(void* data) override;
|
||||||
bool FlushCacheForHandle();
|
bool FlushCacheForHandle() override;
|
||||||
bool InvalidateCacheForHandle();
|
bool InvalidateCacheForHandle() override;
|
||||||
void* map(bool invalidate_cpu_cache = false);
|
void* map(bool invalidate_cpu_cache = false) override;
|
||||||
void unmap();
|
void unmap() override;
|
||||||
bool IsPlaceHolder() { return false; }
|
bool IsPlaceHolder() override { return false; }
|
||||||
bool IsConstTensor() {
|
bool IsConstTensor() override {
|
||||||
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
|
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
|
||||||
}
|
}
|
||||||
|
bool SaveTensorToTextByFp32(std::string filename) override;
|
||||||
|
void* ConvertTensorToData(uint8_t* tensorData) override;
|
||||||
|
|
||||||
GraphImpl* graph_;
|
GraphImpl* graph_;
|
||||||
vsi_nn_tensor_id_t id_;
|
vsi_nn_tensor_id_t id_;
|
||||||
|
|
@ -69,30 +71,38 @@ class TensorPlaceholder : public Tensor {
|
||||||
TensorPlaceholder(Graph* graph) : id_(VSI_NN_TENSOR_ID_NA) {(void)(graph);}
|
TensorPlaceholder(Graph* graph) : id_(VSI_NN_TENSOR_ID_NA) {(void)(graph);}
|
||||||
~TensorPlaceholder(){};
|
~TensorPlaceholder(){};
|
||||||
|
|
||||||
const ShapeType& GetShape() { return spec_.shape_; }
|
const ShapeType& GetShape() override { return spec_.shape_; }
|
||||||
DataType GetDataType() { return spec_.datatype_; }
|
DataType GetDataType() override { return spec_.datatype_; }
|
||||||
const Quantization& GetQuantization() { return spec_.quantization_; }
|
const Quantization& GetQuantization() override { return spec_.quantization_; }
|
||||||
TensorSpec& GetSpec() { return spec_; }
|
TensorSpec& GetSpec() override { return spec_; }
|
||||||
uint32_t GetId() { return id_; };
|
uint32_t GetId() override { return id_; };
|
||||||
bool CopyDataToTensor(const void* data, uint32_t size = 0) {
|
bool CopyDataToTensor(const void* data, uint32_t size = 0) override {
|
||||||
(void)data, void(size);
|
(void)data, void(size);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool CopyDataFromTensor(void* data) {
|
bool CopyDataFromTensor(void* data) override {
|
||||||
(void)data;
|
(void)data;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool InvalidateCacheForHandle() { return false; }
|
bool InvalidateCacheForHandle() override { return false; }
|
||||||
bool FlushCacheForHandle() { return false; }
|
bool FlushCacheForHandle() override { return false; }
|
||||||
void* map(bool invalidate_cpu_cache = false) {
|
void* map(bool invalidate_cpu_cache = false) override {
|
||||||
(void)invalidate_cpu_cache;
|
(void)invalidate_cpu_cache;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
void unmap() { return; }
|
void unmap() override { return; }
|
||||||
bool IsPlaceHolder() { return true; }
|
bool IsPlaceHolder() override { return true; }
|
||||||
bool IsConstTensor() {
|
bool IsConstTensor() override {
|
||||||
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
|
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
|
||||||
}
|
}
|
||||||
|
bool SaveTensorToTextByFp32(std::string filename) override {
|
||||||
|
(void)filename;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
void* ConvertTensorToData(uint8_t* tensorData) override {
|
||||||
|
(void)tensorData;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
vsi_nn_tensor_id_t id_;
|
vsi_nn_tensor_id_t id_;
|
||||||
TensorSpec spec_;
|
TensorSpec spec_;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue