From 3bbe2ef9ec4184c62cc493507206eb590a3c20d5 Mon Sep 17 00:00:00 2001 From: Antkillerfarm Date: Mon, 28 Aug 2023 09:15:43 +0800 Subject: [PATCH] export Swap Handle API (#635) export vsi_nn_SwapHandle & vsi_nn_SwapTensorHandle & vsi_nn_SwapTensorHandleWithCache for TIM-VX usage. Type: New Feature Signed-off-by: Tang Jing --- include/tim/vx/tensor.h | 6 ++++ src/tim/vx/tensor.cc | 56 +++++++++++++++++++++++++++++++++++++ src/tim/vx/tensor_private.h | 23 +++++++++++++++ 3 files changed, 85 insertions(+) diff --git a/include/tim/vx/tensor.h b/include/tim/vx/tensor.h index 0d26d7a..12845c5 100644 --- a/include/tim/vx/tensor.h +++ b/include/tim/vx/tensor.h @@ -149,6 +149,12 @@ class Tensor { virtual bool CopyDataToTensor(const void* data, uint32_t size_in_bytes = 0) = 0; virtual bool CopyDataFromTensor(void* data) = 0; + virtual bool SwapHandle(void* new_ptr, bool is_new_ptr_malloc_by_ovxlib, + void** old_ptr) = 0; + virtual bool SwapHandle(std::shared_ptr tensor) = 0; +#ifdef VSI_SWAP_HANDLE_CACHE_SUPPORT + virtual bool SwapHandleWithCache(std::shared_ptr tensor) = 0; +#endif virtual bool FlushCacheForHandle() = 0; virtual bool InvalidateCacheForHandle() = 0; virtual void* map(bool invalidate_cpu_cache = false) = 0; diff --git a/src/tim/vx/tensor.cc b/src/tim/vx/tensor.cc index e424843..973d078 100644 --- a/src/tim/vx/tensor.cc +++ b/src/tim/vx/tensor.cc @@ -213,6 +213,62 @@ float* TensorImpl::ConvertTensorToFloat32Data() { graph_->graph(), vsi_nn_GetTensor(graph_->graph(), id_)); } +bool TensorImpl::SwapHandle(void* new_ptr, bool is_new_ptr_malloc_by_ovxlib, + void** old_ptr) { + bool retn = true; + if (VSI_NN_TENSOR_ID_NA != id_) { + retn = false; + vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); + if (tensor && tensor->attr.is_created_from_handle) { + retn = (VSI_SUCCESS == vsi_nn_SwapHandle(tensor, new_ptr, + is_new_ptr_malloc_by_ovxlib, + old_ptr)); + if (!retn) { + VSILOGE("SwapHandle fail"); + } + } + } + return retn; +} + +bool TensorImpl::SwapHandle(std::shared_ptr tensor) { + bool retn = true; + if (VSI_NN_TENSOR_ID_NA != id_) { + retn = false; + vsi_nn_tensor_t* tensor0 = vsi_nn_GetTensor(graph_->graph(), id_); + vsi_nn_tensor_t* tensor1 = + vsi_nn_GetTensor(graph_->graph(), tensor->GetId()); + if (tensor0 && tensor0->attr.is_created_from_handle && tensor1 && + tensor1->attr.is_created_from_handle) { + retn = (VSI_SUCCESS == vsi_nn_SwapTensorHandle(tensor0, tensor1)); + if (!retn) { + VSILOGE("SwapHandle fail"); + } + } + } + return retn; +} + +#ifdef VSI_SWAP_HANDLE_CACHE_SUPPORT +bool TensorImpl::SwapHandleWithCache(std::shared_ptr tensor) { + bool retn = true; + if (VSI_NN_TENSOR_ID_NA != id_) { + retn = false; + vsi_nn_tensor_t* tensor0 = vsi_nn_GetTensor(graph_->graph(), id_); + vsi_nn_tensor_t* tensor1 = + vsi_nn_GetTensor(graph_->graph(), tensor->GetId()); + if (tensor0 && tensor0->attr.is_created_from_handle && tensor1 && + tensor1->attr.is_created_from_handle) { + retn = (VSI_SUCCESS == vsi_nn_SwapTensorHandleWithCache(graph_->graph(), tensor0, tensor1)); + if (!retn) { + VSILOGE("SwapHandle fail"); + } + } + } + return retn; +} +#endif + bool TensorImpl::FlushCacheForHandle() { if (!(spec_.attr_ & TensorAttribute::INPUT)) { return false; diff --git a/src/tim/vx/tensor_private.h b/src/tim/vx/tensor_private.h index 2470e8d..5e14d75 100644 --- a/src/tim/vx/tensor_private.h +++ b/src/tim/vx/tensor_private.h @@ -48,6 +48,12 @@ class TensorImpl : public Tensor { uint32_t GetId() override; bool CopyDataToTensor(const void* data, uint32_t size = 0) override; bool CopyDataFromTensor(void* data) override; + bool SwapHandle(void* new_ptr, bool is_new_ptr_malloc_by_ovxlib, + void** old_ptr) override; + bool SwapHandle(std::shared_ptr tensor) override; +#ifdef VSI_SWAP_HANDLE_CACHE_SUPPORT + bool SwapHandleWithCache(std::shared_ptr tensor) override; +#endif bool FlushCacheForHandle() override; bool InvalidateCacheForHandle() override; void* map(bool invalidate_cpu_cache = false) override; @@ -84,6 +90,23 @@ class TensorPlaceholder : public Tensor { (void)data; return false; } + bool SwapHandle(void* new_ptr, bool is_new_ptr_malloc_by_ovxlib, + void** old_ptr) override { + (void)new_ptr; + (void)old_ptr; + (void)is_new_ptr_malloc_by_ovxlib; + return false; + } + bool SwapHandle(std::shared_ptr tensor) override { + (void)tensor; + return false; + } +#ifdef VSI_SWAP_HANDLE_CACHE_SUPPORT + bool SwapHandleWithCache(std::shared_ptr tensor) override { + (void)tensor; + return false; + } +#endif bool InvalidateCacheForHandle() override { return false; } bool FlushCacheForHandle() override { return false; } void* map(bool invalidate_cpu_cache = false) override {