diff --git a/include/tim/vx/graph.h b/include/tim/vx/graph.h index 33f4e45..08cef92 100644 --- a/include/tim/vx/graph.h +++ b/include/tim/vx/graph.h @@ -45,6 +45,12 @@ class Graph { virtual std::shared_ptr CreateTensor(const TensorSpec& spec, const DmaBufferDesc& dmafd) = 0; + + /// Create a tensor with given `TensorSpec`. + /// spec.attr_ must be TensorAttribute::Input or Output + virtual std::shared_ptr CreateIOTensor(const TensorSpec& spec, + void* data = nullptr) = 0; + /// Create a placeholder tensor for optional inputs of operations virtual std::shared_ptr CreateTensorPlaceHolder() = 0; diff --git a/include/tim/vx/tensor.h b/include/tim/vx/tensor.h index 910a701..572ea0a 100644 --- a/include/tim/vx/tensor.h +++ b/include/tim/vx/tensor.h @@ -134,6 +134,10 @@ class Tensor { virtual uint32_t GetId() = 0; virtual bool CopyDataToTensor(const void* data, uint32_t size_in_bytes = 0) = 0; virtual bool CopyDataFromTensor(void* data) = 0; + virtual bool FlushCacheForHandle() = 0; + virtual bool InvalidateCacheForHandle() = 0; + virtual void* map(bool invalidate_cpu_cache = false) = 0; + virtual void unmap() = 0; virtual bool IsPlaceHolder() = 0; virtual bool IsConstTensor() = 0; virtual const void* GetDataRef() const = 0; diff --git a/src/tim/vx/graph.cc b/src/tim/vx/graph.cc index b7fcfba..afb8491 100644 --- a/src/tim/vx/graph.cc +++ b/src/tim/vx/graph.cc @@ -141,6 +141,11 @@ std::shared_ptr GraphImpl::CreateTensor(const TensorSpec& spec, return std::make_shared(this, spec, dmafd); } +std::shared_ptr GraphImpl::CreateIOTensor(const TensorSpec& spec, + void* data) { + return std::make_shared(this, spec, data); +} + std::shared_ptr GraphImpl::CreateTensorPlaceHolder() { if (!tensor_placeholder_) { tensor_placeholder_ = std::make_shared(this); diff --git a/src/tim/vx/graph_private.h b/src/tim/vx/graph_private.h index 4bec654..9f80e69 100644 --- a/src/tim/vx/graph_private.h +++ b/src/tim/vx/graph_private.h @@ -70,6 +70,8 @@ class GraphImpl : public Graph { const void* data = nullptr) override; std::shared_ptr CreateTensor(const TensorSpec& spec, const DmaBufferDesc& dmafd) override; + std::shared_ptr CreateIOTensor(const TensorSpec& spec, + void* data = nullptr) override; std::shared_ptr CreateTensorPlaceHolder() override; bool Compile() override; diff --git a/src/tim/vx/tensor.cc b/src/tim/vx/tensor.cc index 2ee281a..0f357b6 100644 --- a/src/tim/vx/tensor.cc +++ b/src/tim/vx/tensor.cc @@ -79,18 +79,35 @@ TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, const void* data) : graph_(reinterpret_cast(graph)), id_(VSI_NN_TENSOR_ID_NA), spec_(spec), - data_(data) { + data_(const_cast(data)) { Init(); + if (spec_.attr_ & (TensorAttribute::INPUT | TensorAttribute::OUTPUT)) { + data_ = nullptr; // it's not needed to reset it in a constant tensor + } } TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, const DmaBufferDesc& dmafd) : graph_(reinterpret_cast(graph)), id_(VSI_NN_TENSOR_ID_NA), spec_(spec), + data_(nullptr), fd_(dmafd.fd) { Init(); } +TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, void* data) + : graph_(reinterpret_cast(graph)), + id_(VSI_NN_TENSOR_ID_NA), + spec_(spec), + data_(nullptr) { + if (!(spec_.attr_ & (TensorAttribute::INPUT | TensorAttribute::OUTPUT))) { + VSILOGE("TensorImpl with an external data got unexpected attr"); + return; + } + Init(data); + data_ = data; +} + TensorImpl::~TensorImpl() {} bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) { @@ -167,7 +184,95 @@ bool TensorImpl::CopyDataFromTensor(void* data) { return retn; } -bool TensorImpl::Init() { +bool TensorImpl::FlushCacheForHandle() { + if (!(spec_.attr_ & TensorAttribute::INPUT)) { + return false; + } + + bool retn = true; + if (VSI_NN_TENSOR_ID_NA != id_) { + retn = false; + vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); + if (tensor && tensor->attr.is_created_from_handle) { + retn = (VSI_SUCCESS == vsi_nn_FlushHandle(tensor)); + if (!retn) { + VSILOGE("FlushHandle fail"); + } + } + } + return retn; +} + +bool TensorImpl::InvalidateCacheForHandle() { + if (!(spec_.attr_ & TensorAttribute::OUTPUT)) { + return false; + } + + bool retn = true; + if (VSI_NN_TENSOR_ID_NA != id_) { + retn = false; + vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); + if (tensor && tensor->attr.is_created_from_handle) { + void* ptr = NULL; + retn = (VSI_SUCCESS == vsi_nn_GetTensorHandle(tensor, &ptr)); + if (!retn) { + VSILOGE("GetTensorHandle fail"); + } + } + } + return retn; +} + +void* TensorImpl::map(bool invalidate_cpu_cache) { + if (!(spec_.attr_ & (TensorAttribute::INPUT | TensorAttribute::OUTPUT))) { + return nullptr; + } + + void* cpu_ptr = nullptr; + if (VSI_NN_TENSOR_ID_NA != id_) { + vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); + if (tensor && tensor->attr.is_created_from_handle) { + // Here `cpu_cache` means L1/L2/... cache on a CPU chip. + // If data_ has been updated by other devices like NPU, + // then caches on CPU MUST be invalidated before reading. + if (data_ && !invalidate_cpu_cache) { + cpu_ptr = data_; + } else { + vsi_nn_GetTensorHandle(tensor, &cpu_ptr); + // TODO: what to do when fd_ != -1 + } + if (!cpu_ptr) { + VSILOGE("GetTensorHandle fail"); + } + } + } + return cpu_ptr; +} + +void TensorImpl::unmap() { + if (!(spec_.attr_ & (TensorAttribute::INPUT | TensorAttribute::OUTPUT))) { + return; + } + if (VSI_NN_TENSOR_ID_NA == id_) { + return; + } + if (-1 == fd_) { + if (data_ && spec_.attr_ & TensorAttribute::INPUT) { + // Here data_ is an external buffer and may have been updated + vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); + if ( tensor && tensor->attr.is_created_from_handle) { + bool retn = (VSI_SUCCESS == vsi_nn_FlushHandle(tensor)); + if (!retn) { + VSILOGE("FlushHandle fail"); + } + } + } + return; + } + // TODO: unmap fd_ +} + +bool TensorImpl::Init(void *external_cache) { vsi_nn_tensor_attr_t attr; memset(&attr, 0x00, sizeof(attr)); @@ -198,11 +303,11 @@ bool TensorImpl::Init() { graph_->graph(), VSI_NN_TENSOR_ID_AUTO, // DMABUF's fd is created by TensorFromHandle as input or output, &attr, - fd_ != -1 ? (uint8_t*)fd_ : nullptr); // and cannot be set to const + fd_ != -1 ? (uint8_t*)fd_ : (uint8_t*)external_cache); // and cannot be set to const #else if (-1 == fd_) { id_ = vsi_nn_AddTensorFromHandle(graph_->graph(), VSI_NN_TENSOR_ID_AUTO, - &attr, nullptr); + &attr, (uint8_t*)external_cache); } else { id_ = 0xFFFFFFFF; VSILOGE("Create tensor fail: low-level driver doesn't support dmabuffer"); diff --git a/src/tim/vx/tensor_private.h b/src/tim/vx/tensor_private.h index ce26d84..2c8e422 100644 --- a/src/tim/vx/tensor_private.h +++ b/src/tim/vx/tensor_private.h @@ -34,9 +34,10 @@ class TensorImpl : public Tensor { public: TensorImpl(Graph* graph, const TensorSpec& spec, const void* data = nullptr); TensorImpl(Graph* graph, const TensorSpec& spec, const DmaBufferDesc& dmafd); + TensorImpl(Graph* graph, const TensorSpec& spec, void* data = nullptr); ~TensorImpl(); - bool Init(); + bool Init(void *external_cache = nullptr); bool IsWriteable(); bool IsReadable(); @@ -47,6 +48,10 @@ class TensorImpl : public Tensor { uint32_t GetId(); bool CopyDataToTensor(const void* data, uint32_t size = 0); bool CopyDataFromTensor(void* data); + bool FlushCacheForHandle(); + bool InvalidateCacheForHandle(); + void* map(bool invalidate_cpu_cache = false); + void unmap(); bool IsPlaceHolder() { return false; } bool IsConstTensor() { return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT; @@ -56,7 +61,7 @@ class TensorImpl : public Tensor { GraphImpl* graph_; vsi_nn_tensor_id_t id_; TensorSpec spec_; - const void* data_; + void* data_; int64_t fd_{-1}; }; @@ -78,6 +83,13 @@ class TensorPlaceholder : public Tensor { (void)data; return false; } + bool InvalidateCacheForHandle() { return false; } + bool FlushCacheForHandle() { return false; } + void* map(bool invalidate_cpu_cache = false) { + (void)invalidate_cpu_cache; + return nullptr; + } + void unmap() { return; } bool IsPlaceHolder() { return true; } bool IsConstTensor() { return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;