Fixed IExecutable object not bind with DeviceID (#624)
If Executable object doesn't bind with a concrete DeviceID, it will go first device by default. When run multi executable with multi device, the behavior is not expected. Fixed by attach device id with CompileOption. Signed-off-by: xiang.zhang <xiang.zhang@verisilicon.com>
This commit is contained in:
parent
680e8d59cb
commit
821864a582
|
|
@ -28,6 +28,10 @@ if(${TIM_VX_CODE_COVERAGE})
|
|||
set(CMAKE_C_FLAGS "-g -O0 --coverage -fprofile-arcs -ftest-coverage")
|
||||
endif()
|
||||
|
||||
if(${TIM_VX_ENABLE_PLATFORM})
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_PLATFORM")
|
||||
endif()
|
||||
|
||||
if(${TIM_VX_ENABLE_PLATFORM_LITE})
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_PLATFORM_LITE")
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -27,6 +27,10 @@
|
|||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#if defined(ENABLE_PLATFORM)
|
||||
#include "platform/platform.h"
|
||||
#endif
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
struct CompileOptionImpl;
|
||||
|
|
@ -38,6 +42,11 @@ class CompileOption {
|
|||
bool isRelaxMode() const;
|
||||
bool setRelaxMode(bool enable = false);
|
||||
|
||||
#if defined(ENABLE_PLATFORM)
|
||||
void setDeviceId(::tim::vx::platform::IDevice::device_id_t device);
|
||||
::tim::vx::platform::IDevice::device_id_t getDeviceId();
|
||||
#endif
|
||||
|
||||
static CompileOption DefaultOptions;
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -42,11 +42,15 @@ class Tensor;
|
|||
struct TensorSpec;
|
||||
struct DmaBufferDesc;
|
||||
class Operation;
|
||||
class CompileOption;
|
||||
|
||||
class Graph {
|
||||
public:
|
||||
virtual ~Graph() {}
|
||||
|
||||
/// Attach CompileOption
|
||||
virtual void SetCompileOption(const CompileOption&) = 0;
|
||||
|
||||
/// Create a tensor with given `TensorSpec`
|
||||
virtual std::shared_ptr<Tensor> CreateTensor(const TensorSpec& spec,
|
||||
const void* data = nullptr) = 0;
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@
|
|||
#include "tim/vx/graph.h"
|
||||
#include "tim/vx/tensor.h"
|
||||
#include "tim/vx/context.h"
|
||||
#include "tim/vx/ops/nbg.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
|
|||
|
|
@ -34,6 +34,9 @@ struct CompileOptionImpl {
|
|||
using RelaxModeType = std::tuple<std::string, bool, bool, bool>;
|
||||
CompileOptionImpl() {
|
||||
relax_mode_ = RelaxModeType(std::string("RelaxMode"), false, false, false);
|
||||
#if defined(ENABLE_PLATFORM)
|
||||
device_id_ = 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool RelaxMode() const {
|
||||
|
|
@ -46,6 +49,17 @@ struct CompileOptionImpl {
|
|||
: std::get<3>(relax_mode_);
|
||||
}
|
||||
|
||||
#if defined(ENABLE_PLATFORM)
|
||||
void setDeviceId(::tim::vx::platform::IDevice::device_id_t device) {
|
||||
device_id_ = device;
|
||||
}
|
||||
::tim::vx::platform::IDevice::device_id_t getDeviceId() {
|
||||
return device_id_;
|
||||
}
|
||||
|
||||
::tim::vx::platform::IDevice::device_id_t device_id_;
|
||||
#endif
|
||||
|
||||
RelaxModeType relax_mode_;
|
||||
};
|
||||
|
||||
|
|
@ -56,5 +70,17 @@ bool CompileOption::isRelaxMode() const { return this->impl_->RelaxMode(); }
|
|||
bool CompileOption::setRelaxMode(bool enable) {
|
||||
return this->impl_->RelaxMode() = enable;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_PLATFORM)
|
||||
void CompileOption::setDeviceId(::tim::vx::platform::IDevice::device_id_t device) {
|
||||
this->impl_->setDeviceId(device);
|
||||
}
|
||||
|
||||
::tim::vx::platform::IDevice::device_id_t CompileOption::getDeviceId() {
|
||||
return this->impl_->getDeviceId();
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
|
|||
|
|
@ -150,6 +150,10 @@ std::shared_ptr<Tensor> GraphImpl::GetTensorFromCache(const TensorSpec& spec, co
|
|||
}
|
||||
#endif
|
||||
|
||||
void GraphImpl::SetCompileOption(const CompileOption& new_options) {
|
||||
options_ = new_options;
|
||||
}
|
||||
|
||||
vsi_nn_graph_t* GraphImpl::graph() { return graph_; }
|
||||
|
||||
void GraphImpl::AddInput(vsi_nn_tensor_id_t id) {
|
||||
|
|
@ -321,6 +325,12 @@ bool GraphImpl::Setup() {
|
|||
}
|
||||
vsi_nn_SetGraphFastMode(graph_, is_fast_mode);
|
||||
|
||||
#if defined(ENABLE_PLATFORM)
|
||||
auto id = options_.getDeviceId();
|
||||
vxSetGraphAttribute(graph_->g, VX_GRAPH_DEVICE_INDEX_VIV,
|
||||
(void*)(&id), sizeof(id));
|
||||
#endif
|
||||
|
||||
std::call_once(setio_once_, [&status, this]() {
|
||||
status = (vsi_nn_SetGraphInputs(this->graph_, this->inputs_.data(),
|
||||
this->inputs_.size()) &&
|
||||
|
|
|
|||
|
|
@ -49,6 +49,9 @@ class GraphImpl : public Graph {
|
|||
const std::string CalculateCacheKey(const TensorSpec& spec, const void* data);
|
||||
std::map<std::string, std::shared_ptr<tim::vx::Tensor>>& GetTensorCacheMap();
|
||||
#endif
|
||||
|
||||
void SetCompileOption(const CompileOption& new_option) override;
|
||||
|
||||
/// Return the low-level graph object
|
||||
vsi_nn_graph_t* graph();
|
||||
void AddInput(vsi_nn_tensor_id_t id);
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#include "tim/vx/platform/native.h"
|
||||
#include "native_device_private.h"
|
||||
#include "tim/vx/ops/nbg.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -103,9 +104,13 @@ std::shared_ptr<IExecutor> IExecutable::Executor() const {
|
|||
NativeExecutable::NativeExecutable(const std::shared_ptr<IExecutor>& executor,
|
||||
const std::vector<char>& nb_buf,
|
||||
size_t inputs, size_t outputs) {
|
||||
CompileOption opt;
|
||||
opt.setDeviceId(executor->Device()->Id());
|
||||
|
||||
executor_ = executor;
|
||||
context_ = executor->Contex();
|
||||
nb_graph_ = context_->CreateGraph();
|
||||
nb_graph_ = context_->CreateGraph(opt);
|
||||
|
||||
nb_buf_ = nb_buf;
|
||||
nb_node_ = nb_graph_->CreateOperation<tim::vx::ops::NBG>(nb_buf_.data(),
|
||||
inputs, outputs);
|
||||
|
|
@ -269,11 +274,11 @@ bool NativeExecutor::Trigger(bool async) {
|
|||
|
||||
std::shared_ptr<IExecutable> NativeExecutor::Compile(
|
||||
const std::shared_ptr<Graph>& graph) {
|
||||
GraphImpl* graphimp =
|
||||
dynamic_cast<GraphImpl*>(graph.get()); // hack to downcast
|
||||
IDevice::device_id_t id = device_->Id();
|
||||
vxSetGraphAttribute(graphimp->graph()->g, VX_GRAPH_DEVICE_INDEX_VIV,
|
||||
(void*)(&id), sizeof(id));
|
||||
|
||||
CompileOption option;
|
||||
option.setDeviceId(device_->Id());
|
||||
graph->SetCompileOption(option);
|
||||
|
||||
size_t bin_size = -1;
|
||||
graph->CompileToBinary(nullptr, &bin_size);
|
||||
std::vector<char> nb_buf;
|
||||
|
|
|
|||
Loading…
Reference in New Issue