From 315adcf07676f23567f2693f31758d5cc8b9ee1f Mon Sep 17 00:00:00 2001 From: MercuryChen <52600596+MercuryChen@users.noreply.github.com> Date: Wed, 19 Jul 2023 18:40:48 +0800 Subject: [PATCH] Integrate api trace into tim-vx source as an experimental feature. (#623) * Add support for different input dtype of MaxPoolGrad. Type: Code improvement * Integrate api trace into tim-vx source code, as part of experimeantal. Type: New Feature --- .gitignore | 6 + CMakeLists.txt | 1 + include/tim/experimental/trace/README.md | 82 ++ include/tim/experimental/trace/trace_tvx.h | 31 + include/tim/experimental/trace/tracer.h | 1027 ++++++++++++++++++ include/tim/experimental/trace/tvx/context.h | 58 + include/tim/experimental/trace/tvx/graph.h | 360 ++++++ include/tim/experimental/trace/tvx/ops.h | 323 ++++++ include/tim/experimental/trace/tvx/tensor.h | 246 +++++ include/tim/experimental/trace/tvx/types.h | 56 + include/tim/vx/tensor.h | 8 + src/tim/CMakeLists.txt | 8 +- src/tim/vx/graph_test.cc | 62 ++ src/tim/vx/ops/maxpoolgrad.cc | 66 +- 14 files changed, 2303 insertions(+), 31 deletions(-) create mode 100755 include/tim/experimental/trace/README.md create mode 100755 include/tim/experimental/trace/trace_tvx.h create mode 100755 include/tim/experimental/trace/tracer.h create mode 100755 include/tim/experimental/trace/tvx/context.h create mode 100755 include/tim/experimental/trace/tvx/graph.h create mode 100755 include/tim/experimental/trace/tvx/ops.h create mode 100755 include/tim/experimental/trace/tvx/tensor.h create mode 100755 include/tim/experimental/trace/tvx/types.h diff --git a/.gitignore b/.gitignore index d14405d..932e76e 100644 --- a/.gitignore +++ b/.gitignore @@ -338,3 +338,9 @@ ASALocalRun/ build/ *_build/ .DS_Store + +# experimental api trace +trace_log.cc +trace_bin.bin +trace_log.rpl.cc +trace_bin.rpl.bin diff --git a/CMakeLists.txt b/CMakeLists.txt index c9681f7..7c12000 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,7 @@ option(TIM_VX_ENABLE_PLATFORM_LITE "Enable lite multi-device support" option(TIM_VX_ENABLE_GRPC "Enable gPRC support" OFF) option(TIM_VX_DBG_ENABLE_TENSOR_HNDL "Enable built-in tensor from handle: use malloced memory instead of VideoMemory by kernel driver" ON) option(TIM_VX_ENABLE_TENSOR_CACHE "Enable tensor cache for const tensor" OFF) +option(TIM_VX_ENABLE_API_TRACE "Enable trace and replay of graph apis" OFF) set(CMAKE_CXX_STANDARD 14) set(CMAKE_POSITION_INDEPENDENT_CODE ON) diff --git a/include/tim/experimental/trace/README.md b/include/tim/experimental/trace/README.md new file mode 100755 index 0000000..58d4b04 --- /dev/null +++ b/include/tim/experimental/trace/README.md @@ -0,0 +1,82 @@ +# ApiTracer - Header only Cpp OO programs trace and replay tool + +ApiTracer is a header only library provides macros and template functions to trace the C++ object-oriented programs, it not only trace the call-stacks, but also trace the C++ Apis runtime parameters. With the trace log and binary, it's convenient to replay the program execute scene, which helps developer to reproduce bugs and debug. + +## Coding work +ApiTracer was implemented by warp original Apis in traced Apis, they got same class names, same function names, but different namespace. So you need implement the traced Apis for specific programs at first. + +### Examples of implement traced class +Original TensorSpec: +``` C++ +namespace tvx { +struct TensorSpec { + TensorSpec() {} + TensorSpec(DataType datatype, const ShapeType& shape, TensorAttribute attr) + : datatype_(datatype), shape_(shape), attr_(attr) {} + + TensorSpec(DataType datatype, const ShapeType& shape, TensorAttribute attr, + const Quantization& quantization) + : TensorSpec(datatype, shape, attr) { + this->quantization_ = quantization; + } + + TensorSpec(const TensorSpec& other); + + TensorSpec& operator=(const TensorSpec& other); + + TensorSpec& SetDataType(DataType datatype); + + TensorSpec& SetShape(ShapeType& shape); + + TensorSpec& SetAttribute(TensorAttribute attr); + ... +}; + +} /* namespace tvx */ +``` +Define traced TensorSpec +``` C++ +namespace trace { +// all traced classes should derive from TraceClassBase +struct TensorSpec : public TraceClassBase { + DEF_CONSTRUCTOR(TensorSpec) + + DEF_CONSTRUCTOR(TensorSpec, ((DataType)) + ((const ShapeType&)) + ((TensorAttribute)) + ) + + DEF_CONSTRUCTOR(TensorSpec, ((DataType)) + ((const ShapeType&)) + ((TensorAttribute)) + ((const Quantization&)) + ) + + DEF_CONSTRUCTOR(TensorSpec, ((const TensorSpec&)) + ) + + DEF_TRACED_API(TensorSpec&, operator=) + + DEF_TRACED_API(TensorSpec&, SetDataType) + + DEF_TRACED_API(TensorSpec&, SetShape) + + DEF_TRACED_API(TensorSpec&, SetAttribute) + ... +}; +} +``` + +### Macros definition +- DEF_CONSTRUCTOR(class_name, optional_arguments_description, optional_macro_to_handle_pointer): +arguments_description must follow the format `((arg0))((arg1)(arg1_default_value))` +- DEF_TRACED_API(return_type, member_function_name, optional_lambda_to_handle_pointer): +this macro can define most member functions, except: + - function with default parameters + - static functions + - readonly functions + +- DEF_MEMFN_SP(return_type_removed_shared_pointer, member_function_name,optional_arguments_description, optional_macro_to_handle_pointer): +the macro can define those functions with `shared_ptr` +- DEF_MEMFN(return_type, member_function_name, optional_arguments_description, optional_macro_to_handle_pointer) +- DEF_INPLACE_MEMFN(member_function_name, optional_arguments_description, optional_macro_to_handle_pointer) diff --git a/include/tim/experimental/trace/trace_tvx.h b/include/tim/experimental/trace/trace_tvx.h new file mode 100755 index 0000000..a9473f0 --- /dev/null +++ b/include/tim/experimental/trace/trace_tvx.h @@ -0,0 +1,31 @@ +/**************************************************************************** +* +* Copyright (c) 2020-2023 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_EXPERIMENTAL_TRACE_TRACE_TVX_H_ +#define TIM_EXPERIMENTAL_TRACE_TRACE_TVX_H_ +#include "tim/experimental/trace/tvx/context.h" +#include "tim/experimental/trace/tvx/graph.h" +#include "tim/experimental/trace/tvx/ops.h" +#include "tim/experimental/trace/tvx/tensor.h" +#include "tim/experimental/trace/tvx/types.h" +#endif // TIM_EXPERIMENTAL_TRACE_TRACE_TVX_H_ diff --git a/include/tim/experimental/trace/tracer.h b/include/tim/experimental/trace/tracer.h new file mode 100755 index 0000000..7fbe5b6 --- /dev/null +++ b/include/tim/experimental/trace/tracer.h @@ -0,0 +1,1027 @@ +/**************************************************************************** +* +* Copyright (c) 2020-2023 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_EXPERIMENTAL_TRACE_TRACER_H_ +#define TIM_EXPERIMENTAL_TRACE_TRACER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#define TRACE_LOG_NAME_ "trace_log.cc" +#define TRACE_BIN_FILE_ "trace_bin.bin" +#define TRACE_LOG_FILE_REPLAY_ "trace_log.rpl.cc" +#define TRACE_BIN_FILE_REPLAY_ "trace_bin.rpl.bin" +#define TRACE_PREFIX_ENV_VAR_ "TRACE_DUMP_PREFIX" + +#define TCLOGE(fmt, ...) do { \ + printf("[ERROR] [%s:%s:%d]" fmt, __FILE__, __FUNCTION__, __LINE__, \ + ##__VA_ARGS__); \ + fflush(stdout); \ + } while (0) + + +/********************** definitions of extra type traits **********************/ +namespace trace { + +template +using void_t = void; + +template +struct is_fundamental_vector : std::false_type {}; + +template +struct is_fundamental_vector> { + static constexpr bool value = std::is_fundamental::value; +}; + +template +struct is_fundamental_array : std::false_type {}; + +template +struct is_fundamental_array> { + static constexpr bool value = std::is_fundamental::value; +}; + +template +struct is_fundamental_pointer : std::integral_constant::value && + std::is_fundamental>::value> {}; + +template +struct is_traced_obj : std::false_type {}; + +template +struct is_traced_obj().TraceGetObjName())>> + : std::true_type {}; + +template +struct is_traced_obj_ptr : std::false_type {}; + +template +struct is_traced_obj_ptr()->TraceGetObjName())>> + : std::true_type {}; + +template +struct is_traced_obj_ptr_vector : std::false_type {}; + +template +struct is_traced_obj_ptr_vector> { + static constexpr bool value = is_traced_obj_ptr::value; +}; + +template +struct is_others_type : std::integral_constant>::value && + !is_fundamental_array>::value && + !std::is_enum>::value && + !std::is_fundamental>::value && + !is_traced_obj>::value && + !is_traced_obj_ptr>::value && + !is_traced_obj_ptr_vector>::value> {}; + +template +struct is_not_traced_obj_like : std::integral_constant>::value && + !is_traced_obj_ptr>::value && + !is_traced_obj_ptr_vector>::value> {}; + +} /* namespace trace */ + +/*************************** definition of replayer ***************************/ +namespace trace { + +class Replayer { + static FILE* file_trace_bin_; + static FILE* open_file(const char* file_name); + public: + template + static std::vector get_vector(uint32_t offset, size_t vec_size); + + template + static std::vector get_vector_from( + const char* file_path, uint32_t offset, size_t vec_size); + + template + static std::array get_array(uint32_t offset, size_t vec_size); +}; + +// #define API_REPLAYER_IMPLEMENTATION +#ifdef API_REPLAYER_IMPLEMENTATION +FILE* Replayer::file_trace_bin_ = Replayer::open_file(TRACE_BIN_FILE_REPLAY_); +/* static */ FILE* Replayer::open_file(const char* file_name) { + char* prefix = getenv(TRACE_PREFIX_ENV_VAR_); + FILE* fp; + char path[1024] = {0}; + if (prefix != NULL) { + strcpy(path, prefix); + strcat(path, file_name); + } else { + strcpy(path, file_name); + } + fp = fopen(path, "r"); + if (!fp) { + TCLOGE("Can not open file at: %s\n", path); + } + return fp; +} + +template +/* static */ std::vector Replayer::get_vector( + uint32_t offset, size_t vec_size) { + std::vector ret_vec; + if (!file_trace_bin_) { + TCLOGE("FILE pointer is NULL!\n"); + } else { + T* buffer = new T[vec_size]; + fseek(file_trace_bin_, offset, SEEK_SET); + if (fread(buffer, sizeof(T), vec_size, file_trace_bin_) == vec_size) { + ret_vec.assign(buffer, buffer + vec_size); + } else { + TCLOGE("Read binary data failed!\n"); + } + delete[] buffer; + } + return ret_vec; +} + +template +/* static */ std::array Replayer::get_array( + uint32_t offset, size_t vec_size) { + std::vector ret_vec = get_vector(offset, vec_size); + std::array ret_arr; + std::copy_n(ret_vec.begin(), ret_vec.size(), ret_arr.begin()); + return ret_arr; +} + +template +/* static */ std::vector Replayer::get_vector_from( + const char* file_path,uint32_t offset, size_t vec_size) { + FILE* external_file = fopen(file_path, "r"); + if (!external_file) { + TCLOGE("Can not open file at: %s\n", file_path); + } + std::vector ret_vec; + T* buffer = new T[vec_size]; + fseek(external_file, offset, SEEK_SET); + if (fread(buffer, sizeof(T), vec_size, external_file) == vec_size) { + ret_vec.assign(buffer, buffer + vec_size); + } else { + TCLOGE("Read binary data failed!\n"); + } + delete[] buffer; + return ret_vec; +} + +#endif /* #ifdef API_TRACER_IMPLEMENTATION */ + +} /* namespace trace */ + +/**************************** definition of tracer ****************************/ +namespace trace { + +class Tracer { + static std::unordered_map obj_names_; + static std::unordered_map objs_prefix_; + static std::vector params_log_cache_; + static std::list msg_cache_; + static std::unordered_map target2trace_map_; + static FILE* file_trace_log_; + static FILE* file_trace_bin_; + static std::mutex log_mtx_; + static std::mutex bin_mtx_; + + static FILE* open_file(const char* file_name); + + public: + static void logging_msg(const char* format, ...); + + static uint32_t dump_data(const void* data, size_t byte_size, size_t count); + + static std::string allocate_obj_name(const std::string& prefix = "obj_"); + + static inline void insert_obj_name( + const void* obj_ptr, const std::string& obj_name) { + obj_names_[obj_ptr] = obj_name; + } + + static inline void insert_traced_obj(const void* p_target, void* p_traced) { + target2trace_map_.insert({p_target, p_traced}); + } + + static inline void* get_traced_obj(const void* p_target) { + return target2trace_map_[p_target]; + } + + static inline std::string& get_obj_name(const void* obj) { + return obj_names_[obj]; + } + + static inline std::string& get_obj_prefix(const std::string cls_name) { + return objs_prefix_[cls_name]; + } + + static inline void push_back_msg_cache(const std::string& msg) { + msg_cache_.push_back(msg); + } + + static inline void amend_last_msg_cache(const std::string& msg) { + if (msg_cache_.empty()) { + TCLOGE("Can't amend sub_msg, because msg cache is empty!\n"); + } + msg_cache_.back() += msg; + } + + static inline void insert_before_last_msg_cache(const std::string& msg) { + msg_cache_.insert(--msg_cache_.end(), msg); + } + + static inline void msg_cache_sync_to_file() { + while (!msg_cache_.empty()) { + logging_msg(msg_cache_.front().c_str()); + msg_cache_.pop_front(); + } + } + + static inline void clear_params_log_cache() { + params_log_cache_.clear(); + } + + static inline void init_params_log_cache(uint32_t params_size) { + params_log_cache_.clear(); + params_log_cache_.resize(params_size); + } + + static inline void append_params_log_cache(std::string param_log) { + params_log_cache_.push_back(param_log); + } + + static inline void insert_params_log_cache( + std::string param_log, uint32_t idx) { + if (idx != static_cast(-1)) { + params_log_cache_[idx] = param_log; + } else { + params_log_cache_.push_back(param_log); + } + } + + // pop the log of params into msg cache + static inline void pop_params_log_cache() { + if (params_log_cache_.size() == 0) return; + for (uint32_t i = 0; i < params_log_cache_.size() - 1; i++) { + amend_last_msg_cache(params_log_cache_[i] + ", "); + } + amend_last_msg_cache(params_log_cache_.back()); + } + + // directly dump the log of params to file + static inline void dump_params_log_cache() { + if (params_log_cache_.size() == 0) return; + for (uint32_t i = 0; i < params_log_cache_.size() - 1; i++) { + logging_msg("%s, ", params_log_cache_[i].c_str()); + } + logging_msg(params_log_cache_.back().c_str()); + } + + /* + * template functions for logging parameters to log file + */ + // default substitution + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wunused-parameter" + template ::value, int> = 0> + static inline void logging_param(const T& t, uint32_t idx) { + insert_params_log_cache(std::string(), idx); + } + #pragma GCC diagnostic pop + + // enable if T is fundamental std::vector + template >::value, int> = 0> + static inline void logging_param(const T& t, uint32_t idx) { + uint32_t offset = dump_data(t.data(), sizeof(t[0]), t.size()); + std::string element_type = + boost::typeindex::type_id().pretty_name(); + char log_msg[1024] = {0}; + snprintf(log_msg, 1024, "trace::Replayer::get_vector<%s>(%u, %u)", + element_type.c_str(), offset, (uint32_t)t.size()); + insert_params_log_cache(std::string(log_msg), idx); + } + + // enable if T is fundamental std::array + template >::value, int> = 0> + static inline void logging_param(const T& t, uint32_t idx) { + uint32_t offset = dump_data(t.data(), sizeof(t[0]), t.size()); + std::string element_type = + boost::typeindex::type_id().pretty_name(); + char log_msg[1024] = {0}; + snprintf(log_msg, 1024, "trace::Replayer::get_array<%s, %d>(%u, %u)", + element_type.c_str(), (uint32_t)t.size(), offset, + (uint32_t)t.size()); + insert_params_log_cache(std::string(log_msg), idx); + } + + // enable if T is enum + template >::value, int> = 0> + static inline void logging_param(const T& t, uint32_t idx) { + std::string enum_type = + boost::typeindex::type_id().pretty_name(); + char log_msg[1024] = {0}; + snprintf(log_msg, 1024, "(%s)%d", enum_type.c_str(), (int)t); + insert_params_log_cache(std::string(log_msg), idx); + } + + // enable if T is fundamental + template >::value, int> = 0> + static inline void logging_param(const T& t, uint32_t idx) { + insert_params_log_cache(std::to_string(t), idx); + } + + // enable if T is derive from TraceClassBase + template >::value, int> = 0> + static inline void logging_param(const T& t, uint32_t idx) { + insert_params_log_cache(t.TraceGetObjName(), idx); + } + + // enable if T is pointer to object which + // derive from TraceClassBase + template >::value, int> = 0> + static inline void logging_param(const T& t, uint32_t idx) { + insert_params_log_cache(t->TraceGetObjName(), idx); + } + + // enable if T is vector of pointer to object which + // derive from TraceClassBase + template >::value, int> = 0> + static inline void logging_param(const T& t, uint32_t idx) { + std::stringstream ss; + ss << "{"; + for (uint32_t i = 0; i < t.size() - 1; i++) { + ss << t[i]->TraceGetObjName() << ", "; + } + ss << t.back()->TraceGetObjName() << "}"; + insert_params_log_cache(ss.str(), idx); + } + + /* + * template functions for pass correct parameters to api implementation + */ + // default substitution + template ::value, int> = 0> + static inline T&& proc_param(T&& t) { + return std::forward(t); + } + + // enable if T is derive from TraceClassBase + template >::value, int> = 0> + static inline decltype(std::declval().TraceGetImpl())&& proc_param( + T&& t) { + return std::forward(t).TraceGetImpl(); + } + + // enable if T is pointer to object which + // derive from TraceClassBase + template >::value, int> = 0> + static inline decltype(std::declval()->TraceGetImplSp()) proc_param( + T&& t) { + return std::forward(t)->TraceGetImplSp(); + } + + // enable if T is std::vector which element is pointer of traced object + template >::value, int> = 0> + static inline std::vector()[0]->TraceGetImplSp())> + proc_param(T&& t) { + std::vectorTraceGetImplSp())> impl_vec; + for (auto& x : std::forward(t)) { + impl_vec.emplace_back(x->TraceGetImplSp()); + } + return impl_vec; + } +}; + +#ifdef API_TRACER_IMPLEMENTATION +std::unordered_map Tracer::obj_names_; +std::vector Tracer::params_log_cache_; +std::list Tracer::msg_cache_; +std::unordered_map Tracer::target2trace_map_; +FILE* Tracer::file_trace_log_ = Tracer::open_file(TRACE_LOG_NAME_); +FILE* Tracer::file_trace_bin_ = Tracer::open_file(TRACE_BIN_FILE_); +std::mutex Tracer::log_mtx_; +std::mutex Tracer::bin_mtx_; +std::unordered_map Tracer::objs_prefix_ = { + {"Quantization", "quant_" }, + {"TensorSpec", "spec_" }, + {"Tensor", "tensor_" }, + {"Graph", "graph_" } +}; +/* static */ FILE* Tracer::open_file(const char* file_name) { + char* prefix = getenv(TRACE_PREFIX_ENV_VAR_); + FILE* fp; + char path[1024] = {0}; + if (prefix != NULL) { + strcpy(path, prefix); + strcat(path, file_name); + } else { + strcpy(path, file_name); + } + fp = fopen(path, "w"); + if (!fp) { + TCLOGE("Can not open file at: %s\n", path); + } + return fp; +} + +/* static */ std::string Tracer::allocate_obj_name( + const std::string& prefix) { + static std::unordered_map objects_counter; + if (objects_counter.find(prefix) == objects_counter.end()) { + objects_counter[prefix] = 0; + } + return prefix + std::to_string(objects_counter[prefix]++); +} + +/* static */ void Tracer::logging_msg(const char* format, ...) { + char arg_buffer[1024] = {0}; + va_list args; + va_start(args, format); + vsnprintf(arg_buffer, 1024, format, args); + va_end(args); + // printf("%s", arg_buffer); + if (!file_trace_log_) { + TCLOGE("Log file do not exist\n"); + } + std::lock_guard lock(log_mtx_); + fprintf(file_trace_log_, "%s", arg_buffer); + fflush(file_trace_log_); +} + +/* static */ uint32_t Tracer::dump_data( + const void* data, size_t byte_size, size_t count) { + std::lock_guard lock(bin_mtx_); + if (fwrite(data, byte_size, count, file_trace_bin_) != count) { + TCLOGE("Write trace binary data failed!\n"); + } + fflush(file_trace_bin_); + static uint32_t offset = 0; + uint32_t temp = offset; + offset += byte_size * count; + return temp; +} +#endif /* #ifdef API_TRACER_IMPLEMENTATION */ + +} /* namespace trace */ + +/************************ definition of TraceClassBase ************************/ +namespace trace { + +template +struct TraceClassBase { + TargetClass& TraceGetImpl() const { return *impl_; } + + // temporary return rvalue to prevent compile error + std::shared_ptr TraceGetImplSp() { return impl_; } + + std::string& TraceGetObjName() const { + return Tracer::get_obj_name(static_cast(impl_.get())); + } + + std::shared_ptr impl_; + // static const char* target_namespace_name_; +}; + +// #ifdef TARGET_NAMESPACE_NAME +// template +// const char* TraceClassBase::target_namespace_name_ = +// TARGET_NAMESPACE_NAME; +// #endif // #ifdef TARGET_NAMESPACE_NAME + +} /* namespace trace */ + +#ifdef TARGET_NAMESPACE_NAME +static const char* target_namespace_name_ = TARGET_NAMESPACE_NAME; +#endif // #ifdef TARGET_NAMESPACE_NAME + +/**************************** definitions of macros ***************************/ +#define LOG_PARAM_IMPL_(r, _, i, param) \ + Tracer::logging_param(param, i); + +#define LOG_PARAMS(params) \ + Tracer::init_params_log_cache(BOOST_PP_SEQ_SIZE(params)); \ + BOOST_PP_SEQ_FOR_EACH_I(LOG_PARAM_IMPL_, _, params) + +#define PROC_PARAM_IMPL_COMMA_(r, _, param) \ + Tracer::proc_param(param), + +#define PROC_PARAM_IMPL_NO_COMMA_(param) \ + Tracer::proc_param(param) + +#define PROC_SINGLE_PARAM_(params) \ + PROC_PARAM_IMPL_NO_COMMA_(BOOST_PP_SEQ_ELEM(0, params)) + +#define PROC_MULTI_PARAMS_(params) \ + BOOST_PP_SEQ_FOR_EACH( \ + PROC_PARAM_IMPL_COMMA_, _, \ + BOOST_PP_SEQ_SUBSEQ(params, 0, BOOST_PP_DEC(BOOST_PP_SEQ_SIZE(params)))) \ + PROC_PARAM_IMPL_NO_COMMA_( \ + BOOST_PP_SEQ_ELEM(BOOST_PP_DEC(BOOST_PP_SEQ_SIZE(params)), params)) + +#define PROC_PARAMS(params) \ + BOOST_PP_IF(BOOST_PP_EQUAL(BOOST_PP_SEQ_SIZE(params), 1), \ + PROC_SINGLE_PARAM_, PROC_MULTI_PARAMS_)(params) + + +#define NAME_A_PARAM_(r, data, i, elem) (param_##i) + +#define ARGS_DESC_TO_PARAMS(descs) \ + BOOST_PP_SEQ_FOR_EACH_I(NAME_A_PARAM_, _, descs) + +#define IS_WITH_DEFAULT_VAL_(desc) \ + BOOST_PP_EQUAL(BOOST_PP_SEQ_SIZE(desc), 2) + +#define SET_DEFAULT_VAL_(val) = BOOST_PP_SEQ_HEAD(val) + +#define DO_NOTHING(x) + +#define DECLARE_AN_ARG_COMMA_(r, names, i, desc) \ + BOOST_PP_SEQ_HEAD(desc) BOOST_PP_SEQ_ELEM(i, names) \ + BOOST_PP_IF(IS_WITH_DEFAULT_VAL_(desc), SET_DEFAULT_VAL_, DO_NOTHING) \ + (BOOST_PP_SEQ_TAIL(desc)), + +#define DECLARE_AN_ARG_NO_COMMA_(name, desc) \ + BOOST_PP_SEQ_HEAD(desc) name \ + BOOST_PP_IF(IS_WITH_DEFAULT_VAL_(desc), SET_DEFAULT_VAL_, DO_NOTHING) \ + (BOOST_PP_SEQ_TAIL(desc)) + +#define SINGLE_ARG_DESC_TO_DECLARATION_(desc) \ + DECLARE_AN_ARG_NO_COMMA_( \ + BOOST_PP_SEQ_ELEM(0, ARGS_DESC_TO_PARAMS(desc)), \ + BOOST_PP_SEQ_ELEM(0, desc)) + +#define MULTI_ARGS_DESC_TO_DECLARATION_(descs) \ + BOOST_PP_SEQ_FOR_EACH_I(DECLARE_AN_ARG_COMMA_, \ + BOOST_PP_SEQ_SUBSEQ(ARGS_DESC_TO_PARAMS(descs), \ + 0, BOOST_PP_DEC(BOOST_PP_SEQ_SIZE(descs))), \ + BOOST_PP_SEQ_SUBSEQ(descs, 0, \ + BOOST_PP_DEC(BOOST_PP_SEQ_SIZE(descs)))) \ + DECLARE_AN_ARG_NO_COMMA_( \ + BOOST_PP_SEQ_ELEM(BOOST_PP_DEC(BOOST_PP_SEQ_SIZE(descs)), \ + ARGS_DESC_TO_PARAMS(descs)), \ + BOOST_PP_SEQ_ELEM(BOOST_PP_DEC(BOOST_PP_SEQ_SIZE(descs)), descs)) + +#define ARGS_DESC_TO_DECLARATION(descs) \ + BOOST_PP_IF(BOOST_PP_EQUAL(BOOST_PP_SEQ_SIZE(descs), 1), \ + SINGLE_ARG_DESC_TO_DECLARATION_, \ + MULTI_ARGS_DESC_TO_DECLARATION_)(descs) + +#define TO_VARIADIC_IMPL_COMMA_(r, _, elem) elem, +#define TO_VARIADIC_IMPL_NO_COMMA_(elem) elem + +#define SEQ_TO_VARIADICS(seqs) \ + BOOST_PP_SEQ_FOR_EACH(TO_VARIADIC_IMPL_COMMA_, _, \ + BOOST_PP_SEQ_SUBSEQ(seqs, 0, BOOST_PP_DEC(BOOST_PP_SEQ_SIZE(seqs)))) \ + TO_VARIADIC_IMPL_NO_COMMA_( \ + BOOST_PP_SEQ_ELEM(BOOST_PP_DEC(BOOST_PP_SEQ_SIZE(seqs)), seqs)) + +#define DEF_MEMFN_SP_2_(ret_class, api_name) \ + std::shared_ptr api_name() { \ + std::string this_obj_name = TraceGetObjName(); \ + std::string obj_name = \ + Tracer::allocate_obj_name(Tracer::get_obj_prefix(#ret_class)); \ + Tracer::logging_msg("auto %s = %s->%s();\n", obj_name.c_str(), \ + this_obj_name.c_str(), __FUNCTION__); \ + auto obj_impl_sp = impl_->api_name(); \ + auto obj_sp = std::make_shared(obj_impl_sp); \ + Tracer::insert_obj_name(static_cast(obj_impl_sp.get()), obj_name); \ + return obj_sp; \ + } + +#define DEF_MEMFN_SP_3_(ret_class, api_name, args_desc) \ + std::shared_ptr api_name(ARGS_DESC_TO_DECLARATION(args_desc)) { \ + std::string this_obj_name = TraceGetObjName(); \ + std::string obj_name = \ + Tracer::allocate_obj_name(Tracer::get_obj_prefix(#ret_class)); \ + Tracer::logging_msg("auto %s = %s->%s(", obj_name.c_str(), \ + this_obj_name.c_str(), __FUNCTION__); \ + LOG_PARAMS(ARGS_DESC_TO_PARAMS(args_desc)) \ + Tracer::dump_params_log_cache(); \ + Tracer::logging_msg(");\n"); \ + auto obj_impl_sp = impl_->api_name( \ + PROC_PARAMS(ARGS_DESC_TO_PARAMS(args_desc))); \ + auto obj_sp = std::make_shared(obj_impl_sp); \ + Tracer::insert_obj_name(static_cast(obj_impl_sp.get()), obj_name); \ + return obj_sp; \ + } + +#define DEF_MEMFN_SP_4_(ret_class, api_name, args_desc, SPECIAL_MACRO_) \ + std::shared_ptr api_name(ARGS_DESC_TO_DECLARATION(args_desc)) { \ + std::string this_obj_name = TraceGetObjName(); \ + std::string obj_name = \ + Tracer::allocate_obj_name(Tracer::get_obj_prefix(#ret_class)); \ + Tracer::push_back_msg_cache("auto " + obj_name + " = " + this_obj_name \ + + "->" + __FUNCTION__ + "("); \ + LOG_PARAMS(ARGS_DESC_TO_PARAMS(args_desc)) \ + SPECIAL_MACRO_(ARGS_DESC_TO_PARAMS(args_desc)) \ + Tracer::pop_params_log_cache(); \ + Tracer::amend_last_msg_cache(");\n"); \ + Tracer::msg_cache_sync_to_file(); \ + auto obj_impl_sp = impl_->api_name( \ + PROC_PARAMS(ARGS_DESC_TO_PARAMS(args_desc))); \ + auto obj_sp = std::make_shared(obj_impl_sp); \ + Tracer::insert_obj_name(static_cast(obj_impl_sp.get()), obj_name); \ + return obj_sp; \ + } + +#define DEF_SIMPLE_UNTRACED_API(retval, api_name) \ + retval api_name() { \ + return impl_->api_name(); \ + } + +#define DEF_MEMFN_2_(retval, api_name) \ + retval api_name() { \ + std::string this_obj_name = TraceGetObjName(); \ + Tracer::logging_msg("%s->%s();\n", \ + this_obj_name.c_str(), __FUNCTION__); \ + return impl_->api_name(); \ + } + +#define DEF_MEMFN_3_(retval, api_name, args_desc) \ + retval api_name(ARGS_DESC_TO_DECLARATION(args_desc)) { \ + std::string this_obj_name = TraceGetObjName(); \ + Tracer::logging_msg("%s->%s(", \ + this_obj_name.c_str(), __FUNCTION__); \ + LOG_PARAMS(ARGS_DESC_TO_PARAMS(args_desc)) \ + Tracer::dump_params_log_cache(); \ + Tracer::logging_msg(");\n"); \ + return impl_->api_name( \ + PROC_PARAMS(ARGS_DESC_TO_PARAMS(args_desc))); \ + } + +#define DEF_MEMFN_4_(retval, api_name, args_desc, SPECIAL_MACRO_) \ + retval api_name(ARGS_DESC_TO_DECLARATION(args_desc)) { \ + std::string this_obj_name = TraceGetObjName(); \ + Tracer::push_back_msg_cache( \ + this_obj_name + "->" + __FUNCTION__ + "("); \ + LOG_PARAMS(ARGS_DESC_TO_PARAMS(args_desc)) \ + SPECIAL_MACRO_(ARGS_DESC_TO_PARAMS(args_desc)) \ + Tracer::pop_params_log_cache(); \ + Tracer::amend_last_msg_cache(");\n"); \ + Tracer::msg_cache_sync_to_file(); \ + return impl_->api_name( \ + PROC_PARAMS(ARGS_DESC_TO_PARAMS(args_desc))); \ + } + +#define DEF_INPLACE_MEMFN_2_(retval, api_name) \ + retval api_name() { \ + std::string this_obj_name = TraceGetObjName(); \ + Tracer::logging_msg("%s->%s();\n", \ + this_obj_name.c_str(), __FUNCTION__); \ + impl_->api_name(); \ + return *this; \ + } + +#define DEF_INPLACE_MEMFN_3_(retval, api_name, args_desc) \ + retval api_name(ARGS_DESC_TO_DECLARATION(args_desc)) { \ + std::string this_obj_name = TraceGetObjName(); \ + Tracer::logging_msg("%s->%s(", \ + this_obj_name.c_str(), __FUNCTION__); \ + LOG_PARAMS(ARGS_DESC_TO_PARAMS(args_desc)) \ + Tracer::dump_params_log_cache(); \ + Tracer::logging_msg(");\n"); \ + impl_->api_name(PROC_PARAMS(ARGS_DESC_TO_PARAMS(args_desc))); \ + return *this; \ + } + +#define DEF_INPLACE_MEMFN_4_(retval, api_name, args_desc, SPECIAL_MACRO_) \ + retval api_name(ARGS_DESC_TO_DECLARATION(args_desc)) { \ + std::string this_obj_name = TraceGetObjName(); \ + Tracer::push_back_msg_cache( \ + this_obj_name + "->" + __FUNCTION__ + "("); \ + LOG_PARAMS(ARGS_DESC_TO_PARAMS(args_desc)) \ + SPECIAL_MACRO_(ARGS_DESC_TO_PARAMS(args_desc)) \ + Tracer::pop_params_log_cache(); \ + Tracer::amend_last_msg_cache(");\n"); \ + Tracer::msg_cache_sync_to_file(); \ + impl_->api_name(PROC_PARAMS(ARGS_DESC_TO_PARAMS(args_desc))); \ + return *this; \ + } + +#define DEF_CONSTRUCTOR_1_(class_name) \ + class_name() { \ + std::string obj_name = \ + Tracer::allocate_obj_name(Tracer::get_obj_prefix(#class_name)); \ + Tracer::logging_msg("auto %s = %s::%s();\n", obj_name.c_str(), \ + target_namespace_name_, __FUNCTION__); \ + impl_ = std::make_shared(); \ + Tracer::insert_traced_obj( \ + static_cast(impl_.get()), static_cast(this)); \ + Tracer::insert_obj_name(static_cast(impl_.get()), obj_name); \ + } + +#define DEF_CONSTRUCTOR_2_(class_name, args_desc) \ + class_name(ARGS_DESC_TO_DECLARATION(args_desc)) { \ + std::string obj_name = \ + Tracer::allocate_obj_name(Tracer::get_obj_prefix(#class_name)); \ + Tracer::logging_msg("auto %s = %s::%s(", obj_name.c_str(), \ + target_namespace_name_, __FUNCTION__); \ + LOG_PARAMS(ARGS_DESC_TO_PARAMS(args_desc)) \ + Tracer::dump_params_log_cache(); \ + Tracer::logging_msg(");\n"); \ + impl_ = std::make_shared( \ + PROC_PARAMS(ARGS_DESC_TO_PARAMS(args_desc))); \ + Tracer::insert_traced_obj( \ + static_cast(impl_.get()), static_cast(this)); \ + Tracer::insert_obj_name(static_cast(impl_.get()), obj_name); \ + } + +#define DEF_CONSTRUCTOR_3_(class_name, args_desc, SPECIAL_MACRO_) \ + class_name(ARGS_DESC_TO_DECLARATION(args_desc)) { \ + std::string obj_name = \ + Tracer::allocate_obj_name(Tracer::get_obj_prefix(#class_name)); \ + Tracer::push_back_msg_cache( \ + "auto " + obj_name + " = " + target_namespace_name_ + "::" + \ + __FUNCTION__ + "("); \ + LOG_PARAMS(ARGS_DESC_TO_PARAMS(args_desc)) \ + SPECIAL_MACRO_(ARGS_DESC_TO_PARAMS(args_desc)) \ + Tracer::pop_params_log_cache(); \ + Tracer::amend_last_msg_cache(");\n"); \ + Tracer::msg_cache_sync_to_file(); \ + impl_ = std::make_shared( \ + PROC_PARAMS(ARGS_DESC_TO_PARAMS(args_desc))); \ + Tracer::insert_traced_obj( \ + static_cast(impl_.get()), static_cast(this)); \ + Tracer::insert_obj_name(static_cast(impl_.get()), obj_name); \ + } + + +/* + * Follows code to detect empty macro variadic is from: + * https://gustedt.wordpress.com/2010/06/08/detect-empty-macro-arguments/ + */ +#define __ARG16(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13,_14,\ + _15, ...) _15 +#define __HAS_COMMA(...) __ARG16(__VA_ARGS__, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, \ + 1, 1, 1, 0) +#define __TRIGGER_PARENTHESIS_(...) , +#define __PASTE5(_0, _1, _2, _3, _4) _0 ## _1 ## _2 ## _3 ## _4 +#define __IS_EMPTY_CASE_0001 , +#define __IS_EMPTY(_0, _1, _2, _3) __HAS_COMMA( \ + __PASTE5(__IS_EMPTY_CASE_, _0, _1, _2, _3)) + +#define TUPLE_IS_EMPTY(...) \ + __IS_EMPTY( \ + /* test if there is just one argument, eventually an empty one */ \ + __HAS_COMMA(__VA_ARGS__), \ + /* test if _TRIGGER_PARENTHESIS_ together with the argument adds a comma */\ + __HAS_COMMA(__TRIGGER_PARENTHESIS_ __VA_ARGS__), \ + /* test if the argument together with a parenthesis adds a comma */ \ + __HAS_COMMA(__VA_ARGS__ (/*empty*/)), \ + /* test if placing it between _TRIGGER_PARENTHESIS_ */ \ + /* and the parenthesis adds a comma */ \ + __HAS_COMMA(__TRIGGER_PARENTHESIS_ __VA_ARGS__ (/*empty*/)) \ + ) + +#define EMPTY_LAMBDA_ placeholder +// BOOST_PP_VARIADIC_SIZE() variadic is empty, but it's will expand to 1. +// And there are unknow issue of tensorflow build, ##__VA_ARGS__ can't correctly +// remove the comma before. +#define VARIADIC_SIZE_PLUS_ONE(...) \ + BOOST_PP_VARIADIC_SIZE(EMPTY_LAMBDA_, ##__VA_ARGS__) +#define INVOKE_LAMBDA_(r, _, lambda) lambda(); +#define INVOKE_LAMBDAS_(...) \ + BOOST_PP_IF(BOOST_PP_EQUAL(TUPLE_IS_EMPTY(__VA_ARGS__), 1), \ + BOOST_PP_EMPTY(), BOOST_PP_SEQ_FOR_EACH(INVOKE_LAMBDA_, _, \ + BOOST_PP_VARIADIC_TO_SEQ(__VA_ARGS__))) + +/* + * In substitution for traced obj, api_impl got pointer retval because if + * directly return reference will get build error of using pure virtual class + * constructor. + */ +#define DEF_TRACED_API(Ret, api_name, ...) \ + template \ + typename std::enable_if_t::value, R> api_name( \ + Args... params) { \ + std::string this_obj_name = TraceGetObjName(); \ + Tracer::push_back_msg_cache( \ + this_obj_name + "->" + __FUNCTION__ + "("); \ + Tracer::clear_params_log_cache(); \ + boost::hana::tuple params_tuple = {params...}; \ + boost::hana::for_each(params_tuple, [&] (auto x) { \ + Tracer::logging_param(x, -1); \ + }); \ + INVOKE_LAMBDAS_(__VA_ARGS__) \ + Tracer::pop_params_log_cache(); \ + Tracer::amend_last_msg_cache(");\n"); \ + Tracer::msg_cache_sync_to_file(); \ + auto api_impl = [&] (auto&&... ts) { \ + return &(impl_->api_name(ts...)); \ + }; \ + auto proc_param_lambda = [] (auto&& t) { \ + return Tracer::proc_param(t); \ + }; \ + auto params_impl = boost::hana::transform(params_tuple, proc_param_lambda);\ + auto ret_impl_p = boost::hana::unpack(params_impl, api_impl); \ + return *static_cast*>(Tracer::get_traced_obj(ret_impl_p)); \ + } \ + \ + template \ + typename std::enable_if_t::value, R> api_name( \ + Args... params) { \ + std::string this_obj_name = TraceGetObjName(); \ + std::string obj_name = \ + Tracer::allocate_obj_name(Tracer::get_obj_prefix( \ + boost::typeindex::type_id() \ + .pretty_name())); \ + Tracer::push_back_msg_cache( \ + this_obj_name + "->" + __FUNCTION__ + "("); \ + Tracer::clear_params_log_cache(); \ + boost::hana::tuple params_tuple = {params...}; \ + boost::hana::for_each(params_tuple, [&] (auto x) { \ + Tracer::logging_param(x, -1); \ + }); \ + INVOKE_LAMBDAS_(__VA_ARGS__) \ + Tracer::pop_params_log_cache(); \ + Tracer::amend_last_msg_cache(");\n"); \ + Tracer::msg_cache_sync_to_file(); \ + auto api_impl = [&] (auto&&... ts) { \ + return impl_->api_name(ts...); \ + }; \ + auto proc_param_lambda = [] (auto&& t) { \ + return Tracer::proc_param(t); \ + }; \ + auto params_impl = boost::hana::transform(params_tuple, proc_param_lambda);\ + auto obj_impl_sp = boost::hana::unpack(params_impl, api_impl); \ + R obj_sp(obj_impl_sp); \ + Tracer::insert_obj_name(static_cast(obj_impl_sp.get()), obj_name); \ + return obj_sp; \ + } \ + \ + template \ + typename std::enable_if_t::value, R> api_name( \ + Args... params) { \ + std::string this_obj_name = TraceGetObjName(); \ + Tracer::push_back_msg_cache( \ + this_obj_name + "->" + __FUNCTION__ + "("); \ + Tracer::clear_params_log_cache(); \ + boost::hana::tuple params_tuple = {params...}; \ + boost::hana::for_each(params_tuple, [&] (auto x) { \ + Tracer::logging_param(x, -1); \ + }); \ + INVOKE_LAMBDAS_(__VA_ARGS__) \ + Tracer::pop_params_log_cache(); \ + Tracer::amend_last_msg_cache(");\n"); \ + Tracer::msg_cache_sync_to_file(); \ + auto api_impl = [&] (auto&&... ts) { \ + return impl_->api_name(ts...); \ + }; \ + auto proc_param_lambda = [] (auto&& t) { \ + return Tracer::proc_param(t); \ + }; \ + auto params_impl = boost::hana::transform(params_tuple, proc_param_lambda);\ + auto ret_impl = boost::hana::unpack(params_impl, api_impl); \ + R ret; \ + for(auto& x : ret_impl) { \ + ret.push_back(std::make_shared( \ + *static_cast( \ + Tracer::get_traced_obj(x.get())))); \ + } \ + return ret; \ + } \ + \ + template \ + typename std::enable_if_t::value, R> api_name( \ + Args... params) { \ + std::string this_obj_name = TraceGetObjName(); \ + Tracer::push_back_msg_cache( \ + this_obj_name + "->" + __FUNCTION__ + "("); \ + Tracer::clear_params_log_cache(); \ + boost::hana::tuple params_tuple = {params...}; \ + boost::hana::for_each(params_tuple, [&] (auto x) { \ + Tracer::logging_param(x, -1); \ + }); \ + INVOKE_LAMBDAS_(__VA_ARGS__) \ + Tracer::pop_params_log_cache(); \ + Tracer::amend_last_msg_cache(");\n"); \ + Tracer::msg_cache_sync_to_file(); \ + auto api_impl = [&] (auto&&... ts) { \ + return impl_->api_name(ts...); \ + }; \ + auto proc_param_lambda = [] (auto&& t) { \ + return Tracer::proc_param(t); \ + }; \ + auto params_impl = boost::hana::transform(params_tuple, proc_param_lambda);\ + return boost::hana::unpack(params_impl, api_impl); \ + } + +#define GET_MACRO_OVERLOAD_4_(_1, _2, _3, _4, MACRO, ...) MACRO +#define GET_MACRO_OVERLOAD_3_(_1, _2, _3, MACRO, ...) MACRO + +#define DEF_MEMFN_SP(...) \ + GET_MACRO_OVERLOAD_4_(__VA_ARGS__, \ + DEF_MEMFN_SP_4_, \ + DEF_MEMFN_SP_3_, \ + DEF_MEMFN_SP_2_)(__VA_ARGS__) + +#define DEF_MEMFN(...) \ + GET_MACRO_OVERLOAD_4_(__VA_ARGS__, \ + DEF_MEMFN_4_, \ + DEF_MEMFN_3_, \ + DEF_MEMFN_2_)(__VA_ARGS__) + +#define DEF_INPLACE_MEMFN(...) \ + GET_MACRO_OVERLOAD_4_(__VA_ARGS__, \ + DEF_INPLACE_MEMFN_4_, \ + DEF_INPLACE_MEMFN_3_, \ + DEF_INPLACE_MEMFN_2_)(__VA_ARGS__) + +#define DEF_CONSTRUCTOR(...) \ + GET_MACRO_OVERLOAD_3_(__VA_ARGS__, \ + DEF_CONSTRUCTOR_3_, \ + DEF_CONSTRUCTOR_2_, \ + DEF_CONSTRUCTOR_1_)(__VA_ARGS__) + +#define LOGGING_POINTER_MSG(offset, length, idx) \ + char log_msg[1024] = {0}; \ + snprintf(log_msg, 1024, \ + "trace::Replayer::get_vector(%u, %u).data()", \ + offset, length); \ + Tracer::insert_params_log_cache(std::string(log_msg), idx); + +#define DEF_INTERFACE_CONSTRUCTOR(interface) \ + interface(const std::shared_ptr& impl) { \ + Tracer::insert_traced_obj(static_cast(impl.get()), \ + static_cast(this)); \ + impl_ = impl; \ + } + +#endif // TIM_EXPERIMENTAL_TRACE_TRACER_H_ diff --git a/include/tim/experimental/trace/tvx/context.h b/include/tim/experimental/trace/tvx/context.h new file mode 100755 index 0000000..3e9a42e --- /dev/null +++ b/include/tim/experimental/trace/tvx/context.h @@ -0,0 +1,58 @@ +/**************************************************************************** +* +* Copyright (c) 2020-2023 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_EXPERIMENTAL_TRACE_TVX_CONTEXT_H_ +#define TIM_EXPERIMENTAL_TRACE_TVX_CONTEXT_H_ +#include "tim/vx/context.h" +#include "tim/experimental/trace/tvx/graph.h" +#include "tim/experimental/trace/tracer.h" + +namespace trace { + +namespace target = ::tim::vx; + +struct Context : public TraceClassBase { + DEF_INTERFACE_CONSTRUCTOR(Context) + + DEF_MEMFN_SP(Graph, CreateGraph) + + DEF_TRACED_API(bool, isClOnly) + + static inline std::shared_ptr Create() { + std::string obj_name = Tracer::allocate_obj_name("ctx_"); + std::string pf(__PRETTY_FUNCTION__); + pf.replace(pf.rfind("trace"), 5, target_namespace_name_); + char log_msg[1024] = {0}; + snprintf(log_msg, 1024, "auto %s =%s;\n", obj_name.c_str(), + pf.substr(pf.rfind(" "), pf.size()).c_str()); + Tracer::logging_msg(log_msg); + auto obj = std::make_shared(target::Context::Create()); + Tracer::insert_obj_name(static_cast( + obj->TraceGetImplSp().get()), obj_name); + return obj; + } +}; + +} /* namespace trace */ + +#endif // TIM_EXPERIMENTAL_TRACE_TVX_CONTEXT_H_ diff --git a/include/tim/experimental/trace/tvx/graph.h b/include/tim/experimental/trace/tvx/graph.h new file mode 100755 index 0000000..85796d5 --- /dev/null +++ b/include/tim/experimental/trace/tvx/graph.h @@ -0,0 +1,360 @@ +/**************************************************************************** +* +* Copyright (c) 2020-2023 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_EXPERIMENTAL_TRACE_TVX_GRAPH_H_ +#define TIM_EXPERIMENTAL_TRACE_TVX_GRAPH_H_ +#include "tim/vx/graph.h" +#include "tim/experimental/trace/tvx/ops.h" +#include "tim/experimental/trace/tvx/tensor.h" +#include "tim/experimental/trace/tracer.h" + +namespace trace { + +namespace target = ::tim::vx; + +#define DECL_CREATE_OP_IMPL_(r, _, opname) \ + template \ + std::shared_ptr CreateOperationImpl( \ + BOOST_PP_CAT(ops::_VSI_Tag_of_, opname), Params... params); + +#define DECL_CREATE_OPS(ops_seq) \ + BOOST_PP_SEQ_FOR_EACH(DECL_CREATE_OP_IMPL_, _, ops_seq) + +#define TO_STRING_(expr) #expr + +#define DEF_CREATE_OP_IMPL_(r, _, opname) \ +template \ +std::shared_ptr Graph::CreateOperationImpl( \ + BOOST_PP_CAT(ops::_VSI_Tag_of_, opname), Params... params) { \ + std::string this_obj_name = TraceGetObjName(); \ + std::string obj_name = \ + Tracer::allocate_obj_name(std::string(TO_STRING_(opname)) + "_"); \ + Tracer::logging_msg( \ + "auto %s = %s->CreateOperation<%s::ops::%s>(", obj_name.c_str(), \ + this_obj_name.c_str(), target_namespace_name_, TO_STRING_(opname)); \ + Tracer::clear_params_log_cache(); \ + boost::hana::tuple params_tuple = {params...}; \ + boost::hana::for_each(params_tuple, [&] (auto x) { \ + Tracer::logging_param(x, -1); \ + }); \ + Tracer::dump_params_log_cache(); \ + Tracer::logging_msg(");\n"); \ + auto CreateOpImpl = [&] (auto&&... ts) { \ + return impl_->CreateOperation(ts...); \ + }; \ + auto proc_param_lambda = [] (auto&& t) { \ + return Tracer::proc_param(t); \ + }; \ + auto params_impl = boost::hana::transform(params_tuple, proc_param_lambda); \ + auto op_impl_sp = boost::hana::unpack(params_impl, CreateOpImpl); \ + auto op_sp = std::make_shared(op_impl_sp); \ + Tracer::insert_obj_name(static_cast(op_impl_sp.get()), obj_name); \ + return op_sp; \ +} + +#define DEF_CREATE_OP_IMPLS(ops_seq) \ + BOOST_PP_SEQ_FOR_EACH(DEF_CREATE_OP_IMPL_, _, ops_seq) + +#define SPECIALIZATION_CREATE_OP(opname, args_desc, SPECIAL_MACRO_) \ +template <> \ +inline std::shared_ptr Graph::CreateOperation( \ + ARGS_DESC_TO_DECLARATION(args_desc)) { \ + std::string this_obj_name = TraceGetObjName(); \ + std::string obj_name = Tracer::allocate_obj_name(std::string(#opname) + "_");\ + Tracer::push_back_msg_cache( \ + "auto " + obj_name + " = " + this_obj_name + \ + "->CreateOperation("); \ + LOG_PARAMS(ARGS_DESC_TO_PARAMS(args_desc)) \ + SPECIAL_MACRO_(ARGS_DESC_TO_PARAMS(args_desc)) \ + Tracer::pop_params_log_cache(); \ + Tracer::amend_last_msg_cache(");\n"); \ + Tracer::msg_cache_sync_to_file(); \ + auto op_impl_sp = impl_->CreateOperation( \ + SEQ_TO_VARIADICS(ARGS_DESC_TO_PARAMS(args_desc))); \ + auto op_sp = std::make_shared(op_impl_sp); \ + Tracer::insert_obj_name(static_cast(op_impl_sp.get()), obj_name); \ + return op_sp; \ +} + +struct Graph : public TraceClassBase { + DEF_INTERFACE_CONSTRUCTOR(Graph) + + #define SPECIAL_MACRO_(params) \ + if (BOOST_PP_SEQ_ELEM(1, params) == nullptr) { \ + Tracer::insert_params_log_cache("nullptr", 1); \ + } else { \ + uint32_t count = \ + BOOST_PP_SEQ_ELEM(0, params).TraceGetImpl().GetByteSize(); \ + uint32_t offset = \ + Tracer::dump_data( \ + BOOST_PP_SEQ_ELEM(1, params), sizeof(char), count); \ + LOGGING_POINTER_MSG(offset, count, 0) \ + } + + // DEF_MEMFN_SP(Tensor, + // CreateTensor, + // ((const TensorSpec&))((const void*)(nullptr)), + // SPECIAL_MACRO_) + std::shared_ptr CreateTensor( + const TensorSpec ¶m_0, const void *param_1 = nullptr) { + std::string this_obj_name = TraceGetObjName(); + std::string obj_name = + Tracer::allocate_obj_name(Tracer::get_obj_prefix("Tensor")); + Tracer::push_back_msg_cache("auto " + obj_name + " = " + this_obj_name + + "->" + __FUNCTION__ + "("); + Tracer::init_params_log_cache(2); + Tracer::logging_param(param_0, 0); + Tracer::logging_param(param_1, 1); + SPECIAL_MACRO_((param_0)(param_1)) + Tracer::pop_params_log_cache(); + Tracer::amend_last_msg_cache(");\n"); + Tracer::msg_cache_sync_to_file(); + + #if 1 /* if use fake input data */ + if (param_0.TraceGetImpl().GetTensorAttribute() == + TensorAttribute::TRANSIENT && param_1 == nullptr) { + auto fake_vec_name = Tracer::allocate_obj_name("fake_vec_"); + switch (param_0.TraceGetImpl().GetDataType()) { + case DataType::INT32: + Tracer::logging_msg("std::vector %s(%d, 42);\n" + "%s->CopyDataToTensor(%s.data());\n", + fake_vec_name.c_str(), (int)param_0.TraceGetImpl().GetElementNum(), + obj_name.c_str(), fake_vec_name.c_str()); + break; + case DataType::FLOAT32: + Tracer::logging_msg("std::vector %s(%d, 0.42);\n" + "%s->CopyDataToTensor(%s.data());\n", + fake_vec_name.c_str(), (int)param_0.TraceGetImpl().GetElementNum(), + obj_name.c_str(), fake_vec_name.c_str()); + break; + case DataType::FLOAT16: + Tracer::logging_msg("std::vector %s(%d, 0);\n" + "%s->CopyDataToTensor(%s.data());\n", + fake_vec_name.c_str(), (int)param_0.TraceGetImpl().GetElementNum(), + obj_name.c_str(), fake_vec_name.c_str()); + break; + case DataType::BOOL8: + Tracer::logging_msg("std::vector %s(%d, 0);\n" + "%s->CopyDataToTensor(%s.data());\n", + fake_vec_name.c_str(), (int)param_0.TraceGetImpl().GetElementNum(), + obj_name.c_str(), fake_vec_name.c_str()); + break; + default: + TCLOGE("Unimplemented fake data type:%d\n", + (int)param_0.TraceGetImpl().GetDataType()); + break; + } + } + #endif /* if use fake input data */ + + auto obj_impl_sp = impl_->CreateTensor( + Tracer::proc_param(param_0), Tracer::proc_param(param_1)); + auto obj_sp = std::make_shared(obj_impl_sp, param_0); + Tracer::insert_obj_name(static_cast(obj_impl_sp.get()), obj_name); + tensor_sp_keeper_[obj_impl_sp] = obj_sp; // need keep obj_sp alive. + spec_keeper_.push_back(param_0); + return obj_sp; + } + + // DEF_MEMFN_SP(Tensor, + // CreateIOTensor, + // ((const TensorSpec&))((void*)(nullptr)), + // SPECIAL_MACRO_) + std::shared_ptr CreateIOTensor( + const TensorSpec ¶m_0, void *param_1 = nullptr) { + std::string this_obj_name = TraceGetObjName(); + std::string obj_name = + Tracer::allocate_obj_name(Tracer::get_obj_prefix("Tensor")); + Tracer::push_back_msg_cache("auto " + obj_name + " = " + this_obj_name + + "->" + __FUNCTION__ + "("); + Tracer::init_params_log_cache(2); + Tracer::logging_param(param_0, 0); + Tracer::logging_param(param_1, 1); + SPECIAL_MACRO_((param_0)(param_1)) + Tracer::pop_params_log_cache(); + Tracer::amend_last_msg_cache(");\n"); + Tracer::msg_cache_sync_to_file(); + auto obj_impl_sp = impl_->CreateIOTensor( + Tracer::proc_param(param_0), Tracer::proc_param(param_1)); + auto obj_sp = std::make_shared(obj_impl_sp); + Tracer::insert_obj_name(static_cast(obj_impl_sp.get()), obj_name); + tensor_sp_keeper_[obj_impl_sp] = obj_sp; // need keep obj_sp alive. + spec_keeper_.push_back(param_0); + return obj_sp; + } + + #undef SPECIAL_MACRO_ + + // DEF_MEMFN_SP(Tensor, CreateTensor, + // ((const TensorSpec&))((const DmaBufferDesc&)) + // ) + std::shared_ptr CreateTensor( + const TensorSpec ¶m_0, const DmaBufferDesc ¶m_1) { + std::string this_obj_name = TraceGetObjName(); + std::string obj_name = + Tracer::allocate_obj_name(Tracer::get_obj_prefix("Tensor")); + Tracer::logging_msg("auto %s = %s->%s(", obj_name.c_str(), + this_obj_name.c_str(), __FUNCTION__); + Tracer::init_params_log_cache(2); + Tracer::logging_param(param_0, 0); + Tracer::logging_param(param_1, 1); + Tracer::dump_params_log_cache(); + Tracer::logging_msg(");\n"); + auto obj_impl_sp = impl_->CreateTensor(Tracer::proc_param(param_0), + Tracer::proc_param(param_1)); + auto obj_sp = std::make_shared(obj_impl_sp); + Tracer::insert_obj_name(static_cast(obj_impl_sp.get()), obj_name); + tensor_sp_keeper_[obj_impl_sp] = obj_sp; // need keep obj_sp alive. + spec_keeper_.push_back(param_0); + return obj_sp; + } + + // DEF_MEMFN_SP(Tensor, CreateTensorPlaceHolder) + std::shared_ptr CreateTensorPlaceHolder() { + std::string this_obj_name = TraceGetObjName(); + std::string obj_name = + Tracer::allocate_obj_name(Tracer::get_obj_prefix("Tensor")); + Tracer::logging_msg("auto %s = %s->%s();\n", obj_name.c_str(), + this_obj_name.c_str(), __FUNCTION__); + auto obj_impl_sp = impl_->CreateTensorPlaceHolder(); + auto obj_sp = std::make_shared(obj_impl_sp); + Tracer::insert_obj_name(static_cast(obj_impl_sp.get()), obj_name); + tensor_sp_keeper_[obj_impl_sp] = obj_sp; // need keep obj_sp alive. + return obj_sp; + } + + DEF_TRACED_API(bool, Compile) + + // DEF_TRACED_API(bool, CompileToBinary, [&] () { + // if (boost::hana::at_c<0>(params_tuple) == nullptr) { + // auto nbg_size_var = Tracer::allocate_obj_name("nbg_size_"); + // Tracer::insert_obj_name(boost::hana::at_c<1>(params_tuple), nbg_size_var); + // Tracer::insert_before_last_msg_cache( + // "size_t " + nbg_size_var + " = -1;\n"); + // Tracer::insert_params_log_cache("nullptr", 0); + // } else { + // uint32_t count = *boost::hana::at_c<1>(params_tuple); + // uint32_t offset = Tracer::dump_data(boost::hana::at_c<0>(params_tuple), + // sizeof(char), count); + // LOGGING_POINTER_MSG(offset, count, 0) + // } + // Tracer::insert_params_log_cache( + // "&" + Tracer::get_obj_name(boost::hana::at_c<1>(params_tuple)), 1); + // }) + + #define SPECIAL_MACRO_(params) \ + if (BOOST_PP_SEQ_ELEM(0, params) == nullptr) { \ + std::string size_name = Tracer::allocate_obj_name("nbg_size_"); \ + Tracer::insert_obj_name(BOOST_PP_SEQ_ELEM(1, params), size_name); \ + Tracer::insert_before_last_msg_cache( \ + "size_t " + size_name + " = -1;\n"); \ + Tracer::insert_params_log_cache("nullptr", 0); \ + } else { \ + uint32_t data_length = *BOOST_PP_SEQ_ELEM(1, params); \ + uint32_t offset = Tracer::dump_data( \ + BOOST_PP_SEQ_ELEM(0, params), sizeof(char), data_length); \ + LOGGING_POINTER_MSG(offset, data_length, 0) \ + } \ + Tracer::insert_params_log_cache( \ + "&" + Tracer::get_obj_name(BOOST_PP_SEQ_ELEM(1, params)), 1); + + DEF_MEMFN(bool, + CompileToBinary, + ((void*))((size_t*)), + SPECIAL_MACRO_) + + #undef SPECIAL_MACRO_ + + DEF_TRACED_API(bool, Run) + + template + std::shared_ptr CreateOperation(Params... params) { + return CreateOperationImpl( + typename ops::TagDispatchTrait::tag {}, params...); + } + + const std::vector> InputsTensor() { + std::vector> ret; + for (auto& x : impl_->InputsTensor()) { + ret.push_back(tensor_sp_keeper_[x]); + } + return ret; + } + + const std::vector> OutputsTensor() { + std::vector> ret; + for (auto& x : impl_->OutputsTensor()) { + ret.push_back(tensor_sp_keeper_[x]); + } + return ret; + } + + DEF_TRACED_API(void, UpdateTensorConsumersMap) + + DEF_TRACED_API(void, UpdateTensorProducerMap) + + // DEF_TRACED_API(cosnt std::vector>, GetConsumersOp) + + // DEF_TRACED_API(std::shared_ptr, GetProducerOp) + + DEF_TRACED_API(void, PrintGraph) + + // DEF_TRACED_API(const std::vector>, GetConstantInputs) + + // DEF_TRACED_API(const std::vector>, GetOpVector) + + private: + std::unordered_map, std::shared_ptr> + tensor_sp_keeper_; + std::vector spec_keeper_; + + DECL_CREATE_OPS(TVX_OPS_SEQ) + DECL_CREATE_OP_IMPL_(_, _, Pad) + DECL_CREATE_OP_IMPL_(_, _, PadV2) + +}; +#define SPECIAL_MACRO_(params) \ + std::string buf_name = Tracer::allocate_obj_name("nbg_buf_vec_"); \ + FILE* nbg_dumped = fopen("network_binary_graph.nb", "r"); \ + fseek(nbg_dumped, 0L, SEEK_END); \ + uint32_t count = ftell(nbg_dumped); \ + fclose(nbg_dumped); \ + uint32_t offset = Tracer::dump_data( \ + BOOST_PP_SEQ_ELEM(0, params), sizeof(char), count); \ + Tracer::insert_before_last_msg_cache("std::vector " + buf_name + \ + " = trace::Replayer::get_vector(" + std::to_string(offset) + \ + "," + std::to_string(count) + ");\n"); \ + Tracer::insert_params_log_cache(buf_name + ".data()", 0); + +SPECIALIZATION_CREATE_OP(NBG, ((const char*))((size_t))((size_t)), + SPECIAL_MACRO_) + +#undef SPECIAL_MACRO_ + +DEF_CREATE_OP_IMPLS(TVX_OPS_SEQ) +DEF_CREATE_OP_IMPL_(_, _, Pad) +DEF_CREATE_OP_IMPL_(_, _, PadV2) + +} /* namespace trace */ + +#endif // TIM_EXPERIMENTAL_TRACE_TVX_GRAPH_H_ diff --git a/include/tim/experimental/trace/tvx/ops.h b/include/tim/experimental/trace/tvx/ops.h new file mode 100755 index 0000000..6397f33 --- /dev/null +++ b/include/tim/experimental/trace/tvx/ops.h @@ -0,0 +1,323 @@ +/**************************************************************************** +* +* Copyright (c) 2020-2023 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_EXPERIMENTAL_TRACE_TVX_OPS_H_ +#define TIM_EXPERIMENTAL_TRACE_TVX_OPS_H_ +#include "tim/vx/operation.h" +#include "tim/vx/ops.h" +#include "tim/experimental/trace/tvx/tensor.h" +#include "tim/experimental/trace/tracer.h" + +#define TVX_OPS_SEQ \ + (Relu) \ + (Relu1) \ + (Relu6) \ + (Tanh) \ + (Sigmoid) \ + (Swish) \ + (HardSwish) \ + (Mish) \ + (SoftRelu) \ + (Sign) \ + (SoftSign) \ + (Elu) \ + (Prelu) \ + (HardSigmoid) \ + (LeakyRelu) \ + (Linear) \ + (Gelu) \ + (Selu) \ + (Celu) \ + (AddN) \ + (ArgMin) \ + (ArgMax) \ + (Batch2Space) \ + (BatchNorm) \ + (BidirectionalSequenceRnn) \ + (BidirectionalSequenceRnnExt) \ + (Broadcast) \ + (Clip) \ + (Concat) \ + (Conv1d) \ + (Conv2d) \ + (DeConv1d) \ + (DeConv2d) \ + (DepthToSpace) \ + (Dropout) \ + (Minimum) \ + (Maximum) \ + (Add) \ + (Sub) \ + (Pow) \ + (FloorDiv) \ + (Multiply) \ + (Div) \ + (Erf) \ + (FullyConnected) \ + (Gather) \ + (GatherElements) \ + (GatherNd) \ + (GroupedConv2d) \ + (InstanceNormalization) \ + (L2Normalization) \ + (LayerNormalization) \ + (LRN) \ + (LogicalAnd) \ + (LogicalOr) \ + (LogSoftmax) \ + (Matmul) \ + (MaxpoolWithArgmax) \ + (MaxpoolWithArgmax2) \ + (MaxpoolGrad) \ + (MaxUnpool2d) \ + (Moments) \ + (OneHot) \ + /* (Pad) enum defined in class */ \ + /* (PadV2) enum defined in class */ \ + (Pool1d) \ + (Pool2d) \ + (ReduceMin) \ + (ReduceMax) \ + (ReduceAny) \ + (ReduceAll) \ + (ReduceProd) \ + (ReduceMean) \ + (ReduceSum) \ + (Greater) \ + (GreaterOrEqual) \ + (Less) \ + (LessOrEqual) \ + (NotEqual) \ + (Equal) \ + (Reorg) \ + (Reshape) \ + (Resize1d) \ + (Resize) \ + (Reverse) \ + (RNNCell) \ + (RoiAlign) \ + (RoiPool) \ + (ScatterND) \ + (Select) \ + (ShuffleChannel) \ + (DataConvert) \ + (Neg) \ + (Abs) \ + (Sin) \ + (Exp) \ + (Log) \ + (Sqrt) \ + (Rsqrt) \ + (Square) \ + (LogicalNot) \ + (Floor) \ + (Ceil) \ + (Round) \ + (Cast) \ + (Rcp) \ + (SignalFrame) \ + (Slice) \ + (Softmax) \ + (Space2Batch) \ + (SpaceToDepth) \ + (SpatialTransformer) \ + (Split) \ + (Squeeze) \ + (Stack) \ + (StridedSlice) \ + (Svdf) \ + (Tile) \ + (Transpose) \ + (UnidirectionalSequenceLstm) \ + (UnidirectionalSequenceRnn) \ + (UnidirectionalSequenceRnnExt) \ + (Unstack) \ + (Conv3d) \ + (CustomOpBase) \ + (Topk) \ + (BidirectionalSequenceLstm) \ + (HashtableLookup) \ + (EmbeddingLookup) \ + (NBG) + +namespace trace { + +namespace target = ::tim::vx; + +struct Operation : public TraceClassBase { + DEF_INTERFACE_CONSTRUCTOR(Operation) + + DEF_TRACED_API(std::shared_ptr, Clone) + + // unfixed issue when use DEF_TRACED_API to define BindOutput + DEF_INPLACE_MEMFN(Operation&, BindInput, ((const std::shared_ptr&))) + DEF_INPLACE_MEMFN(Operation&, BindOutput, ((const std::shared_ptr&))) + // DEF_TRACED_API(Operation&, BindInput) + // template + // typename std::enable_if_t::value, R> BindInput(Args... params) + // { + // std::string this_obj_name = TraceGetObjName(); + // Tracer::push_back_msg_cache(this_obj_name + "->" + __FUNCTION__ + "("); + // Tracer::clear_params_log_cache(); + // boost::hana::tuple params_tuple = {params...}; + // boost::hana::for_each(params_tuple, [&](auto x) + // { Tracer::logging_param(x, -1); }); + // Tracer::pop_params_log_cache(); + // Tracer::amend_last_msg_cache(");\n"); + // Tracer::msg_cache_sync_to_file(); + // auto api_impl = [&](auto &&...ts) + // { return &(impl_->BindInput(ts...)); }; + // auto proc_param_lambda = [](auto &&t) + // { return Tracer::proc_param(t); }; + // auto params_impl = boost::hana::transform(params_tuple, proc_param_lambda); + // auto ret_impl_p = boost::hana::unpack(params_impl, api_impl); + // return *this; + // // return *static_cast *>(Tracer::get_traced_obj(ret_impl_p)); + // } + // DEF_TRACED_API(Operation&, BindOutput) + // template + // typename std::enable_if_t::value, R> BindOutput(Args... params) + // { + // std::string this_obj_name = TraceGetObjName(); + // Tracer::push_back_msg_cache(this_obj_name + "->" + __FUNCTION__ + "("); + // Tracer::clear_params_log_cache(); + // boost::hana::tuple params_tuple = {params...}; + // boost::hana::for_each(params_tuple, [&](auto x) + // { Tracer::logging_param(x, -1); }); + // Tracer::pop_params_log_cache(); + // Tracer::amend_last_msg_cache(");\n"); + // Tracer::msg_cache_sync_to_file(); + // auto api_impl = [&](auto &&...ts) + // { return &(impl_->BindOutput(ts...)); }; + // auto proc_param_lambda = [](auto &&t) + // { return Tracer::proc_param(t); }; + // auto params_impl = boost::hana::transform(params_tuple, proc_param_lambda); + // auto ret_impl_p = boost::hana::unpack(params_impl, api_impl); + // return *this; + // // return *static_cast *>(Tracer::get_traced_obj(ret_impl_p)); + // } + + // compiler can not deduce the type for + // template variadic, so can't use DEF_TRACED_API to define BindInputs/Outputs + // DEF_INPLACE_MEMFN(Operation&, BindInputs, + // ((const std::vector>&)) + // ) + Operation &BindInputs(const std::vector> ¶m_0) + { + std::string this_obj_name = TraceGetObjName(); + Tracer::logging_msg("%s->%s(", this_obj_name.c_str(), __FUNCTION__); + Tracer::init_params_log_cache(1); + Tracer::logging_param(param_0, 0); + Tracer::dump_params_log_cache(); + Tracer::logging_msg(");\n"); + impl_->BindInputs(Tracer::proc_param(param_0)); + return *this; + } + // DEF_INPLACE_MEMFN(Operation&, BindOutputs, + // ((const std::vector>&)) + // ) + Operation &BindOutputs(const std::vector> ¶m_0) + { + std::string this_obj_name = TraceGetObjName(); + Tracer::logging_msg("%s->%s(", this_obj_name.c_str(), __FUNCTION__); + Tracer::init_params_log_cache(1); + Tracer::logging_param(param_0, 0); + Tracer::dump_params_log_cache(); + Tracer::logging_msg(");\n"); + impl_->BindOutputs(Tracer::proc_param(param_0)); + return *this; + } + DEF_INPLACE_MEMFN(Operation&, SetRoundingPolicy, + ((OverflowPolicy)(OverflowPolicy::SATURATE)) + ((RoundingPolicy)(RoundingPolicy::RTNE)) + ((RoundType)(RoundType::FLOOR)) + ((uint32_t)(0)) + ) + + // do not support unique_ptr and OpImpl by now + // DEF_TRACED_API(std::unique_ptr&, impl) + + // do not support unique_ptr and OpImpl by now + // DEF_TRACED_API(const std::unique_ptr&, impl) + + DEF_TRACED_API(const std::vector>, ConstantInputsTensor) + + protected: + DEF_TRACED_API(bool, IsAllInputsConst) +}; + +} /* namespace trace */ + +namespace trace { +namespace ops { + +struct DefaultTag {}; + +template +struct TagDispatchTrait { + using tag = DefaultTag; +}; + +#define DEF_TIMVX_OP_AND_TAG_IMPL_(r, _, op) \ +struct op : Operation { \ + op(const std::shared_ptr& impl) : Operation(impl) {} \ +}; \ +struct BOOST_PP_CAT(_VSI_Tag_of_, op) {}; \ +template<> \ +struct TagDispatchTrait { \ + using tag = BOOST_PP_CAT(_VSI_Tag_of_, op); \ +}; + +#define DEF_TIMVX_OPS_AND_TAGS(ops) \ + BOOST_PP_SEQ_FOR_EACH(DEF_TIMVX_OP_AND_TAG_IMPL_, _, ops) + +DEF_TIMVX_OPS_AND_TAGS(TVX_OPS_SEQ) + +// DEF_TIMVX_OP_AND_TAG_IMPL_(_, _, Pad) +struct Pad : Operation { + // must be used as tvx::ops::Pad::pad_mode_type::PAD_MODE_CONSTANT, but not + // tvx::ops::Pad::PAD_MODE_CONSTANT + using pad_mode_type = target::ops::Pad::pad_mode_type; + Pad(const std::shared_ptr &impl) : Operation(impl) {} +}; +struct _VSI_Tag_of_Pad {}; +template <> +struct TagDispatchTrait { + using tag = _VSI_Tag_of_Pad; +}; + +// DEF_TIMVX_OP_AND_TAG_IMPL_(_, _, PadV2) +struct PadV2 : Operation { + using pad_mode_type = target::ops::PadV2::pad_mode_type; + PadV2(const std::shared_ptr &impl) : Operation(impl) {} +}; +struct _VSI_Tag_of_PadV2 {}; +template <> +struct TagDispatchTrait { + using tag = _VSI_Tag_of_PadV2; +}; + +} /* namespace ops */ +} /* namespace trace */ + +#endif // TIM_EXPERIMENTAL_TRACE_TVX_OPS_H_ diff --git a/include/tim/experimental/trace/tvx/tensor.h b/include/tim/experimental/trace/tvx/tensor.h new file mode 100755 index 0000000..d9d63a6 --- /dev/null +++ b/include/tim/experimental/trace/tvx/tensor.h @@ -0,0 +1,246 @@ +/**************************************************************************** +* +* Copyright (c) 2020-2023 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_EXPERIMENTAL_TRACE_TVX_TENSOR_H_ +#define TIM_EXPERIMENTAL_TRACE_TVX_TENSOR_H_ +#include "tim/vx/tensor.h" +#include "tim/experimental/trace/tvx/types.h" +#include "tim/experimental/trace/tracer.h" + +namespace trace { + +namespace target = ::tim::vx; + +struct Quantization : public TraceClassBase { + DEF_CONSTRUCTOR(Quantization) + + DEF_CONSTRUCTOR(Quantization, ((QuantType)) + ((float)) + ((int32_t)) + ) + + DEF_CONSTRUCTOR(Quantization, ((QuantType)) + ((int32_t)) + ((std::vector)) + ((std::vector)) + ) + + DEF_CONSTRUCTOR(Quantization, ((QuantType)) + ((int8_t)) + ) + + DEF_TRACED_API(QuantType&, Type) + + // DEF_TRACED_API(const QuantType&, Type) + template + typename std::enable_if_t::value, R> Type( + Args... params) const { + std::string this_obj_name = TraceGetObjName(); + Tracer::push_back_msg_cache(this_obj_name + "->" + __FUNCTION__ + "("); + Tracer::clear_params_log_cache(); + boost::hana::tuple params_tuple = {params...}; + boost::hana::for_each(params_tuple, [&](auto x) { + Tracer::logging_param(x, -1);}); + Tracer::pop_params_log_cache(); + Tracer::amend_last_msg_cache(");\n"); + Tracer::msg_cache_sync_to_file(); + auto api_impl = [&](auto &&...ts) { return impl_->Type(ts...); }; + auto proc_param_lambda = [](auto &&t) { return Tracer::proc_param(t); }; + auto params_impl = boost::hana::transform(params_tuple, proc_param_lambda); + return boost::hana::unpack(params_impl, api_impl); + } + + DEF_TRACED_API(int32_t&, ChannelDim) + + // DEF_TRACED_API(const int32_t&, ChannelDim) + template + typename std::enable_if_t::value, R> ChannelDim( + Args... params) const { + std::string this_obj_name = TraceGetObjName(); + Tracer::push_back_msg_cache(this_obj_name + "->" + __FUNCTION__ + "("); + Tracer::clear_params_log_cache(); + boost::hana::tuple params_tuple = {params...}; + boost::hana::for_each(params_tuple, [&](auto x) { + Tracer::logging_param(x, -1); }); + Tracer::pop_params_log_cache(); + Tracer::amend_last_msg_cache(");\n"); + Tracer::msg_cache_sync_to_file(); + auto api_impl = [&](auto &&...ts) { return impl_->ChannelDim(ts...); }; + auto proc_param_lambda = [](auto &&t) { return Tracer::proc_param(t); }; + auto params_impl = boost::hana::transform(params_tuple, proc_param_lambda); + return boost::hana::unpack(params_impl, api_impl); + } + + DEF_TRACED_API(Quantization&, SetChannelDim) + + DEF_TRACED_API(std::vector&, Scales) + + DEF_TRACED_API(std::vector&, ZeroPoints) + + // DEF_TRACED_API(const std::vector&, ZeroPoints) + template &, class... Args> + typename std::enable_if_t::value, R> ZeroPoints( + Args... params) const { + std::string this_obj_name = TraceGetObjName(); + Tracer::push_back_msg_cache(this_obj_name + "->" + __FUNCTION__ + "("); + Tracer::clear_params_log_cache(); + boost::hana::tuple params_tuple = {params...}; + boost::hana::for_each(params_tuple, [&](auto x) { + Tracer::logging_param(x, -1); }); + Tracer::pop_params_log_cache(); + Tracer::amend_last_msg_cache(");\n"); + Tracer::msg_cache_sync_to_file(); + auto api_impl = [&](auto &&...ts) { return impl_->ZeroPoints(ts...); }; + auto proc_param_lambda = [](auto &&t) { return Tracer::proc_param(t); }; + auto params_impl = boost::hana::transform(params_tuple, proc_param_lambda); + return boost::hana::unpack(params_impl, api_impl); + } + + DEF_TRACED_API(Quantization&, SetZeroPoints) + + DEF_TRACED_API(const std::int8_t&, Fl) +}; + +struct TensorSpec : public TraceClassBase { + DEF_CONSTRUCTOR(TensorSpec) + + DEF_CONSTRUCTOR(TensorSpec, ((DataType)) + ((const ShapeType&)) + ((TensorAttribute)) + ) + + DEF_CONSTRUCTOR(TensorSpec, ((DataType)) + ((const ShapeType&)) + ((TensorAttribute)) + ((const Quantization&)) + ) + + // DEF_CONSTRUCTOR(TensorSpec, ((const TensorSpec&)) + // ) + + DEF_TRACED_API(TensorSpec&, operator=) + + DEF_TRACED_API(TensorSpec&, SetDataType) + + DEF_TRACED_API(TensorSpec&, SetShape) + + DEF_TRACED_API(TensorSpec&, SetAttribute) + + DEF_TRACED_API(TensorSpec&, SetQuantization) + + DEF_TRACED_API(TensorSpec&, AsTransientSpec) + + DEF_TRACED_API(int64_t, GetElementNum) + + DEF_TRACED_API(int64_t, GetElementByteSize) + + DEF_TRACED_API(int64_t, GetByteSize) + + DEF_SIMPLE_UNTRACED_API(DataType&, GetDataType) + + DEF_SIMPLE_UNTRACED_API(ShapeType&, GetShapeType) + + DEF_SIMPLE_UNTRACED_API(TensorAttribute&, GetTensorAttribute) + + DEF_TRACED_API(Quantization&, GetQuantization) +}; + +struct DmaBufferDesc : public TraceClassBase { + DEF_CONSTRUCTOR(DmaBufferDesc) +}; + +struct Tensor : public TraceClassBase { + // DEF_INTERFACE_CONSTRUCTOR(Tensor) + Tensor(const std::shared_ptr &impl) { + Tracer::insert_traced_obj(static_cast(impl.get()), + static_cast(this)); + impl_ = impl; + } + Tensor(const std::shared_ptr &impl, const TensorSpec& spec) { + Tracer::insert_traced_obj(static_cast(impl.get()), + static_cast(this)); + impl_ = impl; + spec_ = spec; + } + // DEF_TRACED_API(const ShapeType&, GetShape) + // unfixed issue of trace GetShape + DEF_SIMPLE_UNTRACED_API(const ShapeType&, GetShape) + + DEF_TRACED_API(DataType, GetDataType) + + DEF_TRACED_API(const Quantization&, GetQuantization) + + // DEF_TRACED_API(TensorSpec&, GetSpec) + TensorSpec& GetSpec() { + return spec_; + } + + + // DEF_TRACED_API(uint32_t, GetId) + DEF_SIMPLE_UNTRACED_API(uint32_t, GetId) + + #define SPECIAL_MACRO_(params) \ + uint32_t count = BOOST_PP_SEQ_ELEM(1, params); \ + uint32_t offset = \ + Tracer::dump_data(BOOST_PP_SEQ_ELEM(0, params), \ + sizeof(char), count); \ + LOGGING_POINTER_MSG(offset, count, 0) + + DEF_MEMFN(bool, CopyDataToTensor, ((const void*))((uint32_t)(0)), + SPECIAL_MACRO_ + ) + + #undef SPECIAL_MACRO_ + + DEF_TRACED_API(bool, CopyDataFromTensor, [&]() { + uint32_t count = impl_->GetSpec().GetByteSize(); + uint32_t offset = Tracer::dump_data(boost::hana::at_c<0>(params_tuple), + sizeof(char), count); + LOGGING_POINTER_MSG(offset, count, 0) + }) + + DEF_TRACED_API(bool, FlushCacheForHandle) + + DEF_TRACED_API(bool, InvalidateCacheForHandle) + + DEF_MEMFN(void*, map, ((bool)(false)) + ) + + DEF_TRACED_API(void, unmap) + + DEF_TRACED_API(bool, IsPlaceHolder) + + DEF_TRACED_API(bool, IsConstTensor) + + DEF_TRACED_API(bool, SaveTensorToTextByFp32) + + DEF_TRACED_API(void*, ConvertTensorToData) + + private: + TensorSpec spec_; +}; + +} /* namespace trace */ + + +#endif // TIM_EXPERIMENTAL_TRACE_TVX_TENSOR_H_ diff --git a/include/tim/experimental/trace/tvx/types.h b/include/tim/experimental/trace/tvx/types.h new file mode 100755 index 0000000..d0327fe --- /dev/null +++ b/include/tim/experimental/trace/tvx/types.h @@ -0,0 +1,56 @@ +/**************************************************************************** +* +* Copyright (c) 2020-2023 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_EXPERIMENTAL_TRACE_TVX_TYPES_H_ +#define TIM_EXPERIMENTAL_TRACE_TVX_TYPES_H_ +#include "tim/vx/types.h" + +namespace trace { + +namespace target = ::tim::vx; + +using ShapeType = std::vector; + +using DataType = target::DataType; + +using QuantType = target::QuantType; + +using TensorAttribute = target::TensorAttribute; + +using PadType = target::PadType; + +using PoolType = target::PoolType; + +using RoundType = target::RoundType; + +using OverflowPolicy = target::OverflowPolicy; + +using RoundingPolicy = target::RoundingPolicy; + +using ResizeType = target::ResizeType; + +using DataLayout = target::DataLayout; + +} /* namespace trace */ + +#endif // TIM_EXPERIMENTAL_TRACE_TVX_TYPES_H_ diff --git a/include/tim/vx/tensor.h b/include/tim/vx/tensor.h index ac50e69..211250a 100644 --- a/include/tim/vx/tensor.h +++ b/include/tim/vx/tensor.h @@ -122,6 +122,14 @@ struct TensorSpec { int64_t GetByteSize() const; + inline DataType& GetDataType() { return datatype_; } + + inline ShapeType& GetShapeType() { return shape_; } + + inline TensorAttribute& GetTensorAttribute() { return attr_; } + + inline Quantization& GetQuantization() { return quantization_; } + DataType datatype_; ShapeType shape_; TensorAttribute attr_; diff --git a/src/tim/CMakeLists.txt b/src/tim/CMakeLists.txt index a2cc363..181aeeb 100644 --- a/src/tim/CMakeLists.txt +++ b/src/tim/CMakeLists.txt @@ -227,6 +227,10 @@ endif() if(TIM_VX_ENABLE_TEST) include(GoogleTest) + if(TIM_VX_ENABLE_API_TRACE) + add_definitions("-DENABLE_API_TRACE") + endif() + add_executable(unit_test ${${TARGET_NAME}_TEST_SRCS}) target_link_libraries(unit_test PRIVATE -Wl,--whole-archive ${TARGET_NAME} @@ -235,7 +239,9 @@ if(TIM_VX_ENABLE_TEST) ${PROJECT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/vx ${OVXLIB_INCLUDE_DIR} - ${LITE_INC_DIRS}) + ${LITE_INC_DIRS} + ${PROJECT_SOURCE_DIR} + ) install(TARGETS unit_test DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR}) endif() diff --git a/src/tim/vx/graph_test.cc b/src/tim/vx/graph_test.cc index 26ea564..60e4f81 100644 --- a/src/tim/vx/graph_test.cc +++ b/src/tim/vx/graph_test.cc @@ -90,3 +90,65 @@ TEST(graph, gen_binary_graph_with_simple_add) { EXPECT_TRUE(nbg_out->CopyDataFromTensor(&output)); EXPECT_EQ(output, expected_out); } + +#ifdef ENABLE_API_TRACE +#define API_REPLAYER_IMPLEMENTATION +#define API_TRACER_IMPLEMENTATION +#define TARGET_NAMESPACE_NAME "tim::vx" +#include "tim/experimental/trace/trace_tvx.h" + +namespace tvx = trace; + +TEST(graph, trace_test) { + auto ctx = tvx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tvx::ShapeType io_shape({1,2,2,1}); + tvx::TensorSpec input_spec(tvx::DataType::FLOAT32, io_shape, tvx::TensorAttribute::INPUT); + tvx::TensorSpec output_spec(tvx::DataType::FLOAT32, io_shape, tvx::TensorAttribute::OUTPUT); + auto input_t0 = graph->CreateTensor(input_spec); + auto input_t1 = graph->CreateTensor(input_spec); + auto input_t2 = graph->CreateTensor(input_spec); + auto output_t0 = graph->CreateTensor(output_spec); + + auto reshape = graph->CreateOperation(io_shape); + (*reshape).BindInput(input_t0).BindOutput(input_t1); + auto add = graph->CreateOperation(); + (*add).BindInputs({input_t0, input_t2}).BindOutputs({output_t0}); + + size_t bin_size = -1; + EXPECT_TRUE(graph->CompileToBinary(nullptr, &bin_size)); + EXPECT_NE(bin_size, -1); + std::vector nbg_buf(bin_size); + + // generate binary graph does't require input data + EXPECT_TRUE(graph->CompileToBinary(nbg_buf.data(), &bin_size)); + + // binary graph compilation doesn't impact current graph's execution + std::vector in = {1.1f, 2.2f, 3.3f, 4.4f}; + std::vector expected_out = {2.2f, 4.4f, 6.6f, 8.8f};; + EXPECT_TRUE(input_t0->CopyDataToTensor(in.data(), sizeof(float) * in.size())); + EXPECT_TRUE(input_t2->CopyDataToTensor(in.data(), sizeof(float) * in.size())); + + EXPECT_TRUE(graph->Run()); + std::vector output(in.size()); + EXPECT_TRUE(output_t0->CopyDataFromTensor(output.data())); + EXPECT_EQ(output, expected_out); + +} + +TEST(graph, replay_test) { + /* + * Copy trace_log.cc in the root dir of tim-vx, rename with trace_log.rpl.cc + * and copy the trace_bin.bin file to the runtime workspace, + * rename with trace_bin.rpl.bin + */ + + #include "trace_log.rpl.cc" + // manual compile and run the selected graph, like: + // graph_12->Compile(); + // graph_12->Run(); + // Last rebuild unit-test and execute this case +} + +#endif /* #ifdef ENABLE_API_TRACE */ diff --git a/src/tim/vx/ops/maxpoolgrad.cc b/src/tim/vx/ops/maxpoolgrad.cc index 6d4699d..6b81ef6 100644 --- a/src/tim/vx/ops/maxpoolgrad.cc +++ b/src/tim/vx/ops/maxpoolgrad.cc @@ -67,19 +67,25 @@ class MaxpoolGradImpl : public OpImpl { tim::vx::ShapeType idx_flattened_shape({CalFlattenedShape(grad_shape)}); tim::vx::ShapeType out_flattened_shape({CalFlattenedShape(in_shape)}); - tim::vx::TensorSpec pool_out_spec_values(tim::vx::DataType::FLOAT32, - grad_shape, tim::vx::TensorAttribute::TRANSIENT); + auto in_type = in_tensors_[POOL_INPUT_TENSOR]->GetDataType(); + auto in_quant = in_tensors_[POOL_INPUT_TENSOR]->GetQuantization(); + if (in_quant.Type() != tim::vx::QuantType::NONE) { + VSILOGW("MaxPoolGrad deal with quantization tensor not validate yet!"); + } + tim::vx::TensorSpec pool_out_spec_values(in_type, + grad_shape, tim::vx::TensorAttribute::TRANSIENT, in_quant); tim::vx::TensorSpec pool_out_spec_indices(tim::vx::DataType::INT32, grad_shape, tim::vx::TensorAttribute::TRANSIENT); tim::vx::TensorSpec idx_flattened_spec(tim::vx::DataType::INT32, idx_flattened_shape,tim::vx::TensorAttribute::TRANSIENT); - tim::vx::TensorSpec upd_flattened_spec(tim::vx::DataType::FLOAT32, - idx_flattened_shape, tim::vx::TensorAttribute::TRANSIENT); - tim::vx::TensorSpec out_flattened_spec(tim::vx::DataType::FLOAT32, - out_flattened_shape, tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec upd_flattened_spec(in_type, + idx_flattened_shape, tim::vx::TensorAttribute::TRANSIENT, in_quant); + tim::vx::TensorSpec out_flattened_spec(in_type, + out_flattened_shape, tim::vx::TensorAttribute::TRANSIENT, in_quant); auto pool_out_values_tensor = graph_->CreateTensor(pool_out_spec_values); - auto pool_out_indices_tensor = graph_->CreateTensor(pool_out_spec_indices); + auto pool_out_indices_tensor = + graph_->CreateTensor(pool_out_spec_indices); auto idx_flattened_tensor = graph_->CreateTensor(idx_flattened_spec); auto upd_flattened_tensor = graph_->CreateTensor(upd_flattened_spec); auto out_flattened_tensor = graph_->CreateTensor(out_flattened_spec); @@ -88,36 +94,43 @@ class MaxpoolGradImpl : public OpImpl { .BindOutputs({pool_out_values_tensor, pool_out_indices_tensor}); // eliminate pool out of maxpoolwithargmax begin - tim::vx::TensorSpec sliced_spec(tim::vx::DataType::FLOAT32, - {1, 1, 1, 1}, tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec sliced_spec(in_type, + {1, 1, 1, 1}, tim::vx::TensorAttribute::TRANSIENT, in_quant); auto sliced_tensor = graph_->CreateTensor(sliced_spec); auto one_zero_tensor = graph_->CreateTensor(sliced_spec); auto grad_tensor = graph_->CreateTensor(pool_out_spec_values); std::vector start = {0, 0, 0, 0}; std::vector length = {1, 1, 1, 1}; - slice_one_ = graph_->CreateOperation(0, start, length); - (*slice_one_).BindInput(pool_out_values_tensor).BindOutput(sliced_tensor); + auto slice_one = + graph_->CreateOperation(0, start, length); + (*slice_one).BindInput(pool_out_values_tensor).BindOutput(sliced_tensor); - self_sub_ = graph_->CreateOperation(); - (*self_sub_).BindInputs({sliced_tensor, sliced_tensor}).BindOutput(one_zero_tensor); + auto self_sub = graph_->CreateOperation(); + (*self_sub).BindInputs({sliced_tensor, sliced_tensor}) + .BindOutput(one_zero_tensor); - add_zeros_ = graph_->CreateOperation(); - (*add_zeros_).BindInputs({one_zero_tensor, in_tensors_[GRADIENT_TENSOR]}) + auto add_zeros = graph_->CreateOperation(); + (*add_zeros).BindInputs({one_zero_tensor, in_tensors_[GRADIENT_TENSOR]}) .BindOutput(grad_tensor); // eliminate pool out of maxpoolwithargmax end - flatten_idx_ = graph_->CreateOperation(idx_flattened_shape); - (*flatten_idx_).BindInput(pool_out_indices_tensor).BindOutput(idx_flattened_tensor); + auto flatten_idx = + graph_->CreateOperation(idx_flattened_shape); + (*flatten_idx).BindInput(pool_out_indices_tensor) + .BindOutput(idx_flattened_tensor); - flatten_upd_ = graph_->CreateOperation(idx_flattened_shape); - (*flatten_upd_).BindInput(grad_tensor).BindOutput(upd_flattened_tensor); + auto flatten_upd = + graph_->CreateOperation(idx_flattened_shape); + (*flatten_upd).BindInput(grad_tensor).BindOutput(upd_flattened_tensor); - scatternd_ = graph_->CreateOperation(out_flattened_shape); - (*scatternd_).BindInputs({idx_flattened_tensor, upd_flattened_tensor}) - .BindOutput(out_flattened_tensor); + auto scatternd = + graph_->CreateOperation(out_flattened_shape); + (*scatternd).BindInputs({idx_flattened_tensor, upd_flattened_tensor}) + .BindOutput(out_flattened_tensor); - reshape_like_input_ = graph_->CreateOperation(in_shape); + reshape_like_input_ = + graph_->CreateOperation(in_shape); (*reshape_like_input_).BindInput(out_flattened_tensor); } @@ -150,14 +163,7 @@ class MaxpoolGradImpl : public OpImpl { const RoundType round_type_; std::shared_ptr maxpoolwithargmax2_; - std::shared_ptr slice_one_; - std::shared_ptr self_sub_; - std::shared_ptr add_zeros_; - std::shared_ptr flatten_idx_; - std::shared_ptr flatten_upd_; - std::shared_ptr scatternd_; std::shared_ptr reshape_like_input_; - std::shared_ptr reshape_pool_output_; std::array, INPUT_CNT> in_tensors_; std::array, OUTPUT_CNT> out_tensors_; uint32_t CalFlattenedShape(const tim::vx::ShapeType& shape) {