Add some tensor dtype convert APIs (#576)
For pre/post process. Type: Code Refine Co-authored-by: wangqian <wangqian@CNCDD9444.verisilicon.com>
This commit is contained in:
parent
c688ca6e81
commit
1543efe098
|
|
@ -28,6 +28,7 @@
|
|||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "tim/vx/types.h"
|
||||
|
||||
|
|
@ -145,8 +146,13 @@ class Tensor {
|
|||
virtual void unmap() = 0;
|
||||
virtual bool IsPlaceHolder() = 0;
|
||||
virtual bool IsConstTensor() = 0;
|
||||
virtual bool SaveTensorToTextByFp32(std::string filename) = 0;
|
||||
virtual void* ConvertTensorToData(uint8_t* tensorData) = 0;
|
||||
};
|
||||
|
||||
namespace utils{
|
||||
bool Float32ToDtype(std::shared_ptr<tim::vx::Tensor> tensor, std::vector<float> fval, uint8_t* tensorData);
|
||||
bool DtypeToFloat32(std::shared_ptr<tim::vx::Tensor> tensor, uint8_t* tensorData, float* data);
|
||||
} //namespace utils
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
||||
|
|
|
|||
|
|
@ -118,6 +118,18 @@ TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, void* data)
|
|||
|
||||
TensorImpl::~TensorImpl() {}
|
||||
|
||||
bool TensorImpl::SaveTensorToTextByFp32(std::string filename){
|
||||
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
|
||||
vsi_nn_SaveTensorToTextByFp32(graph_->graph(), tensor, filename.c_str(), NULL);
|
||||
return true;
|
||||
}
|
||||
|
||||
void* TensorImpl::ConvertTensorToData(uint8_t* tensorData){
|
||||
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
|
||||
tensorData = vsi_nn_ConvertTensorToData(graph_->graph(), tensor);
|
||||
return tensorData;
|
||||
}
|
||||
|
||||
bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) {
|
||||
(void)size_in_bytes;
|
||||
if (!IsWriteable()) {
|
||||
|
|
@ -432,6 +444,31 @@ int64_t TensorSpec::GetElementByteSize() const {
|
|||
int64_t TensorSpec::GetByteSize() const {
|
||||
return GetElementNum() * GetElementByteSize();
|
||||
}
|
||||
namespace utils{
|
||||
bool Float32ToDtype(std::shared_ptr<tim::vx::Tensor> tensor, std::vector<float> fval, uint8_t* tensorData){
|
||||
bool retn = true;
|
||||
vsi_nn_tensor_attr_t attr;
|
||||
uint32_t sz = tensor->GetSpec().GetElementNum();
|
||||
uint32_t stride = tensor->GetSpec().GetElementByteSize();
|
||||
PackTensorDtype(tensor->GetSpec(), &attr.dtype);
|
||||
for (uint32_t i = 0; i < sz; i++){
|
||||
retn = (VSI_SUCCESS == vsi_nn_Float32ToDtype(fval[i], &tensorData[i * stride], &attr.dtype));
|
||||
if (!retn) {
|
||||
VSILOGE("Convert data fail");
|
||||
return retn;
|
||||
}
|
||||
}
|
||||
return retn;
|
||||
}
|
||||
|
||||
bool DtypeToFloat32(std::shared_ptr<tim::vx::Tensor> tensor, uint8_t* tensorData, float* data){
|
||||
bool retn = true;
|
||||
vsi_nn_tensor_attr_t attr;
|
||||
|
||||
PackTensorDtype(tensor->GetSpec(), &attr.dtype);
|
||||
retn = (VSI_SUCCESS == vsi_nn_DtypeToFloat32(tensorData, data, &attr.dtype));
|
||||
return retn;
|
||||
}
|
||||
} //namespace utils
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
|
|||
|
|
@ -41,21 +41,23 @@ class TensorImpl : public Tensor {
|
|||
bool IsWriteable();
|
||||
bool IsReadable();
|
||||
|
||||
const ShapeType& GetShape() { return spec_.shape_; }
|
||||
DataType GetDataType() { return spec_.datatype_; }
|
||||
const Quantization& GetQuantization() { return spec_.quantization_; }
|
||||
TensorSpec& GetSpec() { return spec_; }
|
||||
uint32_t GetId();
|
||||
bool CopyDataToTensor(const void* data, uint32_t size = 0);
|
||||
bool CopyDataFromTensor(void* data);
|
||||
bool FlushCacheForHandle();
|
||||
bool InvalidateCacheForHandle();
|
||||
void* map(bool invalidate_cpu_cache = false);
|
||||
void unmap();
|
||||
bool IsPlaceHolder() { return false; }
|
||||
bool IsConstTensor() {
|
||||
const ShapeType& GetShape() override { return spec_.shape_; }
|
||||
DataType GetDataType() override { return spec_.datatype_; }
|
||||
const Quantization& GetQuantization() override { return spec_.quantization_; }
|
||||
TensorSpec& GetSpec() override { return spec_; }
|
||||
uint32_t GetId() override;
|
||||
bool CopyDataToTensor(const void* data, uint32_t size = 0) override;
|
||||
bool CopyDataFromTensor(void* data) override;
|
||||
bool FlushCacheForHandle() override;
|
||||
bool InvalidateCacheForHandle() override;
|
||||
void* map(bool invalidate_cpu_cache = false) override;
|
||||
void unmap() override;
|
||||
bool IsPlaceHolder() override { return false; }
|
||||
bool IsConstTensor() override {
|
||||
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
|
||||
}
|
||||
bool SaveTensorToTextByFp32(std::string filename) override;
|
||||
void* ConvertTensorToData(uint8_t* tensorData) override;
|
||||
|
||||
GraphImpl* graph_;
|
||||
vsi_nn_tensor_id_t id_;
|
||||
|
|
@ -69,30 +71,38 @@ class TensorPlaceholder : public Tensor {
|
|||
TensorPlaceholder(Graph* graph) : id_(VSI_NN_TENSOR_ID_NA) {(void)(graph);}
|
||||
~TensorPlaceholder(){};
|
||||
|
||||
const ShapeType& GetShape() { return spec_.shape_; }
|
||||
DataType GetDataType() { return spec_.datatype_; }
|
||||
const Quantization& GetQuantization() { return spec_.quantization_; }
|
||||
TensorSpec& GetSpec() { return spec_; }
|
||||
uint32_t GetId() { return id_; };
|
||||
bool CopyDataToTensor(const void* data, uint32_t size = 0) {
|
||||
const ShapeType& GetShape() override { return spec_.shape_; }
|
||||
DataType GetDataType() override { return spec_.datatype_; }
|
||||
const Quantization& GetQuantization() override { return spec_.quantization_; }
|
||||
TensorSpec& GetSpec() override { return spec_; }
|
||||
uint32_t GetId() override { return id_; };
|
||||
bool CopyDataToTensor(const void* data, uint32_t size = 0) override {
|
||||
(void)data, void(size);
|
||||
return false;
|
||||
}
|
||||
bool CopyDataFromTensor(void* data) {
|
||||
bool CopyDataFromTensor(void* data) override {
|
||||
(void)data;
|
||||
return false;
|
||||
}
|
||||
bool InvalidateCacheForHandle() { return false; }
|
||||
bool FlushCacheForHandle() { return false; }
|
||||
void* map(bool invalidate_cpu_cache = false) {
|
||||
bool InvalidateCacheForHandle() override { return false; }
|
||||
bool FlushCacheForHandle() override { return false; }
|
||||
void* map(bool invalidate_cpu_cache = false) override {
|
||||
(void)invalidate_cpu_cache;
|
||||
return nullptr;
|
||||
}
|
||||
void unmap() { return; }
|
||||
bool IsPlaceHolder() { return true; }
|
||||
bool IsConstTensor() {
|
||||
void unmap() override { return; }
|
||||
bool IsPlaceHolder() override { return true; }
|
||||
bool IsConstTensor() override {
|
||||
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
|
||||
}
|
||||
bool SaveTensorToTextByFp32(std::string filename) override {
|
||||
(void)filename;
|
||||
return false;
|
||||
}
|
||||
void* ConvertTensorToData(uint8_t* tensorData) override {
|
||||
(void)tensorData;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
vsi_nn_tensor_id_t id_;
|
||||
TensorSpec spec_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue