Refine Lite API (#221)

Signed-off-by: Zongwu Yang <zongwu.yang@verisilicon.com>
This commit is contained in:
Zongwu.Yang 2021-11-19 20:30:26 +08:00 committed by GitHub
parent 0ca4970d72
commit c90efe70c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 247 additions and 244 deletions

View File

@ -34,14 +34,24 @@ namespace tim {
namespace lite {
class Execution {
public:
static std::shared_ptr<Execution> Create(
const void* executable, size_t executable_size);
virtual Execution& BindInputs(const std::vector<std::shared_ptr<Handle>>& handles) = 0;
virtual Execution& BindOutputs(const std::vector<std::shared_ptr<Handle>>& handles) = 0;
virtual bool Trigger() = 0;
public:
static std::shared_ptr<Execution> Create(const void* executable,
size_t executable_size);
virtual std::shared_ptr<Handle> CreateInputHandle(uint32_t in_idx,
uint8_t* buffer,
size_t size) = 0;
virtual std::shared_ptr<Handle> CreateOutputHandle(uint32_t out_idx,
uint8_t* buffer,
size_t size) = 0;
virtual Execution& BindInputs(
const std::vector<std::shared_ptr<Handle>>& handles) = 0;
virtual Execution& BindOutputs(
const std::vector<std::shared_ptr<Handle>>& handles) = 0;
virtual Execution& UnBindInput(const std::shared_ptr<Handle>& Handle) = 0;
virtual Execution& UnBindOutput(const std::shared_ptr<Handle>& handle) = 0;
virtual bool Trigger() = 0;
};
}
}
} // namespace lite
} // namespace tim
#endif

View File

@ -30,21 +30,10 @@
namespace tim {
namespace lite {
class HandleImpl;
class Handle {
public:
std::unique_ptr<HandleImpl>& impl() { return impl_; }
bool Flush();
bool Invalidate();
protected:
std::unique_ptr<HandleImpl> impl_;
};
class UserHandle : public Handle {
public:
UserHandle(void* buffer, size_t size);
~UserHandle();
virtual bool Flush() = 0;
virtual bool Invalidate() = 0;
};
}

View File

@ -121,16 +121,28 @@ int main() {
assert(input);
assert(output);
memset(output, 0, output_sz);
memcpy(input, input_data.data(), input_data.size());
auto input_handle = std::make_shared<tim::lite::UserHandle>(
input, input_data.size());
auto output_handle = std::make_shared<tim::lite::UserHandle>(
output, lenet_output_size * sizeof(float));
auto input_handle = exec->CreateInputHandle(0, input, input_sz);
auto output_handle = exec->CreateOutputHandle(0, (uint8_t*)output, output_sz);
exec->BindInputs({input_handle});
exec->BindOutputs({output_handle});
memcpy(input, input_data.data(), input_data.size());
input_handle->Flush();
exec->Trigger();
output_handle->Invalidate();
printTopN(output, lenet_output_size, 5);
// rebind input and output
exec->UnBindInput(input_handle);
exec->UnBindOutput(output_handle);
exec->BindInputs({input_handle});
exec->BindOutputs({output_handle});
input_handle->Flush();
exec->Trigger();
output_handle->Invalidate();
printTopN(output, lenet_output_size, 5);
free(output);
free(input);
} else {

View File

@ -28,6 +28,9 @@
#include <cstring>
#include <vector>
#include <memory>
#include <algorithm>
#include <iostream>
#include <cassert>
#include "handle_private.h"
#include "vip_lite.h"
@ -35,67 +38,6 @@
namespace tim {
namespace lite {
namespace {
bool QueryInputBufferParameters(
vip_buffer_create_params_t& param, uint32_t index, vip_network network) {
uint32_t count = 0;
vip_query_network(network, VIP_NETWORK_PROP_INPUT_COUNT, &count);
if (index >= count) {
return false;
}
memset(&param, 0, sizeof(param));
param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT;
vip_query_input(network, index, VIP_BUFFER_PROP_DATA_FORMAT, &param.data_format);
vip_query_input(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, &param.num_of_dims);
vip_query_input(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes);
vip_query_input(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, &param.quant_format);
switch(param.quant_format) {
case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT:
vip_query_input(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS,
&param.quant_data.dfp.fixed_point_pos);
break;
case VIP_BUFFER_QUANTIZE_TF_ASYMM:
vip_query_input(network, index, VIP_BUFFER_PROP_TF_SCALE,
&param.quant_data.affine.scale);
vip_query_input(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT,
&param.quant_data.affine.zeroPoint);
default:
break;
}
return true;
}
bool QueryOutputBufferParameters(
vip_buffer_create_params_t& param, uint32_t index, vip_network network) {
uint32_t count = 0;
vip_query_network(network, VIP_NETWORK_PROP_OUTPUT_COUNT, &count);
if (index >= count) {
return false;
}
memset(&param, 0, sizeof(param));
param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT;
vip_query_output(network, index, VIP_BUFFER_PROP_DATA_FORMAT, &param.data_format);
vip_query_output(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, &param.num_of_dims);
vip_query_output(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes);
vip_query_output(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, &param.quant_format);
switch(param.quant_format) {
case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT:
vip_query_output(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS,
&param.quant_data.dfp.fixed_point_pos);
break;
case VIP_BUFFER_QUANTIZE_TF_ASYMM:
vip_query_output(network, index, VIP_BUFFER_PROP_TF_SCALE,
&param.quant_data.affine.scale);
vip_query_output(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT,
&param.quant_data.affine.zeroPoint);
break;
default:
break;
}
return true;
}
}
ExecutionImpl::ExecutionImpl(const void* executable, size_t executable_size) {
vip_status_e status = VIP_SUCCESS;
vip_network network = nullptr;
@ -130,41 +72,44 @@ ExecutionImpl::~ExecutionImpl() {
vip_finish_network(network_);
vip_destroy_network(network_);
}
input_maps_.clear();
output_maps_.clear();
input_handles_.clear();
output_handles_.clear();
vip_destroy();
}
std::shared_ptr<Handle> ExecutionImpl::CreateInputHandle(uint32_t in_idx, uint8_t* buffer, size_t size) {
auto handle = std::make_shared<HandleImpl>(buffer, size);
if (handle->CreateVipInputBuffer(network_, in_idx)) {
return handle;
} else {
return nullptr;
}
}
std::shared_ptr<Handle> ExecutionImpl::CreateOutputHandle(uint32_t out_idx, uint8_t* buffer, size_t size) {
auto handle = std::make_shared<HandleImpl>(buffer, size);
if (handle->CreateVipPOutputBuffer(network_, out_idx)) {
return handle;
} else {
return nullptr;
}
}
Execution& ExecutionImpl::BindInputs(const std::vector<std::shared_ptr<Handle>>& handles) {
if (!IsValid()) {
return *this;
}
vip_status_e status = VIP_SUCCESS;
vip_buffer_create_params_t param;
for (uint32_t i = 0; i < handles.size(); i ++) {
auto handle = handles[i];
if (!handle) {
status = VIP_ERROR_FAILURE;
break;
}
std::shared_ptr<InternalHandle> internal_handle = nullptr;
if (input_maps_.find(handle) == input_maps_.end()) {
if (!QueryInputBufferParameters(param, i, network_)) {
status = VIP_ERROR_FAILURE;
break;
for (auto handle : handles) {
if (input_handles_.end() == std::find(input_handles_.begin(), input_handles_.end(), handle)) {
input_handles_.push_back(handle);
auto handle_impl = std::dynamic_pointer_cast<HandleImpl>(handle);
vip_status_e status = vip_set_input(network_, handle_impl->Index(), handle_impl->VipHandle());
if (status != VIP_SUCCESS) {
std::cout << "Set input for network failed." << std::endl;
assert(false);
}
internal_handle = handle->impl()->Register(param);
if (!internal_handle) {
status = VIP_ERROR_FAILURE;
break;
}
input_maps_[handle] = internal_handle;
} else {
internal_handle = input_maps_.at(handle);
}
status = vip_set_input(network_, i, internal_handle->handle());
if (status != VIP_SUCCESS) {
break;
std::cout << "The input handle has been binded, need not bind it again." << std::endl;
}
}
return *this;
@ -174,37 +119,38 @@ Execution& ExecutionImpl::BindOutputs(const std::vector<std::shared_ptr<Handle>>
if (!IsValid()) {
return *this;
}
vip_status_e status = VIP_SUCCESS;
vip_buffer_create_params_t param;
for (uint32_t i = 0; i < handles.size(); i ++) {
auto handle = handles[i];
if (!handle) {
status = VIP_ERROR_FAILURE;
break;
}
std::shared_ptr<InternalHandle> internal_handle = nullptr;
if (output_maps_.find(handle) == output_maps_.end()) {
if (!QueryOutputBufferParameters(param, i, network_)) {
status = VIP_ERROR_FAILURE;
break;
for (auto handle : handles) {
if (output_handles_.end() == std::find(output_handles_.begin(), output_handles_.end(), handle)) {
output_handles_.push_back(handle);
auto handle_impl = std::dynamic_pointer_cast<HandleImpl>(handle);
vip_status_e status = vip_set_output(network_, handle_impl->Index(), handle_impl->VipHandle());
if (status != VIP_SUCCESS) {
std::cout << "Set output for network failed." << std::endl;
assert(false);
}
internal_handle = handle->impl()->Register(param);
if (!internal_handle) {
status = VIP_ERROR_FAILURE;
break;
}
output_maps_[handle] = internal_handle;
} else {
internal_handle = output_maps_.at(handle);
}
status = vip_set_output(network_, i, internal_handle->handle());
if (status != VIP_SUCCESS) {
break;
std::cout << "The output handle has been binded, need not bind it again." << std::endl;
}
}
return *this;
};
Execution& ExecutionImpl::UnBindInput(const std::shared_ptr<Handle>& handle) {
auto it = std::find(input_handles_.begin(), input_handles_.end(), handle);
if (input_handles_.end() != it) {
input_handles_.erase(it);
}
return *this;
}
Execution& ExecutionImpl::UnBindOutput(const std::shared_ptr<Handle>& handle) {
auto it = std::find(output_handles_.begin(), output_handles_.end(), handle);
if (output_handles_.end() != it) {
output_handles_.erase(it);
}
return *this;
}
bool ExecutionImpl::Trigger() {
if (!IsValid()) {
return false;

View File

@ -36,21 +36,30 @@ namespace tim {
namespace lite {
class ExecutionImpl : public Execution {
public :
ExecutionImpl(const void* executable, size_t executable_size);
~ExecutionImpl();
Execution& BindInputs(const std::vector<std::shared_ptr<Handle>>& handles) override;
Execution& BindOutputs(const std::vector<std::shared_ptr<Handle>>& handles) override;
bool Trigger() override;
bool IsValid() const { return valid_; };
vip_network network() { return network_; };
private:
std::map<std::shared_ptr<Handle>, std::shared_ptr<InternalHandle>> input_maps_;
std::map<std::shared_ptr<Handle>, std::shared_ptr<InternalHandle>> output_maps_;
bool valid_;
vip_network network_;
public:
ExecutionImpl(const void* executable, size_t executable_size);
~ExecutionImpl();
std::shared_ptr<Handle> CreateInputHandle(uint32_t in_idx, uint8_t* buffer,
size_t size) override;
std::shared_ptr<Handle> CreateOutputHandle(uint32_t out_idx, uint8_t* buffer,
size_t size) override;
Execution& BindInputs(
const std::vector<std::shared_ptr<Handle>>& handles) override;
Execution& BindOutputs(
const std::vector<std::shared_ptr<Handle>>& handles) override;
Execution& UnBindInput(const std::shared_ptr<Handle>& Handle) override;
Execution& UnBindOutput(const std::shared_ptr<Handle>& handle) override;
bool Trigger() override;
bool IsValid() const { return valid_; };
vip_network network() { return network_; };
private:
std::vector<std::shared_ptr<Handle>> input_handles_;
std::vector<std::shared_ptr<Handle>> output_handles_;
bool valid_;
vip_network network_;
};
}
}
} // namespace lite
} // namespace tim
#endif

View File

@ -28,6 +28,7 @@
#include <cstdint>
#include <memory>
#include <iostream>
#include <string.h>
#include "execution_private.h"
#include "vip_lite.h"
@ -36,70 +37,126 @@
namespace tim {
namespace lite {
bool Handle::Flush() {
auto internal_handle = impl_->internal_handle();
return internal_handle->Flush(HandleFlushType::HandleFlush);
}
bool Handle::Invalidate() {
auto internal_handle = impl_->internal_handle();
return internal_handle->Flush(HandleFlushType::HandleInvalidate);
}
UserHandle::UserHandle(void* buffer, size_t size) {
assert((reinterpret_cast<uintptr_t>(buffer) % _64_BYTES_ALIGN) == 0);
impl_ = std::make_unique<UserHandleImpl>(buffer, size);
}
UserHandle::~UserHandle() {}
std::shared_ptr<InternalHandle> UserHandleImpl::Register(
vip_buffer_create_params_t& params) {
internal_handle_ = std::make_shared<InternalUserHandle>(
user_buffer_, user_buffer_size_, params);
if (!internal_handle_->handle()) {
internal_handle_.reset();
namespace {
bool QueryInputBufferParameters(
vip_buffer_create_params_t& param, uint32_t index, vip_network network) {
uint32_t count = 0;
vip_query_network(network, VIP_NETWORK_PROP_INPUT_COUNT, &count);
if (index >= count) {
return false;
}
return internal_handle_;
}
InternalHandle::~InternalHandle() {
if (handle_) {
vip_destroy_buffer(handle_);
handle_ = nullptr;
memset(&param, 0, sizeof(param));
param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT;
vip_query_input(network, index, VIP_BUFFER_PROP_DATA_FORMAT, &param.data_format);
vip_query_input(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, &param.num_of_dims);
vip_query_input(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes);
vip_query_input(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, &param.quant_format);
switch(param.quant_format) {
case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT:
vip_query_input(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS,
&param.quant_data.dfp.fixed_point_pos);
break;
case VIP_BUFFER_QUANTIZE_TF_ASYMM:
vip_query_input(network, index, VIP_BUFFER_PROP_TF_SCALE,
&param.quant_data.affine.scale);
vip_query_input(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT,
&param.quant_data.affine.zeroPoint);
default:
break;
}
return true;
}
InternalUserHandle::InternalUserHandle(void* user_buffer, size_t user_buffer_size,
vip_buffer_create_params_t& params) {
vip_status_e status = VIP_SUCCESS;
vip_buffer internal_buffer = nullptr;
status = vip_create_buffer_from_handle(&params,
user_buffer, user_buffer_size, &internal_buffer);
if (status == VIP_SUCCESS) {
handle_ = internal_buffer;
} else {
handle_ = nullptr;
bool QueryOutputBufferParameters(
vip_buffer_create_params_t& param, uint32_t index, vip_network network) {
uint32_t count = 0;
vip_query_network(network, VIP_NETWORK_PROP_OUTPUT_COUNT, &count);
if (index >= count) {
return false;
}
memset(&param, 0, sizeof(param));
param.memory_type = VIP_BUFFER_MEMORY_TYPE_DEFAULT;
vip_query_output(network, index, VIP_BUFFER_PROP_DATA_FORMAT, &param.data_format);
vip_query_output(network, index, VIP_BUFFER_PROP_NUM_OF_DIMENSION, &param.num_of_dims);
vip_query_output(network, index, VIP_BUFFER_PROP_SIZES_OF_DIMENSION, param.sizes);
vip_query_output(network, index, VIP_BUFFER_PROP_QUANT_FORMAT, &param.quant_format);
switch(param.quant_format) {
case VIP_BUFFER_QUANTIZE_DYNAMIC_FIXED_POINT:
vip_query_output(network, index, VIP_BUFFER_PROP_FIXED_POINT_POS,
&param.quant_data.dfp.fixed_point_pos);
break;
case VIP_BUFFER_QUANTIZE_TF_ASYMM:
vip_query_output(network, index, VIP_BUFFER_PROP_TF_SCALE,
&param.quant_data.affine.scale);
vip_query_output(network, index, VIP_BUFFER_PROP_TF_ZERO_POINT,
&param.quant_data.affine.zeroPoint);
break;
default:
break;
}
return true;
}
}
bool InternalUserHandle::Flush(HandleFlushType type) {
bool HandleImpl::CreateVipInputBuffer(vip_network network,
uint32_t in_idx) {
vip_status_e status = VIP_SUCCESS;
switch (type) {
case HandleFlushType::HandleFlush: {
status = vip_flush_buffer(handle_, VIP_BUFFER_OPER_TYPE_FLUSH);
break;
}
case HandleFlushType::HandleInvalidate: {
status = vip_flush_buffer(handle_, VIP_BUFFER_OPER_TYPE_INVALIDATE);
break;
}
default:
std::cout << __FUNCTION__ << ":" << __LINE__ << " Unkown HandleFlushType."
<< std::endl;
assert(false);
vip_buffer_create_params_t param;
vip_buffer internal_buffer;
assert((reinterpret_cast<uintptr_t>(buffer_) % _64_BYTES_ALIGN) == 0);
if (!QueryInputBufferParameters(param, in_idx, network)) {
status = VIP_ERROR_FAILURE;
return false;
}
status = vip_create_buffer_from_handle(&param, buffer_, buffer_size_,
&internal_buffer);
if (status == VIP_SUCCESS) {
handle_ = internal_buffer;
SetIndex(in_idx);
return true;
} else {
handle_ = nullptr;
return false;
}
}
bool HandleImpl::CreateVipPOutputBuffer(vip_network network,
uint32_t out_idx) {
vip_status_e status = VIP_SUCCESS;
vip_buffer_create_params_t param;
vip_buffer internal_buffer;
assert((reinterpret_cast<uintptr_t>(buffer_) % _64_BYTES_ALIGN) == 0);
if (!QueryOutputBufferParameters(param, out_idx, network)) {
status = VIP_ERROR_FAILURE;
return false;
}
status = vip_create_buffer_from_handle(&param, buffer_, buffer_size_,
&internal_buffer);
if (status == VIP_SUCCESS) {
handle_ = internal_buffer;
SetIndex(out_idx);
return true;
} else {
handle_ = nullptr;
return false;
}
}
bool HandleImpl::Flush() {
vip_status_e status = vip_flush_buffer(handle_, VIP_BUFFER_OPER_TYPE_FLUSH);
return status == VIP_SUCCESS ? true : false;
}
bool HandleImpl::Invalidate() {
vip_status_e status = vip_flush_buffer(handle_, VIP_BUFFER_OPER_TYPE_INVALIDATE);
return status == VIP_SUCCESS ? true : false;
}
HandleImpl::~HandleImpl() {
if (handle_) {
vip_destroy_buffer(handle_);
handle_ = nullptr;
}
return status == VIP_SUCCESS ? true : false;
}
}

View File

@ -37,46 +37,26 @@ enum class HandleFlushType {
HandleInvalidate = 1
};
class InternalHandle;
class HandleImpl : public Handle {
public:
HandleImpl(uint8_t* buffer, size_t size)
: buffer_(buffer), buffer_size_(size) {}
class HandleImpl {
public:
virtual std::shared_ptr<InternalHandle> Register(
vip_buffer_create_params_t& params) = 0;
virtual std::shared_ptr<InternalHandle>& internal_handle() = 0;
};
bool CreateVipInputBuffer(vip_network network, uint32_t in_idx);
bool CreateVipPOutputBuffer(vip_network network, uint32_t out_idx);
uint32_t Index() { return index_; }
vip_buffer VipHandle() { return handle_; }
bool Flush() override;
bool Invalidate() override;
class UserHandleImpl : public HandleImpl {
public:
UserHandleImpl(void* buffer, size_t size)
: user_buffer_(buffer), user_buffer_size_(size) {}
std::shared_ptr<InternalHandle> Register(
vip_buffer_create_params_t& params) override;
std::shared_ptr<InternalHandle>& internal_handle() override {
return internal_handle_;
}
size_t user_buffer_size() const { return user_buffer_size_; }
void* user_buffer() { return user_buffer_; }
private:
void* user_buffer_;
size_t user_buffer_size_;
std::shared_ptr<InternalHandle> internal_handle_ = nullptr;
};
~HandleImpl();
class InternalHandle {
public:
~InternalHandle();
virtual bool Flush(HandleFlushType type) = 0;
vip_buffer handle() { return handle_; };
protected:
vip_buffer handle_;
};
class InternalUserHandle : public InternalHandle {
public:
InternalUserHandle(void* user_buffer, size_t user_buffer_size,
vip_buffer_create_params_t& params);
bool Flush(HandleFlushType type) override;
private:
void SetIndex(uint32_t idx) { index_ = idx; }
uint8_t* buffer_ = nullptr;
size_t buffer_size_ = 0;
vip_buffer handle_;
uint32_t index_;
};
}