Minor refinement: use tensor pointer after check

Signed-off-by: xiang.zhang <xiang.zhang@verisilicon.com>
This commit is contained in:
xiang.zhang 2021-08-03 21:54:06 +08:00 committed by Kainan Cha
parent f0d4118f87
commit d4a13e18a9
1 changed files with 9 additions and 6 deletions

View File

@ -97,9 +97,10 @@ bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) {
if (data && VSI_NN_TENSOR_ID_NA != id_) { if (data && VSI_NN_TENSOR_ID_NA != id_) {
retn = false; retn = false;
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
uint32_t tensor_bytes = vsi_nn_GetTensorSize(
tensor->attr.size, tensor->attr.dim_num, tensor->attr.dtype.vx_type);
if (tensor) { if (tensor) {
uint32_t tensor_bytes = vsi_nn_GetTensorSize(
tensor->attr.size, tensor->attr.dim_num, tensor->attr.dtype.vx_type);
if (tensor->attr.is_created_from_handle) { if (tensor->attr.is_created_from_handle) {
void *ptr = NULL; void *ptr = NULL;
vsi_nn_GetTensorHandle(tensor, &ptr); vsi_nn_GetTensorHandle(tensor, &ptr);
@ -119,8 +120,8 @@ bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) {
const uint8_t* end = static_cast<const uint8_t*>(data) + tensor_bytes; const uint8_t* end = static_cast<const uint8_t*>(data) + tensor_bytes;
std::vector<uint8_t> data_copy(static_cast<const uint8_t*>(data), end); std::vector<uint8_t> data_copy(static_cast<const uint8_t*>(data), end);
retn = VSI_SUCCESS == retn = (VSI_SUCCESS ==
vsi_nn_CopyDataToTensor(graph_->graph(), tensor, data_copy.data()); vsi_nn_CopyDataToTensor(graph_->graph(), tensor, data_copy.data()));
} }
} }
} }
@ -136,9 +137,11 @@ bool TensorImpl::CopyDataFromTensor(void* data) {
if (data && VSI_NN_TENSOR_ID_NA != id_) { if (data && VSI_NN_TENSOR_ID_NA != id_) {
retn = false; retn = false;
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
uint32_t tensor_bytes = vsi_nn_GetTensorSize(
tensor->attr.size, tensor->attr.dim_num, tensor->attr.dtype.vx_type);
if (tensor) { if (tensor) {
uint32_t tensor_bytes = vsi_nn_GetTensorSize(
tensor->attr.size, tensor->attr.dim_num, tensor->attr.dtype.vx_type);
if (tensor->attr.is_created_from_handle) { if (tensor->attr.is_created_from_handle) {
void* ptr = NULL; void* ptr = NULL;
vsi_nn_GetTensorHandle(tensor, &ptr); vsi_nn_GetTensorHandle(tensor, &ptr);