Refine Lite API (#221)

Signed-off-by: Zongwu Yang <zongwu.yang@verisilicon.com>
This commit is contained in:
Zongwu.Yang 2021-11-19 20:30:26 +08:00 committed by GitHub
parent 0ca4970d72
commit c90efe70c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 247 additions and 244 deletions

View File

@ -34,14 +34,24 @@ namespace tim {
namespace lite { namespace lite {
class Execution { class Execution {
public: public:
static std::shared_ptr<Execution> Create( static std::shared_ptr<Execution> Create(const void* executable,
const void* executable, size_t executable_size); size_t executable_size);
virtual Execution& BindInputs(const std::vector<std::shared_ptr<Handle>>& handles) = 0; virtual std::shared_ptr<Handle> CreateInputHandle(uint32_t in_idx,
virtual Execution& BindOutputs(const std::vector<std::shared_ptr<Handle>>& handles) = 0; uint8_t* buffer,
virtual bool Trigger() = 0; size_t size) = 0;
virtual std::shared_ptr<Handle> CreateOutputHandle(uint32_t out_idx,
uint8_t* buffer,
size_t size) = 0;
virtual Execution& BindInputs(
const std::vector<std::shared_ptr<Handle>>& handles) = 0;
virtual Execution& BindOutputs(
const std::vector<std::shared_ptr<Handle>>& handles) = 0;
virtual Execution& UnBindInput(const std::shared_ptr<Handle>& Handle) = 0;
virtual Execution& UnBindOutput(const std::shared_ptr<Handle>& handle) = 0;
virtual bool Trigger() = 0;
}; };
} } // namespace lite
} } // namespace tim
#endif #endif

View File

@ -30,21 +30,10 @@
namespace tim { namespace tim {
namespace lite { namespace lite {
class HandleImpl;
class Handle { class Handle {
public: public:
std::unique_ptr<HandleImpl>& impl() { return impl_; } virtual bool Flush() = 0;
bool Flush(); virtual bool Invalidate() = 0;
bool Invalidate();
protected:
std::unique_ptr<HandleImpl> impl_;
};
class UserHandle : public Handle {
public:
UserHandle(void* buffer, size_t size);
~UserHandle();
}; };
} }

View File

@ -121,16 +121,28 @@ int main() {
assert(input); assert(input);
assert(output); assert(output);
memset(output, 0, output_sz); memset(output, 0, output_sz);
memcpy(input, input_data.data(), input_data.size());
auto input_handle = std::make_shared<tim::lite::UserHandle>( auto input_handle = exec->CreateInputHandle(0, input, input_sz);
input, input_data.size()); auto output_handle = exec->CreateOutputHandle(0, (uint8_t*)output, output_sz);
auto output_handle = std::make_shared<tim::lite::UserHandle>(
output, lenet_output_size * sizeof(float));
exec->BindInputs({input_handle}); exec->BindInputs({input_handle});
exec->BindOutputs({output_handle}); exec->BindOutputs({output_handle});
memcpy(input, input_data.data(), input_data.size());
input_handle->Flush();
exec->Trigger(); exec->Trigger();
output_handle->Invalidate();
printTopN(output, lenet_output_size, 5); printTopN(output, lenet_output_size, 5);
// rebind input and output
exec->UnBindInput(input_handle);
exec->UnBindOutput(output_handle);
exec->BindInputs({input_handle});
exec->BindOutputs({output_handle});
input_handle->Flush();
exec->Trigger();
output_handle->Invalidate();
printTopN(output, lenet_output_size, 5);
free(output); free(output);
free(input); free(input);
} else { } else {

View File

@ -28,6 +28,9 @@
#include <cstring> #include <cstring>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <algorithm>
#include <iostream>
#include <cassert>
#include "handle_private.h" #include "handle_private.h"
#include "vip_lite.h" #include "vip_lite.h"
@ -35,67 +38,6 @@
namespace tim { namespace tim {
namespace lite { namespace lite {
namespace {
bool QueryInputBufferParameters(
vip_buffer_create_params_t& param, uint32_t index, vip_network network) {
uint32_t count = 0;
vip_query_network(network, VIP_NETWORK_PROP_INPUT_COUNT, &count);
if (index >= count) {
return false;
}
memset(&param, 0, sizeof(param));
param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT;
vip_query_input(network, index, VIP_BUFFER_PROP_DATA_FORMAT, &param.data_format);
vip_query_input(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, &param.num_of_dims);
vip_query_input(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes);
vip_query_input(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, &param.quant_format);
switch(param.quant_format) {
case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT:
vip_query_input(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS,
&param.quant_data.dfp.fixed_point_pos);
break;
case VIP_BUFFER_QUANTIZE_TF_ASYMM:
vip_query_input(network, index, VIP_BUFFER_PROP_TF_SCALE,
&param.quant_data.affine.scale);
vip_query_input(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT,
&param.quant_data.affine.zeroPoint);
default:
break;
}
return true;
}
bool QueryOutputBufferParameters(
vip_buffer_create_params_t& param, uint32_t index, vip_network network) {
uint32_t count = 0;
vip_query_network(network, VIP_NETWORK_PROP_OUTPUT_COUNT, &count);
if (index >= count) {
return false;
}
memset(&param, 0, sizeof(param));
param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT;
vip_query_output(network, index, VIP_BUFFER_PROP_DATA_FORMAT, &param.data_format);
vip_query_output(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, &param.num_of_dims);
vip_query_output(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes);
vip_query_output(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, &param.quant_format);
switch(param.quant_format) {
case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT:
vip_query_output(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS,
&param.quant_data.dfp.fixed_point_pos);
break;
case VIP_BUFFER_QUANTIZE_TF_ASYMM:
vip_query_output(network, index, VIP_BUFFER_PROP_TF_SCALE,
&param.quant_data.affine.scale);
vip_query_output(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT,
&param.quant_data.affine.zeroPoint);
break;
default:
break;
}
return true;
}
}
ExecutionImpl::ExecutionImpl(const void* executable, size_t executable_size) { ExecutionImpl::ExecutionImpl(const void* executable, size_t executable_size) {
vip_status_e status = VIP_SUCCESS; vip_status_e status = VIP_SUCCESS;
vip_network network = nullptr; vip_network network = nullptr;
@ -130,41 +72,44 @@ ExecutionImpl::~ExecutionImpl() {
vip_finish_network(network_); vip_finish_network(network_);
vip_destroy_network(network_); vip_destroy_network(network_);
} }
input_maps_.clear(); input_handles_.clear();
output_maps_.clear(); output_handles_.clear();
vip_destroy(); vip_destroy();
} }
std::shared_ptr<Handle> ExecutionImpl::CreateInputHandle(uint32_t in_idx, uint8_t* buffer, size_t size) {
auto handle = std::make_shared<HandleImpl>(buffer, size);
if (handle->CreateVipInputBuffer(network_, in_idx)) {
return handle;
} else {
return nullptr;
}
}
std::shared_ptr<Handle> ExecutionImpl::CreateOutputHandle(uint32_t out_idx, uint8_t* buffer, size_t size) {
auto handle = std::make_shared<HandleImpl>(buffer, size);
if (handle->CreateVipPOutputBuffer(network_, out_idx)) {
return handle;
} else {
return nullptr;
}
}
Execution& ExecutionImpl::BindInputs(const std::vector<std::shared_ptr<Handle>>& handles) { Execution& ExecutionImpl::BindInputs(const std::vector<std::shared_ptr<Handle>>& handles) {
if (!IsValid()) { if (!IsValid()) {
return *this; return *this;
} }
vip_status_e status = VIP_SUCCESS; for (auto handle : handles) {
vip_buffer_create_params_t param; if (input_handles_.end() == std::find(input_handles_.begin(), input_handles_.end(), handle)) {
for (uint32_t i = 0; i < handles.size(); i ++) { input_handles_.push_back(handle);
auto handle = handles[i]; auto handle_impl = std::dynamic_pointer_cast<HandleImpl>(handle);
if (!handle) { vip_status_e status = vip_set_input(network_, handle_impl->Index(), handle_impl->VipHandle());
status = VIP_ERROR_FAILURE; if (status != VIP_SUCCESS) {
break; std::cout << "Set input for network failed." << std::endl;
} assert(false);
std::shared_ptr<InternalHandle> internal_handle = nullptr;
if (input_maps_.find(handle) == input_maps_.end()) {
if (!QueryInputBufferParameters(param, i, network_)) {
status = VIP_ERROR_FAILURE;
break;
} }
internal_handle = handle->impl()->Register(param);
if (!internal_handle) {
status = VIP_ERROR_FAILURE;
break;
}
input_maps_[handle] = internal_handle;
} else { } else {
internal_handle = input_maps_.at(handle); std::cout << "The input handle has been binded, need not bind it again." << std::endl;
}
status = vip_set_input(network_, i, internal_handle->handle());
if (status != VIP_SUCCESS) {
break;
} }
} }
return *this; return *this;
@ -174,37 +119,38 @@ Execution& ExecutionImpl::BindOutputs(const std::vector<std::shared_ptr<Handle>>
if (!IsValid()) { if (!IsValid()) {
return *this; return *this;
} }
vip_status_e status = VIP_SUCCESS; for (auto handle : handles) {
vip_buffer_create_params_t param; if (output_handles_.end() == std::find(output_handles_.begin(), output_handles_.end(), handle)) {
for (uint32_t i = 0; i < handles.size(); i ++) { output_handles_.push_back(handle);
auto handle = handles[i]; auto handle_impl = std::dynamic_pointer_cast<HandleImpl>(handle);
if (!handle) { vip_status_e status = vip_set_output(network_, handle_impl->Index(), handle_impl->VipHandle());
status = VIP_ERROR_FAILURE; if (status != VIP_SUCCESS) {
break; std::cout << "Set output for network failed." << std::endl;
} assert(false);
std::shared_ptr<InternalHandle> internal_handle = nullptr;
if (output_maps_.find(handle) == output_maps_.end()) {
if (!QueryOutputBufferParameters(param, i, network_)) {
status = VIP_ERROR_FAILURE;
break;
} }
internal_handle = handle->impl()->Register(param);
if (!internal_handle) {
status = VIP_ERROR_FAILURE;
break;
}
output_maps_[handle] = internal_handle;
} else { } else {
internal_handle = output_maps_.at(handle); std::cout << "The output handle has been binded, need not bind it again." << std::endl;
}
status = vip_set_output(network_, i, internal_handle->handle());
if (status != VIP_SUCCESS) {
break;
} }
} }
return *this; return *this;
}; };
Execution& ExecutionImpl::UnBindInput(const std::shared_ptr<Handle>& handle) {
auto it = std::find(input_handles_.begin(), input_handles_.end(), handle);
if (input_handles_.end() != it) {
input_handles_.erase(it);
}
return *this;
}
Execution& ExecutionImpl::UnBindOutput(const std::shared_ptr<Handle>& handle) {
auto it = std::find(output_handles_.begin(), output_handles_.end(), handle);
if (output_handles_.end() != it) {
output_handles_.erase(it);
}
return *this;
}
bool ExecutionImpl::Trigger() { bool ExecutionImpl::Trigger() {
if (!IsValid()) { if (!IsValid()) {
return false; return false;

View File

@ -36,21 +36,30 @@ namespace tim {
namespace lite { namespace lite {
class ExecutionImpl : public Execution { class ExecutionImpl : public Execution {
public : public:
ExecutionImpl(const void* executable, size_t executable_size); ExecutionImpl(const void* executable, size_t executable_size);
~ExecutionImpl(); ~ExecutionImpl();
Execution& BindInputs(const std::vector<std::shared_ptr<Handle>>& handles) override; std::shared_ptr<Handle> CreateInputHandle(uint32_t in_idx, uint8_t* buffer,
Execution& BindOutputs(const std::vector<std::shared_ptr<Handle>>& handles) override; size_t size) override;
bool Trigger() override; std::shared_ptr<Handle> CreateOutputHandle(uint32_t out_idx, uint8_t* buffer,
bool IsValid() const { return valid_; }; size_t size) override;
vip_network network() { return network_; }; Execution& BindInputs(
private: const std::vector<std::shared_ptr<Handle>>& handles) override;
std::map<std::shared_ptr<Handle>, std::shared_ptr<InternalHandle>> input_maps_; Execution& BindOutputs(
std::map<std::shared_ptr<Handle>, std::shared_ptr<InternalHandle>> output_maps_; const std::vector<std::shared_ptr<Handle>>& handles) override;
bool valid_; Execution& UnBindInput(const std::shared_ptr<Handle>& Handle) override;
vip_network network_; Execution& UnBindOutput(const std::shared_ptr<Handle>& handle) override;
bool Trigger() override;
bool IsValid() const { return valid_; };
vip_network network() { return network_; };
private:
std::vector<std::shared_ptr<Handle>> input_handles_;
std::vector<std::shared_ptr<Handle>> output_handles_;
bool valid_;
vip_network network_;
}; };
} } // namespace lite
} } // namespace tim
#endif #endif

View File

@ -28,6 +28,7 @@
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <iostream> #include <iostream>
#include <string.h>
#include "execution_private.h" #include "execution_private.h"
#include "vip_lite.h" #include "vip_lite.h"
@ -36,70 +37,126 @@
namespace tim { namespace tim {
namespace lite { namespace lite {
bool Handle::Flush() { namespace {
auto internal_handle = impl_->internal_handle(); bool QueryInputBufferParameters(
return internal_handle->Flush(HandleFlushType::HandleFlush); vip_buffer_create_params_t& param, uint32_t index, vip_network network) {
} uint32_t count = 0;
vip_query_network(network, VIP_NETWORK_PROP_INPUT_COUNT, &count);
bool Handle::Invalidate() { if (index >= count) {
auto internal_handle = impl_->internal_handle(); return false;
return internal_handle->Flush(HandleFlushType::HandleInvalidate);
}
UserHandle::UserHandle(void* buffer, size_t size) {
assert((reinterpret_cast<uintptr_t>(buffer) % _64_BYTES_ALIGN) == 0);
impl_ = std::make_unique<UserHandleImpl>(buffer, size);
}
UserHandle::~UserHandle() {}
std::shared_ptr<InternalHandle> UserHandleImpl::Register(
vip_buffer_create_params_t& params) {
internal_handle_ = std::make_shared<InternalUserHandle>(
user_buffer_, user_buffer_size_, params);
if (!internal_handle_->handle()) {
internal_handle_.reset();
} }
return internal_handle_; memset(&param, 0, sizeof(param));
} param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT;
vip_query_input(network, index, VIP_BUFFER_PROP_DATA_FORMAT, &param.data_format);
InternalHandle::~InternalHandle() { vip_query_input(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, &param.num_of_dims);
if (handle_) { vip_query_input(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes);
vip_destroy_buffer(handle_); vip_query_input(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, &param.quant_format);
handle_ = nullptr; switch(param.quant_format) {
case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT:
vip_query_input(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS,
&param.quant_data.dfp.fixed_point_pos);
break;
case VIP_BUFFER_QUANTIZE_TF_ASYMM:
vip_query_input(network, index, VIP_BUFFER_PROP_TF_SCALE,
&param.quant_data.affine.scale);
vip_query_input(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT,
&param.quant_data.affine.zeroPoint);
default:
break;
} }
return true;
} }
InternalUserHandle::InternalUserHandle(void* user_buffer, size_t user_buffer_size, bool QueryOutputBufferParameters(
vip_buffer_create_params_t& params) { vip_buffer_create_params_t& param, uint32_t index, vip_network network) {
vip_status_e status = VIP_SUCCESS; uint32_t count = 0;
vip_buffer internal_buffer = nullptr; vip_query_network(network, VIP_NETWORK_PROP_OUTPUT_COUNT, &count);
status = vip_create_buffer_from_handle(&params, if (index >= count) {
user_buffer, user_buffer_size, &internal_buffer); return false;
if (status == VIP_SUCCESS) {
handle_ = internal_buffer;
} else {
handle_ = nullptr;
} }
memset(&param, 0, sizeof(param));
param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT;
vip_query_output(network, index, VIP_BUFFER_PROP_DATA_FORMAT, &param.data_format);
vip_query_output(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, &param.num_of_dims);
vip_query_output(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes);
vip_query_output(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, &param.quant_format);
switch(param.quant_format) {
case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT:
vip_query_output(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS,
&param.quant_data.dfp.fixed_point_pos);
break;
case VIP_BUFFER_QUANTIZE_TF_ASYMM:
vip_query_output(network, index, VIP_BUFFER_PROP_TF_SCALE,
&param.quant_data.affine.scale);
vip_query_output(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT,
&param.quant_data.affine.zeroPoint);
break;
default:
break;
}
return true;
}
} }
bool InternalUserHandle::Flush(HandleFlushType type) { bool HandleImpl::CreateVipInputBuffer(vip_network network,
uint32_t in_idx) {
vip_status_e status = VIP_SUCCESS; vip_status_e status = VIP_SUCCESS;
switch (type) { vip_buffer_create_params_t param;
case HandleFlushType::HandleFlush: { vip_buffer internal_buffer;
status = vip_flush_buffer(handle_, VIP_BUFFER_OPER_TYPE_FLUSH); assert((reinterpret_cast<uintptr_t>(buffer_) % _64_BYTES_ALIGN) == 0);
break; if (!QueryInputBufferParameters(param, in_idx, network)) {
} status = VIP_ERROR_FAILURE;
case HandleFlushType::HandleInvalidate: { return false;
status = vip_flush_buffer(handle_, VIP_BUFFER_OPER_TYPE_INVALIDATE); }
break; status = vip_create_buffer_from_handle(&param, buffer_, buffer_size_,
} &internal_buffer);
default: if (status == VIP_SUCCESS) {
std::cout << __FUNCTION__ << ":" << __LINE__ << " Unkown HandleFlushType." handle_ = internal_buffer;
<< std::endl; SetIndex(in_idx);
assert(false); return true;
} else {
handle_ = nullptr;
return false;
}
}
bool HandleImpl::CreateVipPOutputBuffer(vip_network network,
uint32_t out_idx) {
vip_status_e status = VIP_SUCCESS;
vip_buffer_create_params_t param;
vip_buffer internal_buffer;
assert((reinterpret_cast<uintptr_t>(buffer_) % _64_BYTES_ALIGN) == 0);
if (!QueryOutputBufferParameters(param, out_idx, network)) {
status = VIP_ERROR_FAILURE;
return false;
}
status = vip_create_buffer_from_handle(&param, buffer_, buffer_size_,
&internal_buffer);
if (status == VIP_SUCCESS) {
handle_ = internal_buffer;
SetIndex(out_idx);
return true;
} else {
handle_ = nullptr;
return false;
}
}
bool HandleImpl::Flush() {
vip_status_e status = vip_flush_buffer(handle_, VIP_BUFFER_OPER_TYPE_FLUSH);
return status == VIP_SUCCESS ? true : false;
}
bool HandleImpl::Invalidate() {
vip_status_e status = vip_flush_buffer(handle_, VIP_BUFFER_OPER_TYPE_INVALIDATE);
return status == VIP_SUCCESS ? true : false;
}
HandleImpl::~HandleImpl() {
if (handle_) {
vip_destroy_buffer(handle_);
handle_ = nullptr;
} }
return status == VIP_SUCCESS ? true : false;
} }
} }

View File

@ -37,46 +37,26 @@ enum class HandleFlushType {
HandleInvalidate = 1 HandleInvalidate = 1
}; };
class InternalHandle; class HandleImpl : public Handle {
public:
HandleImpl(uint8_t* buffer, size_t size)
: buffer_(buffer), buffer_size_(size) {}
class HandleImpl { bool CreateVipInputBuffer(vip_network network, uint32_t in_idx);
public: bool CreateVipPOutputBuffer(vip_network network, uint32_t out_idx);
virtual std::shared_ptr<InternalHandle> Register( uint32_t Index() { return index_; }
vip_buffer_create_params_t& params) = 0; vip_buffer VipHandle() { return handle_; }
virtual std::shared_ptr<InternalHandle>& internal_handle() = 0; bool Flush() override;
}; bool Invalidate() override;
class UserHandleImpl : public HandleImpl { ~HandleImpl();
public:
UserHandleImpl(void* buffer, size_t size)
: user_buffer_(buffer), user_buffer_size_(size) {}
std::shared_ptr<InternalHandle> Register(
vip_buffer_create_params_t& params) override;
std::shared_ptr<InternalHandle>& internal_handle() override {
return internal_handle_;
}
size_t user_buffer_size() const { return user_buffer_size_; }
void* user_buffer() { return user_buffer_; }
private:
void* user_buffer_;
size_t user_buffer_size_;
std::shared_ptr<InternalHandle> internal_handle_ = nullptr;
};
class InternalHandle { private:
public: void SetIndex(uint32_t idx) { index_ = idx; }
~InternalHandle(); uint8_t* buffer_ = nullptr;
virtual bool Flush(HandleFlushType type) = 0; size_t buffer_size_ = 0;
vip_buffer handle() { return handle_; }; vip_buffer handle_;
protected: uint32_t index_;
vip_buffer handle_;
};
class InternalUserHandle : public InternalHandle {
public:
InternalUserHandle(void* user_buffer, size_t user_buffer_size,
vip_buffer_create_params_t& params);
bool Flush(HandleFlushType type) override;
}; };
} }