diff --git a/include/tim/lite/execution.h b/include/tim/lite/execution.h index 8ceaf5c..6d73a5b 100644 --- a/include/tim/lite/execution.h +++ b/include/tim/lite/execution.h @@ -34,14 +34,24 @@ namespace tim { namespace lite { class Execution { - public: - static std::shared_ptr Create( - const void* executable, size_t executable_size); - virtual Execution& BindInputs(const std::vector>& handles) = 0; - virtual Execution& BindOutputs(const std::vector>& handles) = 0; - virtual bool Trigger() = 0; + public: + static std::shared_ptr Create(const void* executable, + size_t executable_size); + virtual std::shared_ptr CreateInputHandle(uint32_t in_idx, + uint8_t* buffer, + size_t size) = 0; + virtual std::shared_ptr CreateOutputHandle(uint32_t out_idx, + uint8_t* buffer, + size_t size) = 0; + virtual Execution& BindInputs( + const std::vector>& handles) = 0; + virtual Execution& BindOutputs( + const std::vector>& handles) = 0; + virtual Execution& UnBindInput(const std::shared_ptr& Handle) = 0; + virtual Execution& UnBindOutput(const std::shared_ptr& handle) = 0; + virtual bool Trigger() = 0; }; -} -} +} // namespace lite +} // namespace tim #endif \ No newline at end of file diff --git a/include/tim/lite/handle.h b/include/tim/lite/handle.h index 3dcc136..7708894 100644 --- a/include/tim/lite/handle.h +++ b/include/tim/lite/handle.h @@ -30,21 +30,10 @@ namespace tim { namespace lite { -class HandleImpl; - class Handle { public: - std::unique_ptr& impl() { return impl_; } - bool Flush(); - bool Invalidate(); - protected: - std::unique_ptr impl_; -}; - -class UserHandle : public Handle { - public: - UserHandle(void* buffer, size_t size); - ~UserHandle(); + virtual bool Flush() = 0; + virtual bool Invalidate() = 0; }; } diff --git a/samples/lenet_lite/lenet_lite_asymu8.cc b/samples/lenet_lite/lenet_lite_asymu8.cc index e19db77..17ad36e 100644 --- a/samples/lenet_lite/lenet_lite_asymu8.cc +++ b/samples/lenet_lite/lenet_lite_asymu8.cc @@ -121,16 +121,28 @@ int main() { assert(input); assert(output); memset(output, 0, output_sz); - memcpy(input, input_data.data(), input_data.size()); - auto input_handle = std::make_shared( - input, input_data.size()); - auto output_handle = std::make_shared( - output, lenet_output_size * sizeof(float)); + auto input_handle = exec->CreateInputHandle(0, input, input_sz); + auto output_handle = exec->CreateOutputHandle(0, (uint8_t*)output, output_sz); + exec->BindInputs({input_handle}); exec->BindOutputs({output_handle}); + memcpy(input, input_data.data(), input_data.size()); + input_handle->Flush(); exec->Trigger(); + output_handle->Invalidate(); printTopN(output, lenet_output_size, 5); + + // rebind input and output + exec->UnBindInput(input_handle); + exec->UnBindOutput(output_handle); + exec->BindInputs({input_handle}); + exec->BindOutputs({output_handle}); + input_handle->Flush(); + exec->Trigger(); + output_handle->Invalidate(); + printTopN(output, lenet_output_size, 5); + free(output); free(input); } else { diff --git a/src/tim/lite/execution.cc b/src/tim/lite/execution.cc index a436140..c0c80c8 100644 --- a/src/tim/lite/execution.cc +++ b/src/tim/lite/execution.cc @@ -28,6 +28,9 @@ #include #include #include +#include +#include +#include #include "handle_private.h" #include "vip_lite.h" @@ -35,67 +38,6 @@ namespace tim { namespace lite { -namespace { -bool QueryInputBufferParameters( - vip_buffer_create_params_t& param, uint32_t index, vip_network network) { - uint32_t count = 0; - vip_query_network(network, VIP_NETWORK_PROP_INPUT_COUNT, &count); - if (index >= count) { - return false; - } - memset(¶m, 0, sizeof(param)); - param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT; - vip_query_input(network, index, VIP_BUFFER_PROP_DATA_FORMAT, ¶m.data_format); - vip_query_input(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, ¶m.num_of_dims); - vip_query_input(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes); - vip_query_input(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, ¶m.quant_format); - switch(param.quant_format) { - case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT: - vip_query_input(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS, - ¶m.quant_data.dfp.fixed_point_pos); - break; - case VIP_BUFFER_QUANTIZE_TF_ASYMM: - vip_query_input(network, index, VIP_BUFFER_PROP_TF_SCALE, - ¶m.quant_data.affine.scale); - vip_query_input(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT, - ¶m.quant_data.affine.zeroPoint); - default: - break; - } - return true; -} - -bool QueryOutputBufferParameters( - vip_buffer_create_params_t& param, uint32_t index, vip_network network) { - uint32_t count = 0; - vip_query_network(network, VIP_NETWORK_PROP_OUTPUT_COUNT, &count); - if (index >= count) { - return false; - } - memset(¶m, 0, sizeof(param)); - param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT; - vip_query_output(network, index, VIP_BUFFER_PROP_DATA_FORMAT, ¶m.data_format); - vip_query_output(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, ¶m.num_of_dims); - vip_query_output(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes); - vip_query_output(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, ¶m.quant_format); - switch(param.quant_format) { - case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT: - vip_query_output(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS, - ¶m.quant_data.dfp.fixed_point_pos); - break; - case VIP_BUFFER_QUANTIZE_TF_ASYMM: - vip_query_output(network, index, VIP_BUFFER_PROP_TF_SCALE, - ¶m.quant_data.affine.scale); - vip_query_output(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT, - ¶m.quant_data.affine.zeroPoint); - break; - default: - break; - } - return true; -} -} - ExecutionImpl::ExecutionImpl(const void* executable, size_t executable_size) { vip_status_e status = VIP_SUCCESS; vip_network network = nullptr; @@ -130,41 +72,44 @@ ExecutionImpl::~ExecutionImpl() { vip_finish_network(network_); vip_destroy_network(network_); } - input_maps_.clear(); - output_maps_.clear(); + input_handles_.clear(); + output_handles_.clear(); vip_destroy(); } +std::shared_ptr ExecutionImpl::CreateInputHandle(uint32_t in_idx, uint8_t* buffer, size_t size) { + auto handle = std::make_shared(buffer, size); + if (handle->CreateVipInputBuffer(network_, in_idx)) { + return handle; + } else { + return nullptr; + } +} + +std::shared_ptr ExecutionImpl::CreateOutputHandle(uint32_t out_idx, uint8_t* buffer, size_t size) { + auto handle = std::make_shared(buffer, size); + if (handle->CreateVipPOutputBuffer(network_, out_idx)) { + return handle; + } else { + return nullptr; + } +} + Execution& ExecutionImpl::BindInputs(const std::vector>& handles) { if (!IsValid()) { return *this; } - vip_status_e status = VIP_SUCCESS; - vip_buffer_create_params_t param; - for (uint32_t i = 0; i < handles.size(); i ++) { - auto handle = handles[i]; - if (!handle) { - status = VIP_ERROR_FAILURE; - break; - } - std::shared_ptr internal_handle = nullptr; - if (input_maps_.find(handle) == input_maps_.end()) { - if (!QueryInputBufferParameters(param, i, network_)) { - status = VIP_ERROR_FAILURE; - break; + for (auto handle : handles) { + if (input_handles_.end() == std::find(input_handles_.begin(), input_handles_.end(), handle)) { + input_handles_.push_back(handle); + auto handle_impl = std::dynamic_pointer_cast(handle); + vip_status_e status = vip_set_input(network_, handle_impl->Index(), handle_impl->VipHandle()); + if (status != VIP_SUCCESS) { + std::cout << "Set input for network failed." << std::endl; + assert(false); } - internal_handle = handle->impl()->Register(param); - if (!internal_handle) { - status = VIP_ERROR_FAILURE; - break; - } - input_maps_[handle] = internal_handle; } else { - internal_handle = input_maps_.at(handle); - } - status = vip_set_input(network_, i, internal_handle->handle()); - if (status != VIP_SUCCESS) { - break; + std::cout << "The input handle has been binded, need not bind it again." << std::endl; } } return *this; @@ -174,37 +119,38 @@ Execution& ExecutionImpl::BindOutputs(const std::vector> if (!IsValid()) { return *this; } - vip_status_e status = VIP_SUCCESS; - vip_buffer_create_params_t param; - for (uint32_t i = 0; i < handles.size(); i ++) { - auto handle = handles[i]; - if (!handle) { - status = VIP_ERROR_FAILURE; - break; - } - std::shared_ptr internal_handle = nullptr; - if (output_maps_.find(handle) == output_maps_.end()) { - if (!QueryOutputBufferParameters(param, i, network_)) { - status = VIP_ERROR_FAILURE; - break; + for (auto handle : handles) { + if (output_handles_.end() == std::find(output_handles_.begin(), output_handles_.end(), handle)) { + output_handles_.push_back(handle); + auto handle_impl = std::dynamic_pointer_cast(handle); + vip_status_e status = vip_set_output(network_, handle_impl->Index(), handle_impl->VipHandle()); + if (status != VIP_SUCCESS) { + std::cout << "Set output for network failed." << std::endl; + assert(false); } - internal_handle = handle->impl()->Register(param); - if (!internal_handle) { - status = VIP_ERROR_FAILURE; - break; - } - output_maps_[handle] = internal_handle; } else { - internal_handle = output_maps_.at(handle); - } - status = vip_set_output(network_, i, internal_handle->handle()); - if (status != VIP_SUCCESS) { - break; + std::cout << "The output handle has been binded, need not bind it again." << std::endl; } } return *this; }; +Execution& ExecutionImpl::UnBindInput(const std::shared_ptr& handle) { + auto it = std::find(input_handles_.begin(), input_handles_.end(), handle); + if (input_handles_.end() != it) { + input_handles_.erase(it); + } + return *this; +} + +Execution& ExecutionImpl::UnBindOutput(const std::shared_ptr& handle) { + auto it = std::find(output_handles_.begin(), output_handles_.end(), handle); + if (output_handles_.end() != it) { + output_handles_.erase(it); + } + return *this; +} + bool ExecutionImpl::Trigger() { if (!IsValid()) { return false; diff --git a/src/tim/lite/execution_private.h b/src/tim/lite/execution_private.h index 22e0bd3..01020c4 100644 --- a/src/tim/lite/execution_private.h +++ b/src/tim/lite/execution_private.h @@ -36,21 +36,30 @@ namespace tim { namespace lite { class ExecutionImpl : public Execution { - public : - ExecutionImpl(const void* executable, size_t executable_size); - ~ExecutionImpl(); - Execution& BindInputs(const std::vector>& handles) override; - Execution& BindOutputs(const std::vector>& handles) override; - bool Trigger() override; - bool IsValid() const { return valid_; }; - vip_network network() { return network_; }; - private: - std::map, std::shared_ptr> input_maps_; - std::map, std::shared_ptr> output_maps_; - bool valid_; - vip_network network_; + public: + ExecutionImpl(const void* executable, size_t executable_size); + ~ExecutionImpl(); + std::shared_ptr CreateInputHandle(uint32_t in_idx, uint8_t* buffer, + size_t size) override; + std::shared_ptr CreateOutputHandle(uint32_t out_idx, uint8_t* buffer, + size_t size) override; + Execution& BindInputs( + const std::vector>& handles) override; + Execution& BindOutputs( + const std::vector>& handles) override; + Execution& UnBindInput(const std::shared_ptr& Handle) override; + Execution& UnBindOutput(const std::shared_ptr& handle) override; + bool Trigger() override; + bool IsValid() const { return valid_; }; + vip_network network() { return network_; }; + + private: + std::vector> input_handles_; + std::vector> output_handles_; + bool valid_; + vip_network network_; }; -} -} +} // namespace lite +} // namespace tim #endif \ No newline at end of file diff --git a/src/tim/lite/handle.cc b/src/tim/lite/handle.cc index 31c7ede..2297a86 100644 --- a/src/tim/lite/handle.cc +++ b/src/tim/lite/handle.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include "execution_private.h" #include "vip_lite.h" @@ -36,70 +37,126 @@ namespace tim { namespace lite { -bool Handle::Flush() { - auto internal_handle = impl_->internal_handle(); - return internal_handle->Flush(HandleFlushType::HandleFlush); -} - -bool Handle::Invalidate() { - auto internal_handle = impl_->internal_handle(); - return internal_handle->Flush(HandleFlushType::HandleInvalidate); -} - -UserHandle::UserHandle(void* buffer, size_t size) { - assert((reinterpret_cast(buffer) % _64_BYTES_ALIGN) == 0); - impl_ = std::make_unique(buffer, size); -} - -UserHandle::~UserHandle() {} - -std::shared_ptr UserHandleImpl::Register( - vip_buffer_create_params_t& params) { - internal_handle_ = std::make_shared( - user_buffer_, user_buffer_size_, params); - if (!internal_handle_->handle()) { - internal_handle_.reset(); +namespace { +bool QueryInputBufferParameters( + vip_buffer_create_params_t& param, uint32_t index, vip_network network) { + uint32_t count = 0; + vip_query_network(network, VIP_NETWORK_PROP_INPUT_COUNT, &count); + if (index >= count) { + return false; } - return internal_handle_; -} - -InternalHandle::~InternalHandle() { - if (handle_) { - vip_destroy_buffer(handle_); - handle_ = nullptr; + memset(¶m, 0, sizeof(param)); + param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT; + vip_query_input(network, index, VIP_BUFFER_PROP_DATA_FORMAT, ¶m.data_format); + vip_query_input(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, ¶m.num_of_dims); + vip_query_input(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes); + vip_query_input(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, ¶m.quant_format); + switch(param.quant_format) { + case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT: + vip_query_input(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS, + ¶m.quant_data.dfp.fixed_point_pos); + break; + case VIP_BUFFER_QUANTIZE_TF_ASYMM: + vip_query_input(network, index, VIP_BUFFER_PROP_TF_SCALE, + ¶m.quant_data.affine.scale); + vip_query_input(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT, + ¶m.quant_data.affine.zeroPoint); + default: + break; } + return true; } -InternalUserHandle::InternalUserHandle(void* user_buffer, size_t user_buffer_size, - vip_buffer_create_params_t& params) { - vip_status_e status = VIP_SUCCESS; - vip_buffer internal_buffer = nullptr; - status = vip_create_buffer_from_handle(¶ms, - user_buffer, user_buffer_size, &internal_buffer); - if (status == VIP_SUCCESS) { - handle_ = internal_buffer; - } else { - handle_ = nullptr; +bool QueryOutputBufferParameters( + vip_buffer_create_params_t& param, uint32_t index, vip_network network) { + uint32_t count = 0; + vip_query_network(network, VIP_NETWORK_PROP_OUTPUT_COUNT, &count); + if (index >= count) { + return false; } + memset(¶m, 0, sizeof(param)); + param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT; + vip_query_output(network, index, VIP_BUFFER_PROP_DATA_FORMAT, ¶m.data_format); + vip_query_output(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, ¶m.num_of_dims); + vip_query_output(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes); + vip_query_output(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, ¶m.quant_format); + switch(param.quant_format) { + case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT: + vip_query_output(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS, + ¶m.quant_data.dfp.fixed_point_pos); + break; + case VIP_BUFFER_QUANTIZE_TF_ASYMM: + vip_query_output(network, index, VIP_BUFFER_PROP_TF_SCALE, + ¶m.quant_data.affine.scale); + vip_query_output(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT, + ¶m.quant_data.affine.zeroPoint); + break; + default: + break; + } + return true; +} } -bool InternalUserHandle::Flush(HandleFlushType type) { +bool HandleImpl::CreateVipInputBuffer(vip_network network, + uint32_t in_idx) { vip_status_e status = VIP_SUCCESS; - switch (type) { - case HandleFlushType::HandleFlush: { - status = vip_flush_buffer(handle_, VIP_BUFFER_OPER_TYPE_FLUSH); - break; - } - case HandleFlushType::HandleInvalidate: { - status = vip_flush_buffer(handle_, VIP_BUFFER_OPER_TYPE_INVALIDATE); - break; - } - default: - std::cout << __FUNCTION__ << ":" << __LINE__ << " Unkown HandleFlushType." - << std::endl; - assert(false); + vip_buffer_create_params_t param; + vip_buffer internal_buffer; + assert((reinterpret_cast(buffer_) % _64_BYTES_ALIGN) == 0); + if (!QueryInputBufferParameters(param, in_idx, network)) { + status = VIP_ERROR_FAILURE; + return false; + } + status = vip_create_buffer_from_handle(¶m, buffer_, buffer_size_, + &internal_buffer); + if (status == VIP_SUCCESS) { + handle_ = internal_buffer; + SetIndex(in_idx); + return true; + } else { + handle_ = nullptr; + return false; + } +} + +bool HandleImpl::CreateVipPOutputBuffer(vip_network network, + uint32_t out_idx) { + vip_status_e status = VIP_SUCCESS; + vip_buffer_create_params_t param; + vip_buffer internal_buffer; + assert((reinterpret_cast(buffer_) % _64_BYTES_ALIGN) == 0); + if (!QueryOutputBufferParameters(param, out_idx, network)) { + status = VIP_ERROR_FAILURE; + return false; + } + status = vip_create_buffer_from_handle(¶m, buffer_, buffer_size_, + &internal_buffer); + if (status == VIP_SUCCESS) { + handle_ = internal_buffer; + SetIndex(out_idx); + return true; + } else { + handle_ = nullptr; + return false; + } +} + +bool HandleImpl::Flush() { + vip_status_e status = vip_flush_buffer(handle_, VIP_BUFFER_OPER_TYPE_FLUSH); + return status == VIP_SUCCESS ? true : false; +} + +bool HandleImpl::Invalidate() { + vip_status_e status = vip_flush_buffer(handle_, VIP_BUFFER_OPER_TYPE_INVALIDATE); + return status == VIP_SUCCESS ? true : false; +} + +HandleImpl::~HandleImpl() { + if (handle_) { + vip_destroy_buffer(handle_); + handle_ = nullptr; } - return status == VIP_SUCCESS ? true : false; } } diff --git a/src/tim/lite/handle_private.h b/src/tim/lite/handle_private.h index a9e883a..66b3d32 100644 --- a/src/tim/lite/handle_private.h +++ b/src/tim/lite/handle_private.h @@ -37,46 +37,26 @@ enum class HandleFlushType { HandleInvalidate = 1 }; -class InternalHandle; +class HandleImpl : public Handle { + public: + HandleImpl(uint8_t* buffer, size_t size) + : buffer_(buffer), buffer_size_(size) {} -class HandleImpl { - public: - virtual std::shared_ptr Register( - vip_buffer_create_params_t& params) = 0; - virtual std::shared_ptr& internal_handle() = 0; -}; + bool CreateVipInputBuffer(vip_network network, uint32_t in_idx); + bool CreateVipPOutputBuffer(vip_network network, uint32_t out_idx); + uint32_t Index() { return index_; } + vip_buffer VipHandle() { return handle_; } + bool Flush() override; + bool Invalidate() override; -class UserHandleImpl : public HandleImpl { - public: - UserHandleImpl(void* buffer, size_t size) - : user_buffer_(buffer), user_buffer_size_(size) {} - std::shared_ptr Register( - vip_buffer_create_params_t& params) override; - std::shared_ptr& internal_handle() override { - return internal_handle_; - } - size_t user_buffer_size() const { return user_buffer_size_; } - void* user_buffer() { return user_buffer_; } - private: - void* user_buffer_; - size_t user_buffer_size_; - std::shared_ptr internal_handle_ = nullptr; -}; + ~HandleImpl(); -class InternalHandle { - public: - ~InternalHandle(); - virtual bool Flush(HandleFlushType type) = 0; - vip_buffer handle() { return handle_; }; - protected: - vip_buffer handle_; -}; - -class InternalUserHandle : public InternalHandle { - public: - InternalUserHandle(void* user_buffer, size_t user_buffer_size, - vip_buffer_create_params_t& params); - bool Flush(HandleFlushType type) override; + private: + void SetIndex(uint32_t idx) { index_ = idx; } + uint8_t* buffer_ = nullptr; + size_t buffer_size_ = 0; + vip_buffer handle_; + uint32_t index_; }; }