diff --git a/src/tim/vx/tensor.cc b/src/tim/vx/tensor.cc index 4fe6063..d12ed79 100644 --- a/src/tim/vx/tensor.cc +++ b/src/tim/vx/tensor.cc @@ -97,9 +97,10 @@ bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) { if (data && VSI_NN_TENSOR_ID_NA != id_) { retn = false; vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); - uint32_t tensor_bytes = vsi_nn_GetTensorSize( - tensor->attr.size, tensor->attr.dim_num, tensor->attr.dtype.vx_type); if (tensor) { + uint32_t tensor_bytes = vsi_nn_GetTensorSize( + tensor->attr.size, tensor->attr.dim_num, tensor->attr.dtype.vx_type); + if (tensor->attr.is_created_from_handle) { void *ptr = NULL; vsi_nn_GetTensorHandle(tensor, &ptr); @@ -119,8 +120,8 @@ bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) { const uint8_t* end = static_cast(data) + tensor_bytes; std::vector data_copy(static_cast(data), end); - retn = VSI_SUCCESS == - vsi_nn_CopyDataToTensor(graph_->graph(), tensor, data_copy.data()); + retn = (VSI_SUCCESS == + vsi_nn_CopyDataToTensor(graph_->graph(), tensor, data_copy.data())); } } } @@ -136,9 +137,11 @@ bool TensorImpl::CopyDataFromTensor(void* data) { if (data && VSI_NN_TENSOR_ID_NA != id_) { retn = false; vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); - uint32_t tensor_bytes = vsi_nn_GetTensorSize( - tensor->attr.size, tensor->attr.dim_num, tensor->attr.dtype.vx_type); + if (tensor) { + uint32_t tensor_bytes = vsi_nn_GetTensorSize( + tensor->attr.size, tensor->attr.dim_num, tensor->attr.dtype.vx_type); + if (tensor->attr.is_created_from_handle) { void* ptr = NULL; vsi_nn_GetTensorHandle(tensor, &ptr);