Fixed IExecutable object not bind with DeviceID (#624)

If Executable object doesn't bind with a concrete DeviceID,
it will go first device by default.

When run multi executable with multi device, the behavior is not
expected. Fixed by attach device id with CompileOption.

Signed-off-by: xiang.zhang <xiang.zhang@verisilicon.com>
This commit is contained in:
Sven 2023-07-24 22:45:54 +08:00 committed by GitHub
parent 680e8d59cb
commit 821864a582
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 67 additions and 7 deletions

View File

@ -28,6 +28,10 @@ if(${TIM_VX_CODE_COVERAGE})
set(CMAKE_C_FLAGS "-g -O0 --coverage -fprofile-arcs -ftest-coverage") set(CMAKE_C_FLAGS "-g -O0 --coverage -fprofile-arcs -ftest-coverage")
endif() endif()
if(${TIM_VX_ENABLE_PLATFORM})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_PLATFORM")
endif()
if(${TIM_VX_ENABLE_PLATFORM_LITE}) if(${TIM_VX_ENABLE_PLATFORM_LITE})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_PLATFORM_LITE") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_PLATFORM_LITE")
endif() endif()

View File

@ -27,6 +27,10 @@
#include <map> #include <map>
#include <memory> #include <memory>
#if defined(ENABLE_PLATFORM)
#include "platform/platform.h"
#endif
namespace tim { namespace tim {
namespace vx { namespace vx {
struct CompileOptionImpl; struct CompileOptionImpl;
@ -38,6 +42,11 @@ class CompileOption {
bool isRelaxMode() const; bool isRelaxMode() const;
bool setRelaxMode(bool enable = false); bool setRelaxMode(bool enable = false);
#if defined(ENABLE_PLATFORM)
void setDeviceId(::tim::vx::platform::IDevice::device_id_t device);
::tim::vx::platform::IDevice::device_id_t getDeviceId();
#endif
static CompileOption DefaultOptions; static CompileOption DefaultOptions;
private: private:

View File

@ -42,11 +42,15 @@ class Tensor;
struct TensorSpec; struct TensorSpec;
struct DmaBufferDesc; struct DmaBufferDesc;
class Operation; class Operation;
class CompileOption;
class Graph { class Graph {
public: public:
virtual ~Graph() {} virtual ~Graph() {}
/// Attach CompileOption
virtual void SetCompileOption(const CompileOption&) = 0;
/// Create a tensor with given `TensorSpec` /// Create a tensor with given `TensorSpec`
virtual std::shared_ptr<Tensor> CreateTensor(const TensorSpec& spec, virtual std::shared_ptr<Tensor> CreateTensor(const TensorSpec& spec,
const void* data = nullptr) = 0; const void* data = nullptr) = 0;

View File

@ -31,7 +31,6 @@
#include "tim/vx/graph.h" #include "tim/vx/graph.h"
#include "tim/vx/tensor.h" #include "tim/vx/tensor.h"
#include "tim/vx/context.h" #include "tim/vx/context.h"
#include "tim/vx/ops/nbg.h"
namespace tim { namespace tim {
namespace vx { namespace vx {

View File

@ -34,6 +34,9 @@ struct CompileOptionImpl {
using RelaxModeType = std::tuple<std::string, bool, bool, bool>; using RelaxModeType = std::tuple<std::string, bool, bool, bool>;
CompileOptionImpl() { CompileOptionImpl() {
relax_mode_ = RelaxModeType(std::string("RelaxMode"), false, false, false); relax_mode_ = RelaxModeType(std::string("RelaxMode"), false, false, false);
#if defined(ENABLE_PLATFORM)
device_id_ = 0;
#endif
} }
bool RelaxMode() const { bool RelaxMode() const {
@ -46,6 +49,17 @@ struct CompileOptionImpl {
: std::get<3>(relax_mode_); : std::get<3>(relax_mode_);
} }
#if defined(ENABLE_PLATFORM)
void setDeviceId(::tim::vx::platform::IDevice::device_id_t device) {
device_id_ = device;
}
::tim::vx::platform::IDevice::device_id_t getDeviceId() {
return device_id_;
}
::tim::vx::platform::IDevice::device_id_t device_id_;
#endif
RelaxModeType relax_mode_; RelaxModeType relax_mode_;
}; };
@ -56,5 +70,17 @@ bool CompileOption::isRelaxMode() const { return this->impl_->RelaxMode(); }
bool CompileOption::setRelaxMode(bool enable) { bool CompileOption::setRelaxMode(bool enable) {
return this->impl_->RelaxMode() = enable; return this->impl_->RelaxMode() = enable;
} }
#if defined(ENABLE_PLATFORM)
void CompileOption::setDeviceId(::tim::vx::platform::IDevice::device_id_t device) {
this->impl_->setDeviceId(device);
}
::tim::vx::platform::IDevice::device_id_t CompileOption::getDeviceId() {
return this->impl_->getDeviceId();
}
#endif
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim

View File

@ -150,6 +150,10 @@ std::shared_ptr<Tensor> GraphImpl::GetTensorFromCache(const TensorSpec& spec, co
} }
#endif #endif
void GraphImpl::SetCompileOption(const CompileOption& new_options) {
options_ = new_options;
}
vsi_nn_graph_t* GraphImpl::graph() { return graph_; } vsi_nn_graph_t* GraphImpl::graph() { return graph_; }
void GraphImpl::AddInput(vsi_nn_tensor_id_t id) { void GraphImpl::AddInput(vsi_nn_tensor_id_t id) {
@ -321,6 +325,12 @@ bool GraphImpl::Setup() {
} }
vsi_nn_SetGraphFastMode(graph_, is_fast_mode); vsi_nn_SetGraphFastMode(graph_, is_fast_mode);
#if defined(ENABLE_PLATFORM)
auto id = options_.getDeviceId();
vxSetGraphAttribute(graph_->g, VX_GRAPH_DEVICE_INDEX_VIV,
(void*)(&id), sizeof(id));
#endif
std::call_once(setio_once_, [&status, this]() { std::call_once(setio_once_, [&status, this]() {
status = (vsi_nn_SetGraphInputs(this->graph_, this->inputs_.data(), status = (vsi_nn_SetGraphInputs(this->graph_, this->inputs_.data(),
this->inputs_.size()) && this->inputs_.size()) &&

View File

@ -49,6 +49,9 @@ class GraphImpl : public Graph {
const std::string CalculateCacheKey(const TensorSpec& spec, const void* data); const std::string CalculateCacheKey(const TensorSpec& spec, const void* data);
std::map<std::string, std::shared_ptr<tim::vx::Tensor>>& GetTensorCacheMap(); std::map<std::string, std::shared_ptr<tim::vx::Tensor>>& GetTensorCacheMap();
#endif #endif
void SetCompileOption(const CompileOption& new_option) override;
/// Return the low-level graph object /// Return the low-level graph object
vsi_nn_graph_t* graph(); vsi_nn_graph_t* graph();
void AddInput(vsi_nn_tensor_id_t id); void AddInput(vsi_nn_tensor_id_t id);

View File

@ -23,6 +23,7 @@
*****************************************************************************/ *****************************************************************************/
#include "tim/vx/platform/native.h" #include "tim/vx/platform/native.h"
#include "native_device_private.h" #include "native_device_private.h"
#include "tim/vx/ops/nbg.h"
namespace tim { namespace tim {
namespace vx { namespace vx {
@ -103,9 +104,13 @@ std::shared_ptr<IExecutor> IExecutable::Executor() const {
NativeExecutable::NativeExecutable(const std::shared_ptr<IExecutor>& executor, NativeExecutable::NativeExecutable(const std::shared_ptr<IExecutor>& executor,
const std::vector<char>& nb_buf, const std::vector<char>& nb_buf,
size_t inputs, size_t outputs) { size_t inputs, size_t outputs) {
CompileOption opt;
opt.setDeviceId(executor->Device()->Id());
executor_ = executor; executor_ = executor;
context_ = executor->Contex(); context_ = executor->Contex();
nb_graph_ = context_->CreateGraph(); nb_graph_ = context_->CreateGraph(opt);
nb_buf_ = nb_buf; nb_buf_ = nb_buf;
nb_node_ = nb_graph_->CreateOperation<tim::vx::ops::NBG>(nb_buf_.data(), nb_node_ = nb_graph_->CreateOperation<tim::vx::ops::NBG>(nb_buf_.data(),
inputs, outputs); inputs, outputs);
@ -269,11 +274,11 @@ bool NativeExecutor::Trigger(bool async) {
std::shared_ptr<IExecutable> NativeExecutor::Compile( std::shared_ptr<IExecutable> NativeExecutor::Compile(
const std::shared_ptr<Graph>& graph) { const std::shared_ptr<Graph>& graph) {
GraphImpl* graphimp =
dynamic_cast<GraphImpl*>(graph.get()); // hack to downcast CompileOption option;
IDevice::device_id_t id = device_->Id(); option.setDeviceId(device_->Id());
vxSetGraphAttribute(graphimp->graph()->g, VX_GRAPH_DEVICE_INDEX_VIV, graph->SetCompileOption(option);
(void*)(&id), sizeof(id));
size_t bin_size = -1; size_t bin_size = -1;
graph->CompileToBinary(nullptr, &bin_size); graph->CompileToBinary(nullptr, &bin_size);
std::vector<char> nb_buf; std::vector<char> nb_buf;