From b02aa8b8c498fb0b36b377339dc78dd7d3e782e7 Mon Sep 17 00:00:00 2001 From: Zhouheng Zheng Date: Wed, 9 Mar 2022 12:10:08 +0800 Subject: [PATCH] Added customize operator APIs(#315) Co-authored-by: zhouheng.zheng --- include/tim/vx/ops.h | 1 + include/tim/vx/ops/custom_base.h | 123 +++++++++++ samples/CMakeLists.txt | 1 + samples/custom_op_test/CMakeLists.txt | 12 ++ samples/custom_op_test/custom_gemm.cc | 36 ++++ samples/custom_op_test/custom_gemm.h | 148 ++++++++++++++ samples/custom_op_test/custom_op_test.cc | 242 ++++++++++++++++++++++ src/tim/CMakeLists.txt | 2 + src/tim/vx/direct_map_op_impl.cc | 15 ++ src/tim/vx/direct_map_op_impl.h | 12 +- src/tim/vx/op_impl.cc | 5 + src/tim/vx/op_impl.h | 2 + src/tim/vx/ops/custom_base.cc | 250 +++++++++++++++++++++++ 13 files changed, 847 insertions(+), 2 deletions(-) create mode 100644 include/tim/vx/ops/custom_base.h create mode 100644 samples/custom_op_test/CMakeLists.txt create mode 100644 samples/custom_op_test/custom_gemm.cc create mode 100644 samples/custom_op_test/custom_gemm.h create mode 100644 samples/custom_op_test/custom_op_test.cc create mode 100644 src/tim/vx/ops/custom_base.cc diff --git a/include/tim/vx/ops.h b/include/tim/vx/ops.h index 7e36e3f..9f2cc87 100644 --- a/include/tim/vx/ops.h +++ b/include/tim/vx/ops.h @@ -85,5 +85,6 @@ #include "tim/vx/ops/unidirectional_sequence_lstm.h" #include "tim/vx/ops/unstack.h" #include "tim/vx/ops/conv3d.h" +#include "tim/vx/ops/custom_base.h" #endif /* TIM_VX_OPS_H_ */ diff --git a/include/tim/vx/ops/custom_base.h b/include/tim/vx/ops/custom_base.h new file mode 100644 index 0000000..1376660 --- /dev/null +++ b/include/tim/vx/ops/custom_base.h @@ -0,0 +1,123 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_CUSTOM_BASE_H_ +#define TIM_VX_OPS_CUSTOM_BASE_H_ + +#include "tim/vx/direct_map_op.h" +#include +#include + +namespace tim { +namespace vx { +namespace ops { +#define gpu_align(n, align) ((n) + ((align)-1)) & ~((align)-1) + +__attribute__((unused)) static int32_t gobal_kernel_id_ = 0; + +struct Param { + union { + float f; + int i; + bool b; + uint ui; + } data; + tim::vx::DataType type; +}; + +template +typename std::enable_if::type +transform_tuple_to_param_list(std::tuple tup, + std::vector& param_list) { + auto size = std::tuple_size::value; + if (size == I && param_list.size() != 0) { + //nothing to do + } + return; +} + +template +typename std::enable_if<(I < sizeof...(Ts)), void>::type +transform_tuple_to_param_list(std::tuple tup, + std::vector& param_list) { + if (typeid(std::get(tup)) == typeid(float)) { + Param p; + p.data.f = std::get(tup); + p.type = tim::vx::DataType::FLOAT32; + param_list.push_back(p); + } else if (typeid(std::get(tup)) == typeid(uint32_t)) { + Param p; + p.data.ui = std::get(tup); + p.type = tim::vx::DataType::UINT32; + param_list.push_back(p); + } else if (typeid(std::get(tup)) == typeid(int32_t)) { + Param p; + p.data.i = std::get(tup); + p.type = tim::vx::DataType::INT32; + param_list.push_back(p); + } else if (typeid(std::get(tup)) == typeid(bool)) { + Param p; + p.data.b = std::get(tup); + p.type = tim::vx::DataType::BOOL8; + param_list.push_back(p); + } else { + std::cout << "need more else if type of param list " << std::endl; + } + // Go to next element + transform_tuple_to_param_list(tup, param_list); +} + +class CustomOpBase : public Operation { + public: + CustomOpBase(Graph* graph, uint32_t input_num, uint32_t output_num, + int32_t kernel_id, const char* kernel_name); + + ~CustomOpBase(); + virtual void extract_parameter_and_register( + std::vector input_types, + std::string& build_option) = 0; + + virtual void setup_output_shape_info() = 0; + + virtual void setup_kernel_param(uint32_t& dim, + std::vector& gobal_size, + std::vector& local_size) = 0; + + std::vector param_list_; + std::vector inputs_size_; + std::vector outputs_size_; + + const char* func_name_; + const char* kernel_resource_; + void* init_kernel_; + void* vx_node_; + + uint32_t input_num_; + uint32_t output_num_; +}; + +} // namespace ops +} // namespace vx +} // namespace tim + +#endif /* TIM_VX_OPS_MATMUL_H_ */ \ No newline at end of file diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index 9df246b..b80d9e4 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory("benchmark_test") +add_subdirectory("custom_op_test") add_subdirectory("lenet") if(${TIM_VX_ENABLE_VIPLITE}) add_subdirectory("lenet_lite") diff --git a/samples/custom_op_test/CMakeLists.txt b/samples/custom_op_test/CMakeLists.txt new file mode 100644 index 0000000..649d192 --- /dev/null +++ b/samples/custom_op_test/CMakeLists.txt @@ -0,0 +1,12 @@ +message("samples/custom_op_test") + +set(TARGET_NAME "custom_op_test") + +aux_source_directory(. ${TARGET_NAME}_SRCS) +add_executable(${TARGET_NAME} ${${TARGET_NAME}_SRCS}) + +target_link_libraries(${TARGET_NAME} PRIVATE tim-vx) +target_include_directories(${TARGET_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/include +) \ No newline at end of file diff --git a/samples/custom_op_test/custom_gemm.cc b/samples/custom_op_test/custom_gemm.cc new file mode 100644 index 0000000..74306c6 --- /dev/null +++ b/samples/custom_op_test/custom_gemm.cc @@ -0,0 +1,36 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "custom_gemm.h" + +namespace tim { +namespace vx { +namespace ops { + +const char* CustomGemm::kernel_name_ = "xxxx_name12345"; +int32_t CustomGemm::kernel_id_ = -1 * (++gobal_kernel_id_); + + +} // namespace ops +} // namespace vx +} // namespace tim \ No newline at end of file diff --git a/samples/custom_op_test/custom_gemm.h b/samples/custom_op_test/custom_gemm.h new file mode 100644 index 0000000..2a16db4 --- /dev/null +++ b/samples/custom_op_test/custom_gemm.h @@ -0,0 +1,148 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ + +#ifndef TIM_VX_OPS_CUSTOM_GEMM_H_ +#define TIM_VX_OPS_CUSTOM_GEMM_H_ + +#include "tim/vx/ops/custom_base.h" + +namespace tim { +namespace vx { +namespace ops { + +//scalar param for kernel function input +using DeriveParmaTuple = std::tuple; + +class CustomGemm : public CustomOpBase { + public: + CustomGemm(Graph* graph, bool trans_a, bool trans_b, + DeriveParmaTuple tuple_list, + uint32_t input_num = 2, uint32_t output_num = 1) + : CustomOpBase(graph, input_num, output_num, + CustomGemm::kernel_id_, CustomGemm::kernel_name_), + trans_a_(trans_a),trans_b_(trans_b) { + tuple_list_.swap(tuple_list); + transform_tuple_to_param_list(tuple_list_, param_list_); + + kernel_resource_ = + "__kernel void gemm_F32F32toF32_2D(\n\ + __read_only image2d_t inputA,\n\ + __read_only image2d_t inputB,\n\ + __write_only image2d_t output,\n\ + int M,\n\ + int K,\n\ + int N,\n\ + int ac2zero,\n\ + int bc2zero,\n\ + float scale_a,\n\ + float zp_a,\n\ + float scale_b,\n\ + float zp_b,\n\ + float scale_out,\n\ + float zp_out\n\ + )\n\ +{\n\ + int4 coord = (int4)(get_global_id(0), get_global_id(1), 0, 0);\n\ + float4 sum = (float4)(0);\n\ +\n\ + for(; coord.z < K;)\n\ + {\n\ + float4 tempA0;\n\ + float4 tempB0;\n\ +\n\ + tempA0 = read_imagef(inputA, coord.zy);\n\ + tempB0 = read_imagef(inputB, coord.xz);\n\ + coord.z++;\n\ +\n\ + sum = sum + tempA0 * tempB0;\n\ + }\n\ + write_imagef(output, coord.xy, sum);\n\ +}\n\ +\n\ +"; + }; + + protected: + const char* kernel_NotTransA_NotTransB = "gemm_F32F32toF32_2D"; + const char* kernel_TransA_NotTransB = "....."; + DeriveParmaTuple tuple_list_; + bool trans_a_; + bool trans_b_; + static const char* kernel_name_; + static int32_t kernel_id_; + + //function for setup output + void setup_output_shape_info() override { + if (!trans_a_ && !trans_a_) { + outputs_size_[0].push_back(inputs_size_[0][1]); + outputs_size_[0].push_back(inputs_size_[1][0]); + } else { + //other situation: set up outputs_size + //outputs_size_[0].push_back()...... + } + } + + //function for kernel select and build option + void extract_parameter_and_register( + std::vector input_types, + std::string& build_option) override { + if (trans_a_ == false && + trans_a_ == false && + input_types[0] == tim::vx::DataType::FLOAT32 && + input_types[1] == tim::vx::DataType::FLOAT32) { + func_name_ = kernel_NotTransA_NotTransB; + build_option = ""; + } else { + //other situation: named func_name_ and setup param_list + //func_name_ = "......"; + //std::get<2>(param_list) = ...... + } + return; + } + + //function for kernel local size and gobal size + void setup_kernel_param(uint32_t& dim, std::vector& global_size, + std::vector& local_size) { + dim = 3; + local_size[0] = 0; + local_size[1] = 0; + local_size[2] = 0; + + global_size[0] = gpu_align(outputs_size_[0][0], 4); + global_size[1] = gpu_align(outputs_size_[0][1], 4); + global_size[2] = outputs_size_[0].size() > 2 ? outputs_size_[0][2] : 1; + } + + std::shared_ptr Clone( + std::shared_ptr& graph) const override { + return graph->CreateOperation(trans_a_,trans_b_, + this->tuple_list_, this->input_num_, this->output_num_); + } +}; + +} // namespace ops +} // namespace vx +} // namespace tim +#endif \ No newline at end of file diff --git a/samples/custom_op_test/custom_op_test.cc b/samples/custom_op_test/custom_op_test.cc new file mode 100644 index 0000000..dfbcea9 --- /dev/null +++ b/samples/custom_op_test/custom_op_test.cc @@ -0,0 +1,242 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops.h" +#include "custom_gemm.h" +#include + +void custom_gemm_single_test(){ + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType a_shape({6, 2}); + tim::vx::ShapeType b_shape({2, 6}); + tim::vx::ShapeType out_shape({2, 2}); + tim::vx::TensorSpec a_spec(tim::vx::DataType::FLOAT32, + a_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec b_spec(tim::vx::DataType::FLOAT32, + b_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32, + out_shape, tim::vx::TensorAttribute::OUTPUT); + + auto a_tensor = graph->CreateTensor(a_spec); + auto b_tensor = graph->CreateTensor(b_spec); + auto out_tensor = graph->CreateTensor(out_spec); + + std::vector a_data = { + 1, 2, 3, 4, 5, 6, + -1, -2, -3, -4, -5, -6 + }; + std::vector b_data = { + 6, 5, + 4, 3, + 2, 1, + -6, -5, + -4, -3, + -2, -1 + }; + std::vector golden = { + -36, -27, + 36, 27 + }; + + a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float)); + b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float)); + + tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0); + + auto op = graph->CreateOperation( + false,false,tuple_list); + + (*op).BindInputs({a_tensor, b_tensor}).BindOutputs({out_tensor}); + + graph->Compile(); + graph->Run(); + + std::vector output(golden.size()); + out_tensor->CopyDataFromTensor(output.data()); + + std::cout<<"the diff between golan and result:"<CreateGraph(); + + tim::vx::ShapeType a_shape({6, 2}); + tim::vx::ShapeType b_shape({6, 2}); + tim::vx::ShapeType c_shape({6, 2}); + tim::vx::ShapeType d_shape({2, 6}); + + tim::vx::ShapeType out_shape({2, 2}); + + tim::vx::TensorSpec a_spec(tim::vx::DataType::FLOAT32, + a_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec b_spec(tim::vx::DataType::FLOAT32, + b_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec c_spec(tim::vx::DataType::FLOAT32, + c_shape, tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec d_spec(tim::vx::DataType::FLOAT32, + d_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32, + out_shape, tim::vx::TensorAttribute::OUTPUT); + + auto a_tensor = graph->CreateTensor(a_spec); + auto b_tensor = graph->CreateTensor(b_spec); + auto c_tensor = graph->CreateTensor(c_spec); + auto d_tensor = graph->CreateTensor(d_spec); + auto out_tensor = graph->CreateTensor(out_spec); + + std::vector a_data = { + 0, 1, 2, 3, 4, 5, + -1, -2, -3, -4, -5, -6 + }; + + std::vector b_data = { + 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0 + }; + + std::vector d_data = { + 6, 5, + 4, 3, + 2, 1, + -6, -5, + -4, -3, + -2, -1 + }; + std::vector golden = { + -36, -27, + 36, 27 + }; + + a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float)); + b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float)); + d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float)); + + auto op_add = graph->CreateOperation(2); + (*op_add).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor}); + + tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0); + + auto op_gemm = graph->CreateOperation( + false,false,tuple_list); + + (*op_gemm).BindInputs({c_tensor, d_tensor}).BindOutputs({out_tensor}); + + graph->Compile(); + graph->Run(); + + std::vector output(golden.size()); + out_tensor->CopyDataFromTensor(output.data()); + std::cout<<"the diff between golan and result:"<CreateGraph(); + + + tim::vx::ShapeType a_shape({2, 2}); + tim::vx::ShapeType b_shape({2, 2}); + tim::vx::ShapeType c_shape({2, 2}); + tim::vx::ShapeType d_shape({2, 2}); + + tim::vx::ShapeType out_shape({2, 2}); + + tim::vx::TensorSpec a_spec(tim::vx::DataType::FLOAT32, + a_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec b_spec(tim::vx::DataType::FLOAT32, + b_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec c_spec(tim::vx::DataType::FLOAT32, + c_shape, tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec d_spec(tim::vx::DataType::FLOAT32, + d_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32, + out_shape, tim::vx::TensorAttribute::OUTPUT); + + auto a_tensor = graph->CreateTensor(a_spec); + auto b_tensor = graph->CreateTensor(b_spec); + auto c_tensor = graph->CreateTensor(c_spec); + auto d_tensor = graph->CreateTensor(d_spec); + auto out_tensor = graph->CreateTensor(out_spec); + + std::vector a_data = { + 1, 1, 1,1 + }; + + std::vector b_data = { + 1, 1, 1, 1 + }; + + std::vector d_data = { + 1,1,1,1 + }; + std::vector golden = { + 4,4,4,4 + }; + + a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float)); + b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float)); + d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float)); + + tim::vx::ops::DeriveParmaTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0); + + auto op_gemm = graph->CreateOperation( + false,false,tuple_list); + + (*op_gemm).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor}); + + auto op_gemm2 = graph->CreateOperation( + false,false,tuple_list); + + (*op_gemm2).BindInputs({c_tensor, d_tensor}).BindOutputs({out_tensor}); + + graph->Compile(); + graph->Run(); + + std::vector output(golden.size()); + out_tensor->CopyDataFromTensor(output.data()); + std::cout<<"the diff between golan and result:"<uid = graph_->graph()->cur_nid; } +DirectMapOpImpl::DirectMapOpImpl(Graph* graph,DataLayout layout) + : OpImpl(graph, layout){} + + DirectMapOpImpl& DirectMapOpImpl::BindInput( const std::shared_ptr& tensor) { inputs_tensor_.push_back(tensor); @@ -71,5 +75,16 @@ void DirectMapOpImpl::SetRoundingPolicy(OverflowPolicy overflow_policy, node_->vx_param.accumulator_bits = accumulator_bits; } +CustomOpBaseImpl::CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc, + const char* kernel_name, DataLayout layout) + : DirectMapOpImpl(graph, layout) { + op_proc_ = proc; + vsi_nn_node_t* node = vsi_nn_AddExternalNode(graph_->graph(), operation_id, + proc, NULL, kernel_name); + node->uid = graph_->graph()->cur_nid; + SetNode(node); + SetRoundingPolicy(); + }; + } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/direct_map_op_impl.h b/src/tim/vx/direct_map_op_impl.h index 88c49c5..ca57323 100644 --- a/src/tim/vx/direct_map_op_impl.h +++ b/src/tim/vx/direct_map_op_impl.h @@ -34,16 +34,16 @@ namespace vx { class DirectMapOpImpl : public OpImpl { public: - // DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt = 0, - // int output_cnt = 0); DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt = 0, int output_cnt = 0, DataLayout layout = DataLayout::ANY); + DirectMapOpImpl(Graph* graph,DataLayout layout = DataLayout::ANY); ~DirectMapOpImpl() {} DirectMapOpImpl& BindInput(const std::shared_ptr& tensor) override; DirectMapOpImpl& BindOutput(const std::shared_ptr& tensor) override; vsi_nn_node_t* node() override { return this->node_; } + void SetNode(vsi_nn_node_t* node) {this->node_ = node; } void SetRoundingPolicy( OverflowPolicy overflow_policy = OverflowPolicy::SATURATE, @@ -62,6 +62,14 @@ class DirectMapOpImpl : public OpImpl { vsi_nn_node_t* node_{nullptr}; }; +class CustomOpBaseImpl : public DirectMapOpImpl { + public: + CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc, + const char* kernel_name, DataLayout layout = DataLayout::ANY); + protected: + const void* op_proc_; +}; + } // namespace vx } // namespace tim diff --git a/src/tim/vx/op_impl.cc b/src/tim/vx/op_impl.cc index a0366b2..68f280f 100644 --- a/src/tim/vx/op_impl.cc +++ b/src/tim/vx/op_impl.cc @@ -33,5 +33,10 @@ OpImpl::OpImpl(Graph* graph, uint32_t kind, int input_cnt, int output_cnt, input_cnt_(input_cnt), output_cnt_(output_cnt), layout_(layout) {} + +OpImpl::OpImpl(Graph* graph, DataLayout layout) + : graph_(reinterpret_cast(graph)), + layout_(layout) {} + } // namespace vx } // namespace tim diff --git a/src/tim/vx/op_impl.h b/src/tim/vx/op_impl.h index 637deee..b27f320 100644 --- a/src/tim/vx/op_impl.h +++ b/src/tim/vx/op_impl.h @@ -35,6 +35,8 @@ class OpImpl { public: OpImpl(Graph* graph, uint32_t kind, int input_cnt, int output_cnt, DataLayout layout); + OpImpl(Graph* graph, DataLayout layout); + virtual ~OpImpl() = default; virtual OpImpl& BindInput(const std::shared_ptr& tensor) = 0; virtual OpImpl& BindOutput(const std::shared_ptr& tensor) = 0; diff --git a/src/tim/vx/ops/custom_base.cc b/src/tim/vx/ops/custom_base.cc new file mode 100644 index 0000000..b5fff8b --- /dev/null +++ b/src/tim/vx/ops/custom_base.cc @@ -0,0 +1,250 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ + +#include +#include +#include "tim/vx/ops.h" +#include "direct_map_op_impl.h" +#include "vsi_nn_pub.h" + +#include "kernel/vsi_nn_kernel.h" + +namespace tim { +namespace vx { +namespace ops { + +static vsi_bool op_setup(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, + vsi_nn_tensor_t** outputs); + +static vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, + vsi_nn_tensor_t** outputs); + +static vx_status derive_kernel_init(vx_node node, const vx_reference* param, + vx_uint32 param_size); + +static std::map node_base_map_; + +CustomOpBase::CustomOpBase(Graph* graph, uint32_t input_num, + uint32_t output_num, int32_t kernel_id, + const char* kernel_name) + : input_num_(input_num), output_num_(output_num) { + init_kernel_ = reinterpret_cast(derive_kernel_init); + vsi_nn_op_proc_t proc = {NULL, op_compute, NULL, NULL, + op_setup, NULL, input_num_, output_num_}; + this->impl() = std::make_unique( + graph, kernel_id, reinterpret_cast(&proc), kernel_name); + this->impl()->node()->nn_param.client_param = reinterpret_cast(this); +} + +CustomOpBase::~CustomOpBase(){ + auto iter = node_base_map_.find(this->vx_node_); + if (iter != node_base_map_.end()) { + node_base_map_.erase(this->vx_node_); + } +} +vsi_bool op_setup(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, + vsi_nn_tensor_t** outputs) { + CustomOpBase* op_this = + reinterpret_cast(self->nn_param.client_param); + + for (uint32_t i = 0; i < op_this->output_num_; i++) { + std::vector output_size; + op_this->outputs_size_.push_back(output_size); + } + + for (uint32_t i = 0; i < op_this->input_num_; i++) { + std::vector input_size; + for (uint32_t j = 0; j < inputs[i]->attr.dim_num; j++) { + input_size.push_back(inputs[i]->attr.size[j]); + } + op_this->inputs_size_.push_back(input_size); + } + + op_this->setup_output_shape_info(); + + for (uint32_t i = 0; i < op_this->outputs_size_.size(); i++) { + outputs[i]->attr.dim_num = op_this->outputs_size_[i].size(); + for (uint32_t j = 0; j < op_this->outputs_size_[i].size(); j++) { + outputs[i]->attr.size[j] = op_this->outputs_size_[i][j]; + } + } + return TRUE; +}; + +vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, + vsi_nn_tensor_t** outputs) { + vsi_status status = VSI_FAILURE; + auto kernel = vsi_nn_kernel_create(VSI_NN_KERNEL_TYPE_CL); + CustomOpBase* op_this = + reinterpret_cast(self->nn_param.client_param); + + uint32_t param_num = op_this->param_list_.size(); + + std::vector input_types; + for (uint32_t i = 0; i < op_this->input_num_; i++) { + if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32) { + input_types.push_back(tim::vx::DataType::FLOAT32); + } else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_UINT32) { + input_types.push_back(tim::vx::DataType::UINT32); + } else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_INT32) { + input_types.push_back(tim::vx::DataType::INT32); + } else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_BOOL8) { + input_types.push_back(tim::vx::DataType::BOOL8); + } else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_UINT8) { + input_types.push_back(tim::vx::DataType::UINT8); + } else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_INT8) { + input_types.push_back(tim::vx::DataType::INT8); + } else { + std::cout << "Can not find att type in op compute" << std::endl; + assert(false); + } + } + + std::string build_option; + op_this->extract_parameter_and_register(input_types, build_option); + + snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_); + kernel->unique_id = + std::hash()(std::string(op_this->func_name_)); + vx_param_description_t kernel_param_def[param_num]; + + for (uint32_t i = 0; i < op_this->input_num_; i++) { + kernel_param_def[i] = {VX_INPUT, VX_TYPE_TENSOR, + VX_PARAMETER_STATE_REQUIRED}; + } + for (uint32_t i = 0; i < op_this->output_num_; i++) { + kernel_param_def[op_this->input_num_ + i] = {VX_OUTPUT, VX_TYPE_TENSOR, + VX_PARAMETER_STATE_REQUIRED}; + } + + for (uint32_t i = 0; i < param_num; i++) { + kernel_param_def[op_this->input_num_ + op_this->output_num_ + i] = { + VX_INPUT, VX_TYPE_SCALAR, VX_PARAMETER_STATE_REQUIRED}; + } + + kernel->info.parameters = kernel_param_def; + kernel->info.enumeration = KERNEL_ID_PLACEHOLDER; + kernel->info.numParams = param_num; + kernel->info.initialize = + reinterpret_cast(op_this->init_kernel_); + + vsi_nn_kernel_add_source(kernel, VSI_NN_GPU_SOURCE_FMT_EXECUTABLE, 1, + "executable_name"); + + vsi_nn_kernel_add_source(kernel, VSI_NN_GPU_SOURCE_FMT_CODE, 2, "helper", + "fmt_code_name"); + + const char* tmp[] = {"", op_this->kernel_resource_}; + const char** resource = tmp; + + vsi_nn_kernel_add_build_option(kernel, build_option.c_str()); + + auto node = vsi_nn_kernel_create_node_ext(self->graph, kernel, resource); + if (node) { + vsi_nn_kernel_node_param_t node_params[param_num] = {NULL}; + vsi_nn_kernel_node_pack_io(node_params, param_num, inputs, + op_this->input_num_, outputs, + op_this->output_num_); + + uint32_t input_start = op_this->input_num_ + op_this->output_num_; + for (uint32_t i = 0; i < op_this->param_list_.size(); i++) { + if (op_this->param_list_[i].type == tim::vx::DataType::FLOAT32) { + node_params[input_start++] = vsi_nn_kernel_scalar_create( + self->graph, F32, &(op_this->param_list_[i].data.f)); + } else if (op_this->param_list_[i].type == tim::vx::DataType::UINT32) { + node_params[input_start++] = vsi_nn_kernel_scalar_create( + self->graph, U32, &(op_this->param_list_[i].data.ui)); + } else if (op_this->param_list_[i].type == tim::vx::DataType::INT32) { + node_params[input_start++] = vsi_nn_kernel_scalar_create( + self->graph, I32, &(op_this->param_list_[i].data.i)); + } else if (op_this->param_list_[i].type == tim::vx::DataType::BOOL8) { + node_params[input_start++] = vsi_nn_kernel_scalar_create( + self->graph, BOOL8, &(op_this->param_list_[i].data.b)); + }else if (op_this->param_list_[i].type == tim::vx::DataType::UINT8) { + node_params[input_start++] = vsi_nn_kernel_scalar_create( + self->graph, U8, &(op_this->param_list_[i].data.b)); + } else if (op_this->param_list_[i].type == tim::vx::DataType::INT8) { + node_params[input_start++] = vsi_nn_kernel_scalar_create( + self->graph, I8, &(op_this->param_list_[i].data.b)); + } else{ + std::cout << "Can not find scalar type in op compute" << std::endl; + assert(false); + } + } + + input_start = op_this->input_num_ + op_this->output_num_; + status = vsi_nn_kernel_node_pass_param(node, node_params, param_num); + for (uint32_t i = 0; i < param_num; i++) { + vsi_nn_kernel_scalar_release(&node_params[input_start + i]); + } + } + self->n = (vx_node)node; + + node_base_map_.insert(std::pair(reinterpret_cast(self->n), op_this)); + op_this->vx_node_ = reinterpret_cast(self->n); + return status; +} + +vx_status derive_kernel_init(vx_node node, const vx_reference* param, + vx_uint32 param_size) { + vsi_status status = VSI_FAILURE; + if (param_size == 0 && param == nullptr) { + return status; + } + + gpu_param_t gpu_param = {3, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; + + std::vector global_size(3); + std::vector local_size(3); + uint32_t dim = 0; + + auto iter = node_base_map_.find(reinterpret_cast(node)); + if (iter != node_base_map_.end()) { + iter->second->setup_kernel_param(dim, global_size, local_size); + } else { + std::cout << "Something wrong in finding gpu param setup function" + << std::endl; + assert(false); + } + + gpu_param.dim = dim; + gpu_param.global_scale[0] = 1; + gpu_param.global_scale[1] = 1; + gpu_param.global_scale[2] = 1; + + gpu_param.global_size[0] = global_size[0]; + gpu_param.global_size[1] = global_size[1]; + gpu_param.global_size[2] = global_size[2]; + + gpu_param.local_size[0] = local_size[0]; + gpu_param.local_size[1] = local_size[1]; + gpu_param.local_size[2] = local_size[2]; + status = vsi_nn_kernel_gpu_config(node, &gpu_param); + + return status; +} + +} // namespace ops +} // namespace vx +} // namespace tim \ No newline at end of file