Refine customized op support (#327)

Signed-off-by: ZhangXiang <Xiang.Zhang@verisilicon.com>
This commit is contained in:
Sven 2022-03-22 23:00:52 +08:00 committed by GitHub
parent 5bab9964e9
commit 097f8d74cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 37 deletions

View File

@ -46,9 +46,8 @@ struct Param {
}; };
template <size_t I = 0, typename... Ts> template <size_t I = 0, typename... Ts>
typename std::enable_if<I == sizeof...(Ts), void>::type typename std::enable_if<I == sizeof...(Ts), void>::type param_transform(
transform_tuple_to_param_list(std::tuple<Ts...> tup, std::tuple<Ts...> tup, std::vector<Param>& param_list) {
std::vector<Param>& param_list) {
auto size = std::tuple_size<decltype(tup)>::value; auto size = std::tuple_size<decltype(tup)>::value;
if (size == I && param_list.size() != 0) { if (size == I && param_list.size() != 0) {
//nothing to do //nothing to do
@ -57,9 +56,8 @@ transform_tuple_to_param_list(std::tuple<Ts...> tup,
} }
template <size_t I = 0, typename... Ts> template <size_t I = 0, typename... Ts>
typename std::enable_if<(I < sizeof...(Ts)), void>::type typename std::enable_if<(I < sizeof...(Ts)), void>::type param_transform(
transform_tuple_to_param_list(std::tuple<Ts...> tup, std::tuple<Ts...> tup, std::vector<Param>& param_list) {
std::vector<Param>& param_list) {
if (typeid(std::get<I>(tup)) == typeid(float)) { if (typeid(std::get<I>(tup)) == typeid(float)) {
Param p; Param p;
p.data.f = std::get<I>(tup); p.data.f = std::get<I>(tup);
@ -84,7 +82,7 @@ transform_tuple_to_param_list(std::tuple<Ts...> tup,
std::cout << "need more else if type of param list " << std::endl; std::cout << "need more else if type of param list " << std::endl;
} }
// Go to next element // Go to next element
transform_tuple_to_param_list<I + 1>(tup, param_list); param_transform<I + 1>(tup, param_list);
} }
class CustomOpBase : public Operation { class CustomOpBase : public Operation {
@ -93,15 +91,14 @@ class CustomOpBase : public Operation {
int32_t kernel_id, const char* kernel_name); int32_t kernel_id, const char* kernel_name);
~CustomOpBase(); ~CustomOpBase();
virtual void extract_parameter_and_register( virtual void SetupParams(
std::vector<tim::vx::DataType> input_types, std::vector<tim::vx::DataType> input_types,
std::string& build_option) = 0; std::string& build_option) = 0;
virtual void setup_output_shape_info() = 0; virtual void SetupShapeInfor() = 0;
virtual void setup_kernel_param(uint32_t& dim, virtual void SetupEnqueue(uint32_t& dim, std::vector<size_t>& gobal_size,
std::vector<size_t>& gobal_size, std::vector<size_t>& local_size) = 0;
std::vector<size_t>& local_size) = 0;
std::vector<Param> param_list_; std::vector<Param> param_list_;
std::vector<tim::vx::ShapeType> inputs_size_; std::vector<tim::vx::ShapeType> inputs_size_;
@ -120,4 +117,4 @@ class CustomOpBase : public Operation {
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim
#endif /* TIM_VX_OPS_MATMUL_H_ */ #endif /* TIM_VX_OPS_CUSTOM_BASE_H_ */

View File

@ -31,20 +31,32 @@ namespace tim {
namespace vx { namespace vx {
namespace ops { namespace ops {
//scalar param for kernel function input
using DeriveParmaTuple = std::tuple<int, int, int, int, int, float, float,
float, float, float, float>;
class CustomGemm : public CustomOpBase { class CustomGemm : public CustomOpBase {
public: public:
//scalar param for kernel function input
using ParamTuple = std::tuple<int, /* M */
int, /* K */
int, /* N */
int, /* ac2zero */
int, /* bc2zero */
float, /* scale_a */
float, /* zp_a */
float, /* scale_b */
float, /* zp_b */
float, /* scale_out */
float /* zp_out */
>;
CustomGemm(Graph* graph, bool trans_a, bool trans_b, CustomGemm(Graph* graph, bool trans_a, bool trans_b,
DeriveParmaTuple tuple_list, ParamTuple tuple_list, uint32_t input_num = 2,
uint32_t input_num = 2, uint32_t output_num = 1) uint32_t output_num = 1)
: CustomOpBase(graph, input_num, output_num, : CustomOpBase(graph, input_num, output_num, CustomGemm::kernel_id_,
CustomGemm::kernel_id_, CustomGemm::kernel_name_), CustomGemm::kernel_name_),
trans_a_(trans_a),trans_b_(trans_b) { trans_a_(trans_a),
trans_b_(trans_b) {
tuple_list_.swap(tuple_list); tuple_list_.swap(tuple_list);
transform_tuple_to_param_list(tuple_list_, param_list_); param_transform(tuple_list_, param_list_);
kernel_resource_ = kernel_resource_ =
"__kernel void gemm_F32F32toF32_2D(\n\ "__kernel void gemm_F32F32toF32_2D(\n\
@ -87,14 +99,14 @@ class CustomGemm : public CustomOpBase {
protected: protected:
const char* kernel_NotTransA_NotTransB = "gemm_F32F32toF32_2D"; const char* kernel_NotTransA_NotTransB = "gemm_F32F32toF32_2D";
const char* kernel_TransA_NotTransB = "....."; const char* kernel_TransA_NotTransB = ".....";
DeriveParmaTuple tuple_list_; ParamTuple tuple_list_;
bool trans_a_; bool trans_a_;
bool trans_b_; bool trans_b_;
static const char* kernel_name_; static const char* kernel_name_;
static int32_t kernel_id_; static int32_t kernel_id_;
//function for setup output //function for setup output
void setup_output_shape_info() override { void SetupShapeInfor() override {
if (!trans_a_ && !trans_a_) { if (!trans_a_ && !trans_a_) {
outputs_size_[0].push_back(inputs_size_[0][1]); outputs_size_[0].push_back(inputs_size_[0][1]);
outputs_size_[0].push_back(inputs_size_[1][0]); outputs_size_[0].push_back(inputs_size_[1][0]);
@ -105,7 +117,7 @@ class CustomGemm : public CustomOpBase {
} }
//function for kernel select and build option //function for kernel select and build option
void extract_parameter_and_register( void SetupParams(
std::vector<tim::vx::DataType> input_types, std::vector<tim::vx::DataType> input_types,
std::string& build_option) override { std::string& build_option) override {
if (trans_a_ == false && if (trans_a_ == false &&
@ -115,16 +127,13 @@ class CustomGemm : public CustomOpBase {
func_name_ = kernel_NotTransA_NotTransB; func_name_ = kernel_NotTransA_NotTransB;
build_option = ""; build_option = "";
} else { } else {
//other situation: named func_name_ and setup param_list // other situation: named func_name_ and setup param_list
//func_name_ = "......";
//std::get<2>(param_list) = ......
} }
return;
} }
//function for kernel local size and gobal size //function for kernel local size and gobal size
void setup_kernel_param(uint32_t& dim, std::vector<size_t>& global_size, void SetupEnqueue(uint32_t& dim, std::vector<size_t>& global_size,
std::vector<size_t>& local_size) { std::vector<size_t>& local_size) {
dim = 3; dim = 3;
local_size[0] = 0; local_size[0] = 0;
local_size[1] = 0; local_size[1] = 0;

View File

@ -65,7 +65,7 @@ void custom_gemm_single_test(){
a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float)); a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float));
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float)); b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0); tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
auto op = graph->CreateOperation<tim::vx::ops::CustomGemm>( auto op = graph->CreateOperation<tim::vx::ops::CustomGemm>(
false,false,tuple_list); false,false,tuple_list);
@ -143,7 +143,7 @@ void custom_gemm_op_and_add_op_test(){
auto op_add = graph->CreateOperation<tim::vx::ops::AddN>(2); auto op_add = graph->CreateOperation<tim::vx::ops::AddN>(2);
(*op_add).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor}); (*op_add).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor});
tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0); tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>( auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
false,false,tuple_list); false,false,tuple_list);
@ -210,7 +210,7 @@ void custom_gemm_op_and_custom_gemm_op_test(){
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float)); b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float)); d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float));
tim::vx::ops::DeriveParmaTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0); tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0);
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>( auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
false,false,tuple_list); false,false,tuple_list);

View File

@ -88,7 +88,7 @@ vsi_bool op_setup(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
op_this->inputs_size_.push_back(input_size); op_this->inputs_size_.push_back(input_size);
} }
op_this->setup_output_shape_info(); op_this->SetupShapeInfor();
for (uint32_t i = 0; i < op_this->outputs_size_.size(); i++) { for (uint32_t i = 0; i < op_this->outputs_size_.size(); i++) {
outputs[i]->attr.dim_num = op_this->outputs_size_[i].size(); outputs[i]->attr.dim_num = op_this->outputs_size_[i].size();
@ -129,7 +129,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
} }
std::string build_option; std::string build_option;
op_this->extract_parameter_and_register(input_types, build_option); op_this->SetupParams(input_types, build_option);
snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_); snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_);
kernel->unique_id = kernel->unique_id =
@ -233,7 +233,7 @@ vx_status derive_kernel_init(vx_node node, const vx_reference* param,
auto iter = node_base_map_.find(reinterpret_cast<void*>(node)); auto iter = node_base_map_.find(reinterpret_cast<void*>(node));
if (iter != node_base_map_.end()) { if (iter != node_base_map_.end()) {
iter->second->setup_kernel_param(dim, global_size, local_size); iter->second->SetupEnqueue(dim, global_size, local_size);
} else { } else {
std::cout << "Something wrong in finding gpu param setup function" std::cout << "Something wrong in finding gpu param setup function"
<< std::endl; << std::endl;