Add some tensor dtype convert APIs (#576)

For pre/post process.

Type: Code Refine

Co-authored-by: wangqian <wangqian@CNCDD9444.verisilicon.com>
This commit is contained in:
SCUWQ 2023-04-27 09:04:39 +08:00 committed by GitHub
parent c688ca6e81
commit 1543efe098
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 80 additions and 27 deletions

View File

@ -28,6 +28,7 @@
#include <iostream>
#include <memory>
#include <vector>
#include <string>
#include "tim/vx/types.h"
@ -145,8 +146,13 @@ class Tensor {
virtual void unmap() = 0;
virtual bool IsPlaceHolder() = 0;
virtual bool IsConstTensor() = 0;
virtual bool SaveTensorToTextByFp32(std::string filename) = 0;
virtual void* ConvertTensorToData(uint8_t* tensorData) = 0;
};
namespace utils{
bool Float32ToDtype(std::shared_ptr<tim::vx::Tensor> tensor, std::vector<float> fval, uint8_t* tensorData);
bool DtypeToFloat32(std::shared_ptr<tim::vx::Tensor> tensor, uint8_t* tensorData, float* data);
} //namespace utils
} // namespace vx
} // namespace tim

View File

@ -118,6 +118,18 @@ TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, void* data)
TensorImpl::~TensorImpl() {}
bool TensorImpl::SaveTensorToTextByFp32(std::string filename){
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
vsi_nn_SaveTensorToTextByFp32(graph_->graph(), tensor, filename.c_str(), NULL);
return true;
}
void* TensorImpl::ConvertTensorToData(uint8_t* tensorData){
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
tensorData = vsi_nn_ConvertTensorToData(graph_->graph(), tensor);
return tensorData;
}
bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) {
(void)size_in_bytes;
if (!IsWriteable()) {
@ -432,6 +444,31 @@ int64_t TensorSpec::GetElementByteSize() const {
int64_t TensorSpec::GetByteSize() const {
return GetElementNum() * GetElementByteSize();
}
namespace utils{
bool Float32ToDtype(std::shared_ptr<tim::vx::Tensor> tensor, std::vector<float> fval, uint8_t* tensorData){
bool retn = true;
vsi_nn_tensor_attr_t attr;
uint32_t sz = tensor->GetSpec().GetElementNum();
uint32_t stride = tensor->GetSpec().GetElementByteSize();
PackTensorDtype(tensor->GetSpec(), &attr.dtype);
for (uint32_t i = 0; i < sz; i++){
retn = (VSI_SUCCESS == vsi_nn_Float32ToDtype(fval[i], &tensorData[i * stride], &attr.dtype));
if (!retn) {
VSILOGE("Convert data fail");
return retn;
}
}
return retn;
}
bool DtypeToFloat32(std::shared_ptr<tim::vx::Tensor> tensor, uint8_t* tensorData, float* data){
bool retn = true;
vsi_nn_tensor_attr_t attr;
PackTensorDtype(tensor->GetSpec(), &attr.dtype);
retn = (VSI_SUCCESS == vsi_nn_DtypeToFloat32(tensorData, data, &attr.dtype));
return retn;
}
} //namespace utils
} // namespace vx
} // namespace tim

View File

@ -41,21 +41,23 @@ class TensorImpl : public Tensor {
bool IsWriteable();
bool IsReadable();
const ShapeType& GetShape() { return spec_.shape_; }
DataType GetDataType() { return spec_.datatype_; }
const Quantization& GetQuantization() { return spec_.quantization_; }
TensorSpec& GetSpec() { return spec_; }
uint32_t GetId();
bool CopyDataToTensor(const void* data, uint32_t size = 0);
bool CopyDataFromTensor(void* data);
bool FlushCacheForHandle();
bool InvalidateCacheForHandle();
void* map(bool invalidate_cpu_cache = false);
void unmap();
bool IsPlaceHolder() { return false; }
bool IsConstTensor() {
const ShapeType& GetShape() override { return spec_.shape_; }
DataType GetDataType() override { return spec_.datatype_; }
const Quantization& GetQuantization() override { return spec_.quantization_; }
TensorSpec& GetSpec() override { return spec_; }
uint32_t GetId() override;
bool CopyDataToTensor(const void* data, uint32_t size = 0) override;
bool CopyDataFromTensor(void* data) override;
bool FlushCacheForHandle() override;
bool InvalidateCacheForHandle() override;
void* map(bool invalidate_cpu_cache = false) override;
void unmap() override;
bool IsPlaceHolder() override { return false; }
bool IsConstTensor() override {
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
}
bool SaveTensorToTextByFp32(std::string filename) override;
void* ConvertTensorToData(uint8_t* tensorData) override;
GraphImpl* graph_;
vsi_nn_tensor_id_t id_;
@ -69,30 +71,38 @@ class TensorPlaceholder : public Tensor {
TensorPlaceholder(Graph* graph) : id_(VSI_NN_TENSOR_ID_NA) {(void)(graph);}
~TensorPlaceholder(){};
const ShapeType& GetShape() { return spec_.shape_; }
DataType GetDataType() { return spec_.datatype_; }
const Quantization& GetQuantization() { return spec_.quantization_; }
TensorSpec& GetSpec() { return spec_; }
uint32_t GetId() { return id_; };
bool CopyDataToTensor(const void* data, uint32_t size = 0) {
const ShapeType& GetShape() override { return spec_.shape_; }
DataType GetDataType() override { return spec_.datatype_; }
const Quantization& GetQuantization() override { return spec_.quantization_; }
TensorSpec& GetSpec() override { return spec_; }
uint32_t GetId() override { return id_; };
bool CopyDataToTensor(const void* data, uint32_t size = 0) override {
(void)data, void(size);
return false;
}
bool CopyDataFromTensor(void* data) {
bool CopyDataFromTensor(void* data) override {
(void)data;
return false;
}
bool InvalidateCacheForHandle() { return false; }
bool FlushCacheForHandle() { return false; }
void* map(bool invalidate_cpu_cache = false) {
bool InvalidateCacheForHandle() override { return false; }
bool FlushCacheForHandle() override { return false; }
void* map(bool invalidate_cpu_cache = false) override {
(void)invalidate_cpu_cache;
return nullptr;
}
void unmap() { return; }
bool IsPlaceHolder() { return true; }
bool IsConstTensor() {
void unmap() override { return; }
bool IsPlaceHolder() override { return true; }
bool IsConstTensor() override {
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
}
bool SaveTensorToTextByFp32(std::string filename) override {
(void)filename;
return false;
}
void* ConvertTensorToData(uint8_t* tensorData) override {
(void)tensorData;
return nullptr;
}
vsi_nn_tensor_id_t id_;
TensorSpec spec_;