From 821864a582919313075f40f5c3f4ff64bf75e7c9 Mon Sep 17 00:00:00 2001 From: Sven Date: Mon, 24 Jul 2023 22:45:54 +0800 Subject: [PATCH] Fixed IExecutable object not bind with DeviceID (#624) If Executable object doesn't bind with a concrete DeviceID, it will go first device by default. When run multi executable with multi device, the behavior is not expected. Fixed by attach device id with CompileOption. Signed-off-by: xiang.zhang --- CMakeLists.txt | 4 ++++ include/tim/vx/compile_option.h | 9 +++++++++ include/tim/vx/graph.h | 4 ++++ include/tim/vx/platform/platform.h | 1 - src/tim/vx/compile_option.cc | 26 ++++++++++++++++++++++++++ src/tim/vx/graph.cc | 10 ++++++++++ src/tim/vx/graph_private.h | 3 +++ src/tim/vx/platform/native.cc | 17 +++++++++++------ 8 files changed, 67 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c12000..72691e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,6 +28,10 @@ if(${TIM_VX_CODE_COVERAGE}) set(CMAKE_C_FLAGS "-g -O0 --coverage -fprofile-arcs -ftest-coverage") endif() +if(${TIM_VX_ENABLE_PLATFORM}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_PLATFORM") +endif() + if(${TIM_VX_ENABLE_PLATFORM_LITE}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_PLATFORM_LITE") endif() diff --git a/include/tim/vx/compile_option.h b/include/tim/vx/compile_option.h index 9658013..b4e1761 100644 --- a/include/tim/vx/compile_option.h +++ b/include/tim/vx/compile_option.h @@ -27,6 +27,10 @@ #include #include +#if defined(ENABLE_PLATFORM) +#include "platform/platform.h" +#endif + namespace tim { namespace vx { struct CompileOptionImpl; @@ -38,6 +42,11 @@ class CompileOption { bool isRelaxMode() const; bool setRelaxMode(bool enable = false); +#if defined(ENABLE_PLATFORM) + void setDeviceId(::tim::vx::platform::IDevice::device_id_t device); + ::tim::vx::platform::IDevice::device_id_t getDeviceId(); +#endif + static CompileOption DefaultOptions; private: diff --git a/include/tim/vx/graph.h b/include/tim/vx/graph.h index 14ee305..66d68ac 100644 --- a/include/tim/vx/graph.h +++ b/include/tim/vx/graph.h @@ -42,11 +42,15 @@ class Tensor; struct TensorSpec; struct DmaBufferDesc; class Operation; +class CompileOption; class Graph { public: virtual ~Graph() {} + /// Attach CompileOption + virtual void SetCompileOption(const CompileOption&) = 0; + /// Create a tensor with given `TensorSpec` virtual std::shared_ptr CreateTensor(const TensorSpec& spec, const void* data = nullptr) = 0; diff --git a/include/tim/vx/platform/platform.h b/include/tim/vx/platform/platform.h index e90bce7..cc2799d 100644 --- a/include/tim/vx/platform/platform.h +++ b/include/tim/vx/platform/platform.h @@ -31,7 +31,6 @@ #include "tim/vx/graph.h" #include "tim/vx/tensor.h" #include "tim/vx/context.h" -#include "tim/vx/ops/nbg.h" namespace tim { namespace vx { diff --git a/src/tim/vx/compile_option.cc b/src/tim/vx/compile_option.cc index c87efbf..6371b3e 100644 --- a/src/tim/vx/compile_option.cc +++ b/src/tim/vx/compile_option.cc @@ -34,6 +34,9 @@ struct CompileOptionImpl { using RelaxModeType = std::tuple; CompileOptionImpl() { relax_mode_ = RelaxModeType(std::string("RelaxMode"), false, false, false); + #if defined(ENABLE_PLATFORM) + device_id_ = 0; + #endif } bool RelaxMode() const { @@ -46,6 +49,17 @@ struct CompileOptionImpl { : std::get<3>(relax_mode_); } +#if defined(ENABLE_PLATFORM) + void setDeviceId(::tim::vx::platform::IDevice::device_id_t device) { + device_id_ = device; + } + ::tim::vx::platform::IDevice::device_id_t getDeviceId() { + return device_id_; + } + + ::tim::vx::platform::IDevice::device_id_t device_id_; +#endif + RelaxModeType relax_mode_; }; @@ -56,5 +70,17 @@ bool CompileOption::isRelaxMode() const { return this->impl_->RelaxMode(); } bool CompileOption::setRelaxMode(bool enable) { return this->impl_->RelaxMode() = enable; } + +#if defined(ENABLE_PLATFORM) + void CompileOption::setDeviceId(::tim::vx::platform::IDevice::device_id_t device) { + this->impl_->setDeviceId(device); + } + + ::tim::vx::platform::IDevice::device_id_t CompileOption::getDeviceId() { + return this->impl_->getDeviceId(); + } + + +#endif } // namespace vx } // namespace tim diff --git a/src/tim/vx/graph.cc b/src/tim/vx/graph.cc index dfafb59..814cb71 100644 --- a/src/tim/vx/graph.cc +++ b/src/tim/vx/graph.cc @@ -150,6 +150,10 @@ std::shared_ptr GraphImpl::GetTensorFromCache(const TensorSpec& spec, co } #endif +void GraphImpl::SetCompileOption(const CompileOption& new_options) { + options_ = new_options; +} + vsi_nn_graph_t* GraphImpl::graph() { return graph_; } void GraphImpl::AddInput(vsi_nn_tensor_id_t id) { @@ -321,6 +325,12 @@ bool GraphImpl::Setup() { } vsi_nn_SetGraphFastMode(graph_, is_fast_mode); + #if defined(ENABLE_PLATFORM) + auto id = options_.getDeviceId(); + vxSetGraphAttribute(graph_->g, VX_GRAPH_DEVICE_INDEX_VIV, + (void*)(&id), sizeof(id)); + #endif + std::call_once(setio_once_, [&status, this]() { status = (vsi_nn_SetGraphInputs(this->graph_, this->inputs_.data(), this->inputs_.size()) && diff --git a/src/tim/vx/graph_private.h b/src/tim/vx/graph_private.h index 230339b..c64a3e3 100644 --- a/src/tim/vx/graph_private.h +++ b/src/tim/vx/graph_private.h @@ -49,6 +49,9 @@ class GraphImpl : public Graph { const std::string CalculateCacheKey(const TensorSpec& spec, const void* data); std::map>& GetTensorCacheMap(); #endif + + void SetCompileOption(const CompileOption& new_option) override; + /// Return the low-level graph object vsi_nn_graph_t* graph(); void AddInput(vsi_nn_tensor_id_t id); diff --git a/src/tim/vx/platform/native.cc b/src/tim/vx/platform/native.cc index 86c2ab2..81c9c3a 100644 --- a/src/tim/vx/platform/native.cc +++ b/src/tim/vx/platform/native.cc @@ -23,6 +23,7 @@ *****************************************************************************/ #include "tim/vx/platform/native.h" #include "native_device_private.h" +#include "tim/vx/ops/nbg.h" namespace tim { namespace vx { @@ -103,9 +104,13 @@ std::shared_ptr IExecutable::Executor() const { NativeExecutable::NativeExecutable(const std::shared_ptr& executor, const std::vector& nb_buf, size_t inputs, size_t outputs) { + CompileOption opt; + opt.setDeviceId(executor->Device()->Id()); + executor_ = executor; context_ = executor->Contex(); - nb_graph_ = context_->CreateGraph(); + nb_graph_ = context_->CreateGraph(opt); + nb_buf_ = nb_buf; nb_node_ = nb_graph_->CreateOperation(nb_buf_.data(), inputs, outputs); @@ -269,11 +274,11 @@ bool NativeExecutor::Trigger(bool async) { std::shared_ptr NativeExecutor::Compile( const std::shared_ptr& graph) { - GraphImpl* graphimp = - dynamic_cast(graph.get()); // hack to downcast - IDevice::device_id_t id = device_->Id(); - vxSetGraphAttribute(graphimp->graph()->g, VX_GRAPH_DEVICE_INDEX_VIV, - (void*)(&id), sizeof(id)); + + CompileOption option; + option.setDeviceId(device_->Id()); + graph->SetCompileOption(option); + size_t bin_size = -1; graph->CompileToBinary(nullptr, &bin_size); std::vector nb_buf;