Refine customized op support (#327)

Signed-off-by: ZhangXiang <Xiang.Zhang@verisilicon.com>
This commit is contained in:
Sven 2022-03-22 23:00:52 +08:00 committed by GitHub
parent 5bab9964e9
commit 097f8d74cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 37 deletions

View File

@ -46,9 +46,8 @@ struct Param {
};
template <size_t I = 0, typename... Ts>
typename std::enable_if<I == sizeof...(Ts), void>::type
transform_tuple_to_param_list(std::tuple<Ts...> tup,
std::vector<Param>& param_list) {
typename std::enable_if<I == sizeof...(Ts), void>::type param_transform(
std::tuple<Ts...> tup, std::vector<Param>& param_list) {
auto size = std::tuple_size<decltype(tup)>::value;
if (size == I && param_list.size() != 0) {
//nothing to do
@ -57,9 +56,8 @@ transform_tuple_to_param_list(std::tuple<Ts...> tup,
}
template <size_t I = 0, typename... Ts>
typename std::enable_if<(I < sizeof...(Ts)), void>::type
transform_tuple_to_param_list(std::tuple<Ts...> tup,
std::vector<Param>& param_list) {
typename std::enable_if<(I < sizeof...(Ts)), void>::type param_transform(
std::tuple<Ts...> tup, std::vector<Param>& param_list) {
if (typeid(std::get<I>(tup)) == typeid(float)) {
Param p;
p.data.f = std::get<I>(tup);
@ -84,7 +82,7 @@ transform_tuple_to_param_list(std::tuple<Ts...> tup,
std::cout << "need more else if type of param list " << std::endl;
}
// Go to next element
transform_tuple_to_param_list<I + 1>(tup, param_list);
param_transform<I + 1>(tup, param_list);
}
class CustomOpBase : public Operation {
@ -93,15 +91,14 @@ class CustomOpBase : public Operation {
int32_t kernel_id, const char* kernel_name);
~CustomOpBase();
virtual void extract_parameter_and_register(
virtual void SetupParams(
std::vector<tim::vx::DataType> input_types,
std::string& build_option) = 0;
virtual void setup_output_shape_info() = 0;
virtual void SetupShapeInfor() = 0;
virtual void setup_kernel_param(uint32_t& dim,
std::vector<size_t>& gobal_size,
std::vector<size_t>& local_size) = 0;
virtual void SetupEnqueue(uint32_t& dim, std::vector<size_t>& gobal_size,
std::vector<size_t>& local_size) = 0;
std::vector<Param> param_list_;
std::vector<tim::vx::ShapeType> inputs_size_;
@ -120,4 +117,4 @@ class CustomOpBase : public Operation {
} // namespace vx
} // namespace tim
#endif /* TIM_VX_OPS_MATMUL_H_ */
#endif /* TIM_VX_OPS_CUSTOM_BASE_H_ */

View File

@ -31,20 +31,32 @@ namespace tim {
namespace vx {
namespace ops {
//scalar param for kernel function input
using DeriveParmaTuple = std::tuple<int, int, int, int, int, float, float,
float, float, float, float>;
class CustomGemm : public CustomOpBase {
public:
//scalar param for kernel function input
using ParamTuple = std::tuple<int, /* M */
int, /* K */
int, /* N */
int, /* ac2zero */
int, /* bc2zero */
float, /* scale_a */
float, /* zp_a */
float, /* scale_b */
float, /* zp_b */
float, /* scale_out */
float /* zp_out */
>;
CustomGemm(Graph* graph, bool trans_a, bool trans_b,
DeriveParmaTuple tuple_list,
uint32_t input_num = 2, uint32_t output_num = 1)
: CustomOpBase(graph, input_num, output_num,
CustomGemm::kernel_id_, CustomGemm::kernel_name_),
trans_a_(trans_a),trans_b_(trans_b) {
ParamTuple tuple_list, uint32_t input_num = 2,
uint32_t output_num = 1)
: CustomOpBase(graph, input_num, output_num, CustomGemm::kernel_id_,
CustomGemm::kernel_name_),
trans_a_(trans_a),
trans_b_(trans_b) {
tuple_list_.swap(tuple_list);
transform_tuple_to_param_list(tuple_list_, param_list_);
param_transform(tuple_list_, param_list_);
kernel_resource_ =
"__kernel void gemm_F32F32toF32_2D(\n\
@ -87,14 +99,14 @@ class CustomGemm : public CustomOpBase {
protected:
const char* kernel_NotTransA_NotTransB = "gemm_F32F32toF32_2D";
const char* kernel_TransA_NotTransB = ".....";
DeriveParmaTuple tuple_list_;
ParamTuple tuple_list_;
bool trans_a_;
bool trans_b_;
static const char* kernel_name_;
static int32_t kernel_id_;
//function for setup output
void setup_output_shape_info() override {
void SetupShapeInfor() override {
if (!trans_a_ && !trans_a_) {
outputs_size_[0].push_back(inputs_size_[0][1]);
outputs_size_[0].push_back(inputs_size_[1][0]);
@ -105,7 +117,7 @@ class CustomGemm : public CustomOpBase {
}
//function for kernel select and build option
void extract_parameter_and_register(
void SetupParams(
std::vector<tim::vx::DataType> input_types,
std::string& build_option) override {
if (trans_a_ == false &&
@ -115,16 +127,13 @@ class CustomGemm : public CustomOpBase {
func_name_ = kernel_NotTransA_NotTransB;
build_option = "";
} else {
//other situation: named func_name_ and setup param_list
//func_name_ = "......";
//std::get<2>(param_list) = ......
// other situation: named func_name_ and setup param_list
}
return;
}
//function for kernel local size and gobal size
void setup_kernel_param(uint32_t& dim, std::vector<size_t>& global_size,
std::vector<size_t>& local_size) {
void SetupEnqueue(uint32_t& dim, std::vector<size_t>& global_size,
std::vector<size_t>& local_size) {
dim = 3;
local_size[0] = 0;
local_size[1] = 0;

View File

@ -65,7 +65,7 @@ void custom_gemm_single_test(){
a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float));
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
auto op = graph->CreateOperation<tim::vx::ops::CustomGemm>(
false,false,tuple_list);
@ -143,7 +143,7 @@ void custom_gemm_op_and_add_op_test(){
auto op_add = graph->CreateOperation<tim::vx::ops::AddN>(2);
(*op_add).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor});
tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
false,false,tuple_list);
@ -210,7 +210,7 @@ void custom_gemm_op_and_custom_gemm_op_test(){
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float));
tim::vx::ops::DeriveParmaTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0);
tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0);
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
false,false,tuple_list);

View File

@ -88,7 +88,7 @@ vsi_bool op_setup(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
op_this->inputs_size_.push_back(input_size);
}
op_this->setup_output_shape_info();
op_this->SetupShapeInfor();
for (uint32_t i = 0; i < op_this->outputs_size_.size(); i++) {
outputs[i]->attr.dim_num = op_this->outputs_size_[i].size();
@ -129,7 +129,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
}
std::string build_option;
op_this->extract_parameter_and_register(input_types, build_option);
op_this->SetupParams(input_types, build_option);
snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_);
kernel->unique_id =
@ -233,7 +233,7 @@ vx_status derive_kernel_init(vx_node node, const vx_reference* param,
auto iter = node_base_map_.find(reinterpret_cast<void*>(node));
if (iter != node_base_map_.end()) {
iter->second->setup_kernel_param(dim, global_size, local_size);
iter->second->SetupEnqueue(dim, global_size, local_size);
} else {
std::cout << "Something wrong in finding gpu param setup function"
<< std::endl;