feat(tensor): support external buffer when creating input/output tensors (#389)
* support external buffer when creating input/output tensors * feat(tensor): add new map/unmap APIs
This commit is contained in:
parent
a9764291b0
commit
f8741b4704
|
|
@ -45,6 +45,12 @@ class Graph {
|
|||
|
||||
virtual std::shared_ptr<Tensor> CreateTensor(const TensorSpec& spec,
|
||||
const DmaBufferDesc& dmafd) = 0;
|
||||
|
||||
/// Create a tensor with given `TensorSpec`.
|
||||
/// spec.attr_ must be TensorAttribute::Input or Output
|
||||
virtual std::shared_ptr<Tensor> CreateIOTensor(const TensorSpec& spec,
|
||||
void* data = nullptr) = 0;
|
||||
|
||||
/// Create a placeholder tensor for optional inputs of operations
|
||||
virtual std::shared_ptr<Tensor> CreateTensorPlaceHolder() = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -134,6 +134,10 @@ class Tensor {
|
|||
virtual uint32_t GetId() = 0;
|
||||
virtual bool CopyDataToTensor(const void* data, uint32_t size_in_bytes = 0) = 0;
|
||||
virtual bool CopyDataFromTensor(void* data) = 0;
|
||||
virtual bool FlushCacheForHandle() = 0;
|
||||
virtual bool InvalidateCacheForHandle() = 0;
|
||||
virtual void* map(bool invalidate_cpu_cache = false) = 0;
|
||||
virtual void unmap() = 0;
|
||||
virtual bool IsPlaceHolder() = 0;
|
||||
virtual bool IsConstTensor() = 0;
|
||||
virtual const void* GetDataRef() const = 0;
|
||||
|
|
|
|||
|
|
@ -141,6 +141,11 @@ std::shared_ptr<Tensor> GraphImpl::CreateTensor(const TensorSpec& spec,
|
|||
return std::make_shared<TensorImpl>(this, spec, dmafd);
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> GraphImpl::CreateIOTensor(const TensorSpec& spec,
|
||||
void* data) {
|
||||
return std::make_shared<TensorImpl>(this, spec, data);
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> GraphImpl::CreateTensorPlaceHolder() {
|
||||
if (!tensor_placeholder_) {
|
||||
tensor_placeholder_ = std::make_shared<TensorPlaceholder>(this);
|
||||
|
|
|
|||
|
|
@ -70,6 +70,8 @@ class GraphImpl : public Graph {
|
|||
const void* data = nullptr) override;
|
||||
std::shared_ptr<Tensor> CreateTensor(const TensorSpec& spec,
|
||||
const DmaBufferDesc& dmafd) override;
|
||||
std::shared_ptr<Tensor> CreateIOTensor(const TensorSpec& spec,
|
||||
void* data = nullptr) override;
|
||||
std::shared_ptr<Tensor> CreateTensorPlaceHolder() override;
|
||||
|
||||
bool Compile() override;
|
||||
|
|
|
|||
|
|
@ -79,18 +79,35 @@ TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, const void* data)
|
|||
: graph_(reinterpret_cast<GraphImpl*>(graph)),
|
||||
id_(VSI_NN_TENSOR_ID_NA),
|
||||
spec_(spec),
|
||||
data_(data) {
|
||||
data_(const_cast<void *>(data)) {
|
||||
Init();
|
||||
if (spec_.attr_ & (TensorAttribute::INPUT | TensorAttribute::OUTPUT)) {
|
||||
data_ = nullptr; // it's not needed to reset it in a constant tensor
|
||||
}
|
||||
}
|
||||
|
||||
TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, const DmaBufferDesc& dmafd)
|
||||
: graph_(reinterpret_cast<GraphImpl*>(graph)),
|
||||
id_(VSI_NN_TENSOR_ID_NA),
|
||||
spec_(spec),
|
||||
data_(nullptr),
|
||||
fd_(dmafd.fd) {
|
||||
Init();
|
||||
}
|
||||
|
||||
TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, void* data)
|
||||
: graph_(reinterpret_cast<GraphImpl*>(graph)),
|
||||
id_(VSI_NN_TENSOR_ID_NA),
|
||||
spec_(spec),
|
||||
data_(nullptr) {
|
||||
if (!(spec_.attr_ & (TensorAttribute::INPUT | TensorAttribute::OUTPUT))) {
|
||||
VSILOGE("TensorImpl with an external data got unexpected attr");
|
||||
return;
|
||||
}
|
||||
Init(data);
|
||||
data_ = data;
|
||||
}
|
||||
|
||||
TensorImpl::~TensorImpl() {}
|
||||
|
||||
bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) {
|
||||
|
|
@ -167,7 +184,95 @@ bool TensorImpl::CopyDataFromTensor(void* data) {
|
|||
return retn;
|
||||
}
|
||||
|
||||
bool TensorImpl::Init() {
|
||||
bool TensorImpl::FlushCacheForHandle() {
|
||||
if (!(spec_.attr_ & TensorAttribute::INPUT)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool retn = true;
|
||||
if (VSI_NN_TENSOR_ID_NA != id_) {
|
||||
retn = false;
|
||||
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
|
||||
if (tensor && tensor->attr.is_created_from_handle) {
|
||||
retn = (VSI_SUCCESS == vsi_nn_FlushHandle(tensor));
|
||||
if (!retn) {
|
||||
VSILOGE("FlushHandle fail");
|
||||
}
|
||||
}
|
||||
}
|
||||
return retn;
|
||||
}
|
||||
|
||||
bool TensorImpl::InvalidateCacheForHandle() {
|
||||
if (!(spec_.attr_ & TensorAttribute::OUTPUT)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool retn = true;
|
||||
if (VSI_NN_TENSOR_ID_NA != id_) {
|
||||
retn = false;
|
||||
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
|
||||
if (tensor && tensor->attr.is_created_from_handle) {
|
||||
void* ptr = NULL;
|
||||
retn = (VSI_SUCCESS == vsi_nn_GetTensorHandle(tensor, &ptr));
|
||||
if (!retn) {
|
||||
VSILOGE("GetTensorHandle fail");
|
||||
}
|
||||
}
|
||||
}
|
||||
return retn;
|
||||
}
|
||||
|
||||
void* TensorImpl::map(bool invalidate_cpu_cache) {
|
||||
if (!(spec_.attr_ & (TensorAttribute::INPUT | TensorAttribute::OUTPUT))) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void* cpu_ptr = nullptr;
|
||||
if (VSI_NN_TENSOR_ID_NA != id_) {
|
||||
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
|
||||
if (tensor && tensor->attr.is_created_from_handle) {
|
||||
// Here `cpu_cache` means L1/L2/... cache on a CPU chip.
|
||||
// If data_ has been updated by other devices like NPU,
|
||||
// then caches on CPU MUST be invalidated before reading.
|
||||
if (data_ && !invalidate_cpu_cache) {
|
||||
cpu_ptr = data_;
|
||||
} else {
|
||||
vsi_nn_GetTensorHandle(tensor, &cpu_ptr);
|
||||
// TODO: what to do when fd_ != -1
|
||||
}
|
||||
if (!cpu_ptr) {
|
||||
VSILOGE("GetTensorHandle fail");
|
||||
}
|
||||
}
|
||||
}
|
||||
return cpu_ptr;
|
||||
}
|
||||
|
||||
void TensorImpl::unmap() {
|
||||
if (!(spec_.attr_ & (TensorAttribute::INPUT | TensorAttribute::OUTPUT))) {
|
||||
return;
|
||||
}
|
||||
if (VSI_NN_TENSOR_ID_NA == id_) {
|
||||
return;
|
||||
}
|
||||
if (-1 == fd_) {
|
||||
if (data_ && spec_.attr_ & TensorAttribute::INPUT) {
|
||||
// Here data_ is an external buffer and may have been updated
|
||||
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
|
||||
if ( tensor && tensor->attr.is_created_from_handle) {
|
||||
bool retn = (VSI_SUCCESS == vsi_nn_FlushHandle(tensor));
|
||||
if (!retn) {
|
||||
VSILOGE("FlushHandle fail");
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
// TODO: unmap fd_
|
||||
}
|
||||
|
||||
bool TensorImpl::Init(void *external_cache) {
|
||||
vsi_nn_tensor_attr_t attr;
|
||||
|
||||
memset(&attr, 0x00, sizeof(attr));
|
||||
|
|
@ -198,11 +303,11 @@ bool TensorImpl::Init() {
|
|||
graph_->graph(),
|
||||
VSI_NN_TENSOR_ID_AUTO, // DMABUF's fd is created by TensorFromHandle as input or output,
|
||||
&attr,
|
||||
fd_ != -1 ? (uint8_t*)fd_ : nullptr); // and cannot be set to const
|
||||
fd_ != -1 ? (uint8_t*)fd_ : (uint8_t*)external_cache); // and cannot be set to const
|
||||
#else
|
||||
if (-1 == fd_) {
|
||||
id_ = vsi_nn_AddTensorFromHandle(graph_->graph(), VSI_NN_TENSOR_ID_AUTO,
|
||||
&attr, nullptr);
|
||||
&attr, (uint8_t*)external_cache);
|
||||
} else {
|
||||
id_ = 0xFFFFFFFF;
|
||||
VSILOGE("Create tensor fail: low-level driver doesn't support dmabuffer");
|
||||
|
|
|
|||
|
|
@ -34,9 +34,10 @@ class TensorImpl : public Tensor {
|
|||
public:
|
||||
TensorImpl(Graph* graph, const TensorSpec& spec, const void* data = nullptr);
|
||||
TensorImpl(Graph* graph, const TensorSpec& spec, const DmaBufferDesc& dmafd);
|
||||
TensorImpl(Graph* graph, const TensorSpec& spec, void* data = nullptr);
|
||||
~TensorImpl();
|
||||
|
||||
bool Init();
|
||||
bool Init(void *external_cache = nullptr);
|
||||
bool IsWriteable();
|
||||
bool IsReadable();
|
||||
|
||||
|
|
@ -47,6 +48,10 @@ class TensorImpl : public Tensor {
|
|||
uint32_t GetId();
|
||||
bool CopyDataToTensor(const void* data, uint32_t size = 0);
|
||||
bool CopyDataFromTensor(void* data);
|
||||
bool FlushCacheForHandle();
|
||||
bool InvalidateCacheForHandle();
|
||||
void* map(bool invalidate_cpu_cache = false);
|
||||
void unmap();
|
||||
bool IsPlaceHolder() { return false; }
|
||||
bool IsConstTensor() {
|
||||
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
|
||||
|
|
@ -56,7 +61,7 @@ class TensorImpl : public Tensor {
|
|||
GraphImpl* graph_;
|
||||
vsi_nn_tensor_id_t id_;
|
||||
TensorSpec spec_;
|
||||
const void* data_;
|
||||
void* data_;
|
||||
int64_t fd_{-1};
|
||||
};
|
||||
|
||||
|
|
@ -78,6 +83,13 @@ class TensorPlaceholder : public Tensor {
|
|||
(void)data;
|
||||
return false;
|
||||
}
|
||||
bool InvalidateCacheForHandle() { return false; }
|
||||
bool FlushCacheForHandle() { return false; }
|
||||
void* map(bool invalidate_cpu_cache = false) {
|
||||
(void)invalidate_cpu_cache;
|
||||
return nullptr;
|
||||
}
|
||||
void unmap() { return; }
|
||||
bool IsPlaceHolder() { return true; }
|
||||
bool IsConstTensor() {
|
||||
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
|
||||
|
|
|
|||
Loading…
Reference in New Issue