[New API] Add compile_option support - relax_mode (#285)
Added new API for tim::vx::Context::CreateGraph with a CompileOption
Only one option added in CompileOption:
relax_mode : Run float32 mode with bfloat16
Signed-off-by: xiang.zhang <xiang.zhang@verisilicon.com>
This commit is contained in:
parent
1c5302ba51
commit
7c1a00213b
2
BUILD
2
BUILD
|
|
@ -34,6 +34,7 @@ cc_library(
|
||||||
"include/tim/vx/ops.h",
|
"include/tim/vx/ops.h",
|
||||||
"include/tim/vx/tensor.h",
|
"include/tim/vx/tensor.h",
|
||||||
"include/tim/vx/types.h",
|
"include/tim/vx/types.h",
|
||||||
|
"include/tim/vx/compile_option.h",
|
||||||
"include/tim/transform/layout_inference.h",
|
"include/tim/transform/layout_inference.h",
|
||||||
] + glob([
|
] + glob([
|
||||||
"include/tim/vx/ops/*.h"
|
"include/tim/vx/ops/*.h"
|
||||||
|
|
@ -41,6 +42,7 @@ cc_library(
|
||||||
srcs = [
|
srcs = [
|
||||||
"src/tim/vx/context_private.h",
|
"src/tim/vx/context_private.h",
|
||||||
"src/tim/vx/context.cc",
|
"src/tim/vx/context.cc",
|
||||||
|
"src/tim/vx/compile_option.cc",
|
||||||
"src/tim/vx/graph_private.h",
|
"src/tim/vx/graph_private.h",
|
||||||
"src/tim/vx/graph.cc",
|
"src/tim/vx/graph.cc",
|
||||||
"src/tim/vx/direct_map_op_impl.cc",
|
"src/tim/vx/direct_map_op_impl.cc",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2022 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#ifndef TIM_VX_COMPILE_OPTION_H_
|
||||||
|
#define TIM_VX_COMPILE_OPTION_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace tim {
|
||||||
|
namespace vx {
|
||||||
|
struct CompileOptionImpl;
|
||||||
|
class CompileOption {
|
||||||
|
public:
|
||||||
|
CompileOption();
|
||||||
|
~CompileOption(){};
|
||||||
|
|
||||||
|
bool isRelaxMode() const;
|
||||||
|
bool setRelaxMode(bool enable = false);
|
||||||
|
|
||||||
|
static CompileOption DefaultOptions;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// option can have dafult values
|
||||||
|
std::shared_ptr<CompileOptionImpl> impl_;
|
||||||
|
};
|
||||||
|
} // namespace vx
|
||||||
|
} // namespace tim
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
@ -29,13 +29,14 @@
|
||||||
namespace tim {
|
namespace tim {
|
||||||
namespace vx {
|
namespace vx {
|
||||||
|
|
||||||
|
struct CompileOption;
|
||||||
class Graph;
|
class Graph;
|
||||||
|
|
||||||
class Context {
|
class Context {
|
||||||
public:
|
public:
|
||||||
virtual ~Context() {}
|
virtual ~Context() {}
|
||||||
static std::shared_ptr<Context> Create();
|
static std::shared_ptr<Context> Create();
|
||||||
virtual std::shared_ptr<Graph> CreateGraph() = 0;
|
virtual std::shared_ptr<Graph> CreateGraph() = 0;
|
||||||
|
virtual std::shared_ptr<Graph> CreateGraph(const CompileOption& options) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace tim {
|
namespace tim {
|
||||||
namespace vx {
|
namespace vx {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,6 @@ if (NOT CMAKE_INSTALL_LIBDIR)
|
||||||
set(CMAKE_INSTALL_LIBDIR "lib")
|
set(CMAKE_INSTALL_LIBDIR "lib")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
install(TARGETS ${TARGET_NAME} ${TARGET_NAME}
|
install(TARGETS ${TARGET_NAME} ${TARGET_NAME}
|
||||||
DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
|
DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2022 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#include "tim/vx/compile_option.h"
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace tim {
|
||||||
|
namespace vx {
|
||||||
|
|
||||||
|
CompileOption CompileOption::DefaultOptions;
|
||||||
|
|
||||||
|
struct CompileOptionImpl {
|
||||||
|
// string: readable name; bool: setup or not; bool: value if setup; bool: default value if not setup;
|
||||||
|
using RelaxModeType = std::tuple<std::string, bool, bool, bool>;
|
||||||
|
CompileOptionImpl() {
|
||||||
|
relax_mode_ = {std::string("RelaxMode"), false, false, false};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RelaxMode() const {
|
||||||
|
return std::get<1>(relax_mode_) ? std::get<2>(relax_mode_)
|
||||||
|
: std::get<3>(relax_mode_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool& RelaxMode() {
|
||||||
|
return std::get<1>(relax_mode_) ? std::get<2>(relax_mode_)
|
||||||
|
: std::get<3>(relax_mode_);
|
||||||
|
}
|
||||||
|
|
||||||
|
RelaxModeType relax_mode_;
|
||||||
|
};
|
||||||
|
|
||||||
|
CompileOption::CompileOption() : impl_(new CompileOptionImpl()) {}
|
||||||
|
|
||||||
|
bool CompileOption::isRelaxMode() const { return this->impl_->RelaxMode(); }
|
||||||
|
|
||||||
|
bool CompileOption::setRelaxMode(bool enable) {
|
||||||
|
return this->impl_->RelaxMode() = enable;
|
||||||
|
}
|
||||||
|
} // namespace vx
|
||||||
|
} // namespace tim
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2022 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#include "tim/vx/compile_option.h"
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
TEST(compile_option, relax_mode) {
|
||||||
|
tim::vx::CompileOption opt;
|
||||||
|
|
||||||
|
EXPECT_TRUE(opt.isRelaxMode() == false);
|
||||||
|
opt.setRelaxMode(true);
|
||||||
|
EXPECT_TRUE(opt.isRelaxMode() == true);
|
||||||
|
|
||||||
|
EXPECT_TRUE(tim::vx::CompileOption::DefaultOptions.isRelaxMode() == false);
|
||||||
|
}
|
||||||
|
|
@ -26,6 +26,7 @@
|
||||||
#include "context_private.h"
|
#include "context_private.h"
|
||||||
#include "graph_private.h"
|
#include "graph_private.h"
|
||||||
#include "tim/vx/graph.h"
|
#include "tim/vx/graph.h"
|
||||||
|
#include "tim/vx/compile_option.h"
|
||||||
#include "vsi_nn_pub.h"
|
#include "vsi_nn_pub.h"
|
||||||
|
|
||||||
namespace tim {
|
namespace tim {
|
||||||
|
|
@ -48,5 +49,10 @@ std::shared_ptr<Context> Context::Create() {
|
||||||
std::shared_ptr<Graph> ContextImpl::CreateGraph() {
|
std::shared_ptr<Graph> ContextImpl::CreateGraph() {
|
||||||
return std::make_shared<GraphImpl>(this);
|
return std::make_shared<GraphImpl>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Graph> ContextImpl::CreateGraph(const CompileOption& options) {
|
||||||
|
return std::make_shared<GraphImpl>(this, options);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
@ -35,6 +35,7 @@ class ContextImpl : public Context {
|
||||||
~ContextImpl();
|
~ContextImpl();
|
||||||
vsi_nn_context_t context();
|
vsi_nn_context_t context();
|
||||||
std::shared_ptr<Graph> CreateGraph() override;
|
std::shared_ptr<Graph> CreateGraph() override;
|
||||||
|
std::shared_ptr<Graph> CreateGraph(const CompileOption&) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
vsi_nn_context_t context_;
|
vsi_nn_context_t context_;
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@
|
||||||
#include "tensor_private.h"
|
#include "tensor_private.h"
|
||||||
#include "tim/vx/context.h"
|
#include "tim/vx/context.h"
|
||||||
#include "tim/vx/ops/nbg.h"
|
#include "tim/vx/ops/nbg.h"
|
||||||
|
#include "tim/vx/compile_option.h"
|
||||||
#include "vsi_nn_pub.h"
|
#include "vsi_nn_pub.h"
|
||||||
|
|
||||||
namespace tim {
|
namespace tim {
|
||||||
|
|
@ -44,10 +45,11 @@ const std::vector<std::shared_ptr<Tensor>> Graph::GetConstantInputs() const {
|
||||||
return const_inputs;
|
return const_inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphImpl::GraphImpl(ContextImpl* context)
|
GraphImpl::GraphImpl(ContextImpl* context, const CompileOption& options)
|
||||||
: context_(context),
|
: context_(context),
|
||||||
graph_(vsi_nn_CreateGraph(context_->context(), 0, 0)),
|
graph_(vsi_nn_CreateGraph(context_->context(), 0, 0)),
|
||||||
tensor_placeholder_(nullptr) {}
|
tensor_placeholder_(nullptr),
|
||||||
|
options_(options){}
|
||||||
|
|
||||||
GraphImpl::~GraphImpl() { vsi_nn_ReleaseGraph(&graph_); }
|
GraphImpl::~GraphImpl() { vsi_nn_ReleaseGraph(&graph_); }
|
||||||
|
|
||||||
|
|
@ -156,7 +158,12 @@ bool GraphImpl::Compile() {
|
||||||
|
|
||||||
vsi_nn_SetGraphVersion(graph_, major, minor, patch);
|
vsi_nn_SetGraphVersion(graph_, major, minor, patch);
|
||||||
|
|
||||||
vsi_nn_SetGraphFastMode(graph_, false);
|
bool is_fast_mode = options_.isRelaxMode();
|
||||||
|
if (is_fast_mode) {
|
||||||
|
VSILOGW("Important notice: float model executed in bfloat16 "
|
||||||
|
"mode which will have better performance but lower precesion");
|
||||||
|
}
|
||||||
|
vsi_nn_SetGraphFastMode(graph_, is_fast_mode);
|
||||||
|
|
||||||
std::call_once(setio_once_, [&status, this]() {
|
std::call_once(setio_once_, [&status, this]() {
|
||||||
status = (vsi_nn_SetGraphInputs(this->graph_, this->inputs_.data(),
|
status = (vsi_nn_SetGraphInputs(this->graph_, this->inputs_.data(),
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "tim/vx/tensor.h"
|
#include "tim/vx/tensor.h"
|
||||||
|
#include "tim/vx/compile_option.h"
|
||||||
#include "context_private.h"
|
#include "context_private.h"
|
||||||
|
|
||||||
#include "vsi_nn_pub.h"
|
#include "vsi_nn_pub.h"
|
||||||
|
|
@ -40,7 +41,7 @@ namespace vx {
|
||||||
|
|
||||||
class GraphImpl : public Graph {
|
class GraphImpl : public Graph {
|
||||||
public:
|
public:
|
||||||
GraphImpl(ContextImpl* context);
|
GraphImpl(ContextImpl* context, const CompileOption& options = CompileOption::DefaultOptions);
|
||||||
~GraphImpl();
|
~GraphImpl();
|
||||||
|
|
||||||
/// Return the low-level graph object
|
/// Return the low-level graph object
|
||||||
|
|
@ -88,6 +89,8 @@ class GraphImpl : public Graph {
|
||||||
std::vector<std::shared_ptr<Tensor>> outputs_tensor_;
|
std::vector<std::shared_ptr<Tensor>> outputs_tensor_;
|
||||||
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>> tensor_consumers_;
|
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>> tensor_consumers_;
|
||||||
std::map<std::shared_ptr<Tensor>, std::shared_ptr<Operation>> tensor_producer_;
|
std::map<std::shared_ptr<Tensor>, std::shared_ptr<Operation>> tensor_producer_;
|
||||||
|
|
||||||
|
CompileOption options_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue