export Swap Handle API (#635)

export vsi_nn_SwapHandle & vsi_nn_SwapTensorHandle &
vsi_nn_SwapTensorHandleWithCache for TIM-VX usage.

Type: New Feature

Signed-off-by: Tang Jing <jing.tang@verisilicon.com>
This commit is contained in:
Antkillerfarm 2023-08-28 09:15:43 +08:00 committed by GitHub
parent 7fc264a9e6
commit 3bbe2ef9ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 0 deletions

View File

@ -149,6 +149,12 @@ class Tensor {
virtual bool CopyDataToTensor(const void* data,
uint32_t size_in_bytes = 0) = 0;
virtual bool CopyDataFromTensor(void* data) = 0;
virtual bool SwapHandle(void* new_ptr, bool is_new_ptr_malloc_by_ovxlib,
void** old_ptr) = 0;
virtual bool SwapHandle(std::shared_ptr<tim::vx::Tensor> tensor) = 0;
#ifdef VSI_SWAP_HANDLE_CACHE_SUPPORT
virtual bool SwapHandleWithCache(std::shared_ptr<tim::vx::Tensor> tensor) = 0;
#endif
virtual bool FlushCacheForHandle() = 0;
virtual bool InvalidateCacheForHandle() = 0;
virtual void* map(bool invalidate_cpu_cache = false) = 0;

View File

@ -213,6 +213,62 @@ float* TensorImpl::ConvertTensorToFloat32Data() {
graph_->graph(), vsi_nn_GetTensor(graph_->graph(), id_));
}
bool TensorImpl::SwapHandle(void* new_ptr, bool is_new_ptr_malloc_by_ovxlib,
void** old_ptr) {
bool retn = true;
if (VSI_NN_TENSOR_ID_NA != id_) {
retn = false;
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
if (tensor && tensor->attr.is_created_from_handle) {
retn = (VSI_SUCCESS == vsi_nn_SwapHandle(tensor, new_ptr,
is_new_ptr_malloc_by_ovxlib,
old_ptr));
if (!retn) {
VSILOGE("SwapHandle fail");
}
}
}
return retn;
}
bool TensorImpl::SwapHandle(std::shared_ptr<tim::vx::Tensor> tensor) {
bool retn = true;
if (VSI_NN_TENSOR_ID_NA != id_) {
retn = false;
vsi_nn_tensor_t* tensor0 = vsi_nn_GetTensor(graph_->graph(), id_);
vsi_nn_tensor_t* tensor1 =
vsi_nn_GetTensor(graph_->graph(), tensor->GetId());
if (tensor0 && tensor0->attr.is_created_from_handle && tensor1 &&
tensor1->attr.is_created_from_handle) {
retn = (VSI_SUCCESS == vsi_nn_SwapTensorHandle(tensor0, tensor1));
if (!retn) {
VSILOGE("SwapHandle fail");
}
}
}
return retn;
}
#ifdef VSI_SWAP_HANDLE_CACHE_SUPPORT
bool TensorImpl::SwapHandleWithCache(std::shared_ptr<tim::vx::Tensor> tensor) {
bool retn = true;
if (VSI_NN_TENSOR_ID_NA != id_) {
retn = false;
vsi_nn_tensor_t* tensor0 = vsi_nn_GetTensor(graph_->graph(), id_);
vsi_nn_tensor_t* tensor1 =
vsi_nn_GetTensor(graph_->graph(), tensor->GetId());
if (tensor0 && tensor0->attr.is_created_from_handle && tensor1 &&
tensor1->attr.is_created_from_handle) {
retn = (VSI_SUCCESS == vsi_nn_SwapTensorHandleWithCache(graph_->graph(), tensor0, tensor1));
if (!retn) {
VSILOGE("SwapHandle fail");
}
}
}
return retn;
}
#endif
bool TensorImpl::FlushCacheForHandle() {
if (!(spec_.attr_ & TensorAttribute::INPUT)) {
return false;

View File

@ -48,6 +48,12 @@ class TensorImpl : public Tensor {
uint32_t GetId() override;
bool CopyDataToTensor(const void* data, uint32_t size = 0) override;
bool CopyDataFromTensor(void* data) override;
bool SwapHandle(void* new_ptr, bool is_new_ptr_malloc_by_ovxlib,
void** old_ptr) override;
bool SwapHandle(std::shared_ptr<tim::vx::Tensor> tensor) override;
#ifdef VSI_SWAP_HANDLE_CACHE_SUPPORT
bool SwapHandleWithCache(std::shared_ptr<tim::vx::Tensor> tensor) override;
#endif
bool FlushCacheForHandle() override;
bool InvalidateCacheForHandle() override;
void* map(bool invalidate_cpu_cache = false) override;
@ -84,6 +90,23 @@ class TensorPlaceholder : public Tensor {
(void)data;
return false;
}
bool SwapHandle(void* new_ptr, bool is_new_ptr_malloc_by_ovxlib,
void** old_ptr) override {
(void)new_ptr;
(void)old_ptr;
(void)is_new_ptr_malloc_by_ovxlib;
return false;
}
bool SwapHandle(std::shared_ptr<tim::vx::Tensor> tensor) override {
(void)tensor;
return false;
}
#ifdef VSI_SWAP_HANDLE_CACHE_SUPPORT
bool SwapHandleWithCache(std::shared_ptr<tim::vx::Tensor> tensor) override {
(void)tensor;
return false;
}
#endif
bool InvalidateCacheForHandle() override { return false; }
bool FlushCacheForHandle() override { return false; }
void* map(bool invalidate_cpu_cache = false) override {