diff --git a/CMakeLists.txt b/CMakeLists.txt index 1e388a6..911bc0c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,7 @@ option(TIM_VX_ENABLE_PLATFORM "Enable multi devices support" option(TIM_VX_ENABLE_PLATFORM_LITE "Enable lite multi-device support" OFF) option(TIM_VX_ENABLE_GRPC "Enable gPRC support" OFF) option(TIM_VX_DBG_ENABLE_TENSOR_HNDL "Enable built-in tensor from handle: use malloced memory instead of VideoMemory by kernel driver" ON) +option(TIM_VX_ENABLE_TENSOR_CACHE "Enable tensor cache for const tensor" ON) set(CMAKE_CXX_STANDARD 14) set(CMAKE_POSITION_INDEPENDENT_CODE ON) @@ -46,6 +47,11 @@ if(${TIM_VX_ENABLE_40BIT}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DVSI_40BIT_VA_SUPPORT") endif() +if(${TIM_VX_ENABLE_TENSOR_CACHE}) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DENABLE_TENSOR_CACHE") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_TENSOR_CACHE") +endif() + if(${TIM_VX_ENABLE_CUSTOM_OP}) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DTIM_VX_ENABLE_CUSTOM_OP") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTIM_VX_ENABLE_CUSTOM_OP") @@ -93,6 +99,9 @@ if(TIM_VX_ENABLE_GRPC) include(cmake/gRPC.cmake) endif() +if(TIM_VX_ENABLE_TENSOR_CACHE) + find_package(OpenSSL REQUIRED) +endif() add_subdirectory("src/tim") if(TIM_VX_BUILD_EXAMPLES) diff --git a/README.md b/README.md index 5967a3f..3724958 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ cmake options: |`VIP_LITE_SDK` | full path to VIPLite sdk, required when `TIM_VX_ENABLE_PLATFORM_LITE`=ON | Not set | |`TIM_VX_ENABLE_GRPC` | Enable gPRC support, only work when `TIM_VX_ENABLE_PLATFORM`=ON | OFF | |`TIM_VX_DBG_ENABLE_TENSOR_HNDL` | Enable built-in tensor from handle | ON | +|`TIM_VX_ENABLE_TENSOR_CACHE` | Enable tensor cache for const tensor | ON | ---- Run unit test: diff --git a/src/tim/CMakeLists.txt b/src/tim/CMakeLists.txt index d0a7eb9..22cd93e 100644 --- a/src/tim/CMakeLists.txt +++ b/src/tim/CMakeLists.txt @@ -141,6 +141,10 @@ target_include_directories(${TARGET_NAME} PRIVATE ${INC_DIRS}) target_link_libraries(${TARGET_NAME} PUBLIC -Wl,--no-whole-archive ${OVXDRV_LIBRARIES} ${EXTERNAL_LIBS}) +if(${TIM_VX_ENABLE_TENSOR_CACHE}) + target_link_libraries(${TARGET_NAME} PUBLIC ${OPENSSL_CRYPTO_LIBRARY}) +endif() + if(${TIM_VX_USE_EXTERNAL_OVXLIB}) #-Wl,--whole-archive should not applied to external library, but only for shared library target_link_libraries(${TARGET_NAME} PUBLIC tim_internal) diff --git a/src/tim/transform/layout_inference_test.cc b/src/tim/transform/layout_inference_test.cc index ebd659c..a9f5c83 100644 --- a/src/tim/transform/layout_inference_test.cc +++ b/src/tim/transform/layout_inference_test.cc @@ -167,6 +167,56 @@ TEST(GroupedConv2d, kernel_bigger_than_input_SAME) { infer_input->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float)); infer_graph->Run(); + std::vector output(golden.size()); + EXPECT_TRUE(infer_output->CopyDataFromTensor(output.data())); + EXPECT_EQ(golden, output); +} + +TEST(FC, share_const_tensor) { + auto ctx = tim::vx::Context::Create(); + auto src_graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({2, 1}); + tim::vx::ShapeType kernel_shape({2, 2}); + tim::vx::ShapeType bias_shape({2}); + tim::vx::ShapeType output_shape({2, 1}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec kernel_spec(tim::vx::DataType::FLOAT32, kernel_shape, + tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, bias_shape, + tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec tran_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + std::vector in_data = {1,4,}; + std::vector weight = {-3,3,2,1,}; + std::vector bias = {0.1, 0.4,}; + std::vector golden = {-8, 25}; + auto input_tensor = src_graph->CreateTensor(input_spec); + auto weight_tensor = src_graph->CreateTensor(kernel_spec, weight.data()); + auto bias_tensor = src_graph->CreateTensor(bias_spec, bias.data()); + auto tran_tensor = src_graph->CreateTensor(tran_spec); + auto output_tensor = src_graph->CreateTensor(output_spec); + + auto op1 = src_graph->CreateOperation(0,2); + (*op1).BindInputs({input_tensor, weight_tensor, bias_tensor}).BindOutputs({tran_tensor}); + + auto op2 = src_graph->CreateOperation(0,2); + (*op2).BindInputs({tran_tensor, weight_tensor, bias_tensor}).BindOutputs({output_tensor}); + // Do layout inference + auto transform = tim::transform::LayoutInference(src_graph, ctx); + auto infer_graph = transform.first; + auto graph_io_map = transform.second; + infer_graph->Compile(); + + auto infer_input = graph_io_map[src_graph->InputsTensor()[0]]; + auto infer_output = graph_io_map[src_graph->OutputsTensor()[0]]; + + infer_input->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float)); + infer_graph->Run(); + std::vector output(golden.size()); EXPECT_TRUE(infer_output->CopyDataFromTensor(output.data())); EXPECT_EQ(golden, output); diff --git a/src/tim/vx/graph.cc b/src/tim/vx/graph.cc index 3dbc16e..db65934 100644 --- a/src/tim/vx/graph.cc +++ b/src/tim/vx/graph.cc @@ -24,6 +24,10 @@ #include "tim/vx/graph.h" #include +#ifdef ENABLE_TENSOR_CACHE +#include +#endif + #include "context_private.h" #include "graph_private.h" #include "op_impl.h" @@ -55,6 +59,96 @@ GraphImpl::GraphImpl(ContextImpl* context, const CompileOption& options) GraphImpl::~GraphImpl() { vsi_nn_ReleaseGraph(&graph_); } +#ifdef ENABLE_TENSOR_CACHE +std::map>& GraphImpl::GetTensorCacheMap() { + return cached_tensor_; +} + +#define MD5_SECRET_LEN_16 (16) +#define MD5_BYTE_STRING_LEN (4) +const std::string GraphImpl::caclulateMd5Secret32(const std::string& src) { + std::string md5String; + EVP_MD_CTX *mdctx; + const EVP_MD *md; + uint32_t md_len; + unsigned char md_value[MD5_SECRET_LEN_16] = {0}; + char tmp[MD5_BYTE_STRING_LEN] = {0}; + + OpenSSL_add_all_digests(); + md = EVP_md5(); + if (md == NULL) { + VSILOGE("Unknown EVP_md5 message."); + } + mdctx = EVP_MD_CTX_new(); + if (!EVP_DigestInit_ex(mdctx, md, NULL)) { + VSILOGE("EVP_MD_CTX initialization failed."); + EVP_MD_CTX_free(mdctx); + } + if (!EVP_DigestUpdate(mdctx, src.c_str(), src.size())) { + VSILOGE("EVP_MD_CTX update failed."); + EVP_MD_CTX_free(mdctx); + } + if (!EVP_DigestFinal_ex(mdctx, md_value, &md_len)) { + VSILOGE("EVP_MD_CTX finalization failed."); + EVP_MD_CTX_free(mdctx); + } + EVP_MD_CTX_free(mdctx); + + for (int i = 0; i < MD5_SECRET_LEN_16; ++i) { + memset(tmp, 0x00, sizeof(tmp)); + snprintf(tmp, sizeof(tmp), "%02X", md_value[i]); + md5String += tmp; + } + return md5String; +} + +const std::string GraphImpl::CaclulateCacheKey(const TensorSpec& spec, const void* data) { + std::string md5_key; + uint32_t data_size = 1; + for (auto it = spec.shape_.begin(); it != spec.shape_.end(); ++it) { + data_size *= *it; + } + switch (spec.datatype_) { + case DataType::INT16: + case DataType::UINT16: + case DataType::FLOAT16: + data_size *= 2; + break; + case DataType::INT32: + case DataType::UINT32: + case DataType::FLOAT32: + data_size *= 4; + break; + case DataType::INT64: + data_size *= 8; + break; + default: + break; + } + if (data_size < 512) { + md5_key = caclulateMd5Secret32(std::string((const char*)data, data_size)); + } else { + md5_key = caclulateMd5Secret32( + std::string((const char*)data, 512)); //Take first 512 bytes + } + return md5_key; +} + +std::shared_ptr GraphImpl::GetTensorFromCache(const TensorSpec& spec, const void* data) { + std::shared_ptr tensor; + std::string md5_key = CaclulateCacheKey(spec, data); + if (GetTensorCacheMap().find(md5_key) != GetTensorCacheMap().end() && + GetTensorCacheMap()[md5_key]->GetQuantization().Scales() == spec.quantization_.Scales() && + GetTensorCacheMap()[md5_key]->GetQuantization().ZeroPoints() == spec.quantization_.ZeroPoints()) { + tensor = GetTensorCacheMap()[md5_key]; + } else { + tensor = std::make_shared(this, spec, data); + GetTensorCacheMap()[md5_key] = tensor; + } + return tensor; +} +#endif + vsi_nn_graph_t* GraphImpl::graph() { return graph_; } void GraphImpl::AddInput(vsi_nn_tensor_id_t id) { @@ -135,6 +229,11 @@ void GraphImpl::PrintGraph() const { vsi_nn_PrintGraph(this->graph_); } std::shared_ptr GraphImpl::CreateTensor(const TensorSpec& spec, const void* data) { +#ifdef ENABLE_TENSOR_CACHE + if (spec.attr_ & TensorAttribute::CONSTANT && data != NULL) { + return GetTensorFromCache(spec, data); + } +#endif auto tensor = std::make_shared(this, spec, data); if (spec.attr_ & TensorAttribute::INPUT) { this->AddInput(tensor); diff --git a/src/tim/vx/graph_private.h b/src/tim/vx/graph_private.h index c693fb3..318b01d 100644 --- a/src/tim/vx/graph_private.h +++ b/src/tim/vx/graph_private.h @@ -26,6 +26,7 @@ #include "tim/vx/graph.h" #include +#include #include #include #include @@ -43,7 +44,12 @@ class GraphImpl : public Graph { public: GraphImpl(ContextImpl* context, const CompileOption& options = CompileOption::DefaultOptions); ~GraphImpl(); - +#ifdef ENABLE_TENSOR_CACHE + std::shared_ptr GetTensorFromCache(const TensorSpec& spec, const void* data); + const std::string CaclulateCacheKey(const TensorSpec& spec, const void* data); + const std::string caclulateMd5Secret32(const std::string& src); + std::map>& GetTensorCacheMap(); +#endif /// Return the low-level graph object vsi_nn_graph_t* graph(); void AddInput(vsi_nn_tensor_id_t id); @@ -97,7 +103,9 @@ class GraphImpl : public Graph { int32_t not_consumed_output_cnt_; std::map, std::vector>> tensor_consumers_; std::map, std::shared_ptr> tensor_producer_; - +#ifdef ENABLE_TENSOR_CACHE + std::map> cached_tensor_; +#endif CompileOption options_; private: /// Setup graph