From 097f8d74cdd809b7fffe15027bca21b9385247d6 Mon Sep 17 00:00:00 2001 From: Sven Date: Tue, 22 Mar 2022 23:00:52 +0800 Subject: [PATCH] Refine customized op support (#327) Signed-off-by: ZhangXiang --- include/tim/vx/ops/custom_base.h | 23 ++++++------ samples/custom_op_test/custom_gemm.h | 45 ++++++++++++++---------- samples/custom_op_test/custom_op_test.cc | 6 ++-- src/tim/vx/ops/custom_base.cc | 6 ++-- 4 files changed, 43 insertions(+), 37 deletions(-) diff --git a/include/tim/vx/ops/custom_base.h b/include/tim/vx/ops/custom_base.h index 1376660..1059924 100644 --- a/include/tim/vx/ops/custom_base.h +++ b/include/tim/vx/ops/custom_base.h @@ -46,9 +46,8 @@ struct Param { }; template -typename std::enable_if::type -transform_tuple_to_param_list(std::tuple tup, - std::vector& param_list) { +typename std::enable_if::type param_transform( + std::tuple tup, std::vector& param_list) { auto size = std::tuple_size::value; if (size == I && param_list.size() != 0) { //nothing to do @@ -57,9 +56,8 @@ transform_tuple_to_param_list(std::tuple tup, } template -typename std::enable_if<(I < sizeof...(Ts)), void>::type -transform_tuple_to_param_list(std::tuple tup, - std::vector& param_list) { +typename std::enable_if<(I < sizeof...(Ts)), void>::type param_transform( + std::tuple tup, std::vector& param_list) { if (typeid(std::get(tup)) == typeid(float)) { Param p; p.data.f = std::get(tup); @@ -84,7 +82,7 @@ transform_tuple_to_param_list(std::tuple tup, std::cout << "need more else if type of param list " << std::endl; } // Go to next element - transform_tuple_to_param_list(tup, param_list); + param_transform(tup, param_list); } class CustomOpBase : public Operation { @@ -93,15 +91,14 @@ class CustomOpBase : public Operation { int32_t kernel_id, const char* kernel_name); ~CustomOpBase(); - virtual void extract_parameter_and_register( + virtual void SetupParams( std::vector input_types, std::string& build_option) = 0; - virtual void setup_output_shape_info() = 0; + virtual void SetupShapeInfor() = 0; - virtual void setup_kernel_param(uint32_t& dim, - std::vector& gobal_size, - std::vector& local_size) = 0; + virtual void SetupEnqueue(uint32_t& dim, std::vector& gobal_size, + std::vector& local_size) = 0; std::vector param_list_; std::vector inputs_size_; @@ -120,4 +117,4 @@ class CustomOpBase : public Operation { } // namespace vx } // namespace tim -#endif /* TIM_VX_OPS_MATMUL_H_ */ \ No newline at end of file +#endif /* TIM_VX_OPS_CUSTOM_BASE_H_ */ \ No newline at end of file diff --git a/samples/custom_op_test/custom_gemm.h b/samples/custom_op_test/custom_gemm.h index 2a16db4..5fbb8f8 100644 --- a/samples/custom_op_test/custom_gemm.h +++ b/samples/custom_op_test/custom_gemm.h @@ -31,20 +31,32 @@ namespace tim { namespace vx { namespace ops { -//scalar param for kernel function input -using DeriveParmaTuple = std::tuple; + class CustomGemm : public CustomOpBase { public: + //scalar param for kernel function input + using ParamTuple = std::tuple; CustomGemm(Graph* graph, bool trans_a, bool trans_b, - DeriveParmaTuple tuple_list, - uint32_t input_num = 2, uint32_t output_num = 1) - : CustomOpBase(graph, input_num, output_num, - CustomGemm::kernel_id_, CustomGemm::kernel_name_), - trans_a_(trans_a),trans_b_(trans_b) { + ParamTuple tuple_list, uint32_t input_num = 2, + uint32_t output_num = 1) + : CustomOpBase(graph, input_num, output_num, CustomGemm::kernel_id_, + CustomGemm::kernel_name_), + trans_a_(trans_a), + trans_b_(trans_b) { tuple_list_.swap(tuple_list); - transform_tuple_to_param_list(tuple_list_, param_list_); + param_transform(tuple_list_, param_list_); kernel_resource_ = "__kernel void gemm_F32F32toF32_2D(\n\ @@ -87,14 +99,14 @@ class CustomGemm : public CustomOpBase { protected: const char* kernel_NotTransA_NotTransB = "gemm_F32F32toF32_2D"; const char* kernel_TransA_NotTransB = "....."; - DeriveParmaTuple tuple_list_; + ParamTuple tuple_list_; bool trans_a_; bool trans_b_; static const char* kernel_name_; static int32_t kernel_id_; //function for setup output - void setup_output_shape_info() override { + void SetupShapeInfor() override { if (!trans_a_ && !trans_a_) { outputs_size_[0].push_back(inputs_size_[0][1]); outputs_size_[0].push_back(inputs_size_[1][0]); @@ -105,7 +117,7 @@ class CustomGemm : public CustomOpBase { } //function for kernel select and build option - void extract_parameter_and_register( + void SetupParams( std::vector input_types, std::string& build_option) override { if (trans_a_ == false && @@ -115,16 +127,13 @@ class CustomGemm : public CustomOpBase { func_name_ = kernel_NotTransA_NotTransB; build_option = ""; } else { - //other situation: named func_name_ and setup param_list - //func_name_ = "......"; - //std::get<2>(param_list) = ...... + // other situation: named func_name_ and setup param_list } - return; } //function for kernel local size and gobal size - void setup_kernel_param(uint32_t& dim, std::vector& global_size, - std::vector& local_size) { + void SetupEnqueue(uint32_t& dim, std::vector& global_size, + std::vector& local_size) { dim = 3; local_size[0] = 0; local_size[1] = 0; diff --git a/samples/custom_op_test/custom_op_test.cc b/samples/custom_op_test/custom_op_test.cc index dfbcea9..07441ca 100644 --- a/samples/custom_op_test/custom_op_test.cc +++ b/samples/custom_op_test/custom_op_test.cc @@ -65,7 +65,7 @@ void custom_gemm_single_test(){ a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float)); b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float)); - tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0); + tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0); auto op = graph->CreateOperation( false,false,tuple_list); @@ -143,7 +143,7 @@ void custom_gemm_op_and_add_op_test(){ auto op_add = graph->CreateOperation(2); (*op_add).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor}); - tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0); + tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0); auto op_gemm = graph->CreateOperation( false,false,tuple_list); @@ -210,7 +210,7 @@ void custom_gemm_op_and_custom_gemm_op_test(){ b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float)); d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float)); - tim::vx::ops::DeriveParmaTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0); + tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0); auto op_gemm = graph->CreateOperation( false,false,tuple_list); diff --git a/src/tim/vx/ops/custom_base.cc b/src/tim/vx/ops/custom_base.cc index 28c223f..d8e748c 100644 --- a/src/tim/vx/ops/custom_base.cc +++ b/src/tim/vx/ops/custom_base.cc @@ -88,7 +88,7 @@ vsi_bool op_setup(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, op_this->inputs_size_.push_back(input_size); } - op_this->setup_output_shape_info(); + op_this->SetupShapeInfor(); for (uint32_t i = 0; i < op_this->outputs_size_.size(); i++) { outputs[i]->attr.dim_num = op_this->outputs_size_[i].size(); @@ -129,7 +129,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, } std::string build_option; - op_this->extract_parameter_and_register(input_types, build_option); + op_this->SetupParams(input_types, build_option); snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_); kernel->unique_id = @@ -233,7 +233,7 @@ vx_status derive_kernel_init(vx_node node, const vx_reference* param, auto iter = node_base_map_.find(reinterpret_cast(node)); if (iter != node_base_map_.end()) { - iter->second->setup_kernel_param(dim, global_size, local_size); + iter->second->SetupEnqueue(dim, global_size, local_size); } else { std::cout << "Something wrong in finding gpu param setup function" << std::endl;