Refine customized op support (#327)
Signed-off-by: ZhangXiang <Xiang.Zhang@verisilicon.com>
This commit is contained in:
parent
5bab9964e9
commit
097f8d74cd
|
|
@ -46,9 +46,8 @@ struct Param {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <size_t I = 0, typename... Ts>
|
template <size_t I = 0, typename... Ts>
|
||||||
typename std::enable_if<I == sizeof...(Ts), void>::type
|
typename std::enable_if<I == sizeof...(Ts), void>::type param_transform(
|
||||||
transform_tuple_to_param_list(std::tuple<Ts...> tup,
|
std::tuple<Ts...> tup, std::vector<Param>& param_list) {
|
||||||
std::vector<Param>& param_list) {
|
|
||||||
auto size = std::tuple_size<decltype(tup)>::value;
|
auto size = std::tuple_size<decltype(tup)>::value;
|
||||||
if (size == I && param_list.size() != 0) {
|
if (size == I && param_list.size() != 0) {
|
||||||
//nothing to do
|
//nothing to do
|
||||||
|
|
@ -57,9 +56,8 @@ transform_tuple_to_param_list(std::tuple<Ts...> tup,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t I = 0, typename... Ts>
|
template <size_t I = 0, typename... Ts>
|
||||||
typename std::enable_if<(I < sizeof...(Ts)), void>::type
|
typename std::enable_if<(I < sizeof...(Ts)), void>::type param_transform(
|
||||||
transform_tuple_to_param_list(std::tuple<Ts...> tup,
|
std::tuple<Ts...> tup, std::vector<Param>& param_list) {
|
||||||
std::vector<Param>& param_list) {
|
|
||||||
if (typeid(std::get<I>(tup)) == typeid(float)) {
|
if (typeid(std::get<I>(tup)) == typeid(float)) {
|
||||||
Param p;
|
Param p;
|
||||||
p.data.f = std::get<I>(tup);
|
p.data.f = std::get<I>(tup);
|
||||||
|
|
@ -84,7 +82,7 @@ transform_tuple_to_param_list(std::tuple<Ts...> tup,
|
||||||
std::cout << "need more else if type of param list " << std::endl;
|
std::cout << "need more else if type of param list " << std::endl;
|
||||||
}
|
}
|
||||||
// Go to next element
|
// Go to next element
|
||||||
transform_tuple_to_param_list<I + 1>(tup, param_list);
|
param_transform<I + 1>(tup, param_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
class CustomOpBase : public Operation {
|
class CustomOpBase : public Operation {
|
||||||
|
|
@ -93,14 +91,13 @@ class CustomOpBase : public Operation {
|
||||||
int32_t kernel_id, const char* kernel_name);
|
int32_t kernel_id, const char* kernel_name);
|
||||||
|
|
||||||
~CustomOpBase();
|
~CustomOpBase();
|
||||||
virtual void extract_parameter_and_register(
|
virtual void SetupParams(
|
||||||
std::vector<tim::vx::DataType> input_types,
|
std::vector<tim::vx::DataType> input_types,
|
||||||
std::string& build_option) = 0;
|
std::string& build_option) = 0;
|
||||||
|
|
||||||
virtual void setup_output_shape_info() = 0;
|
virtual void SetupShapeInfor() = 0;
|
||||||
|
|
||||||
virtual void setup_kernel_param(uint32_t& dim,
|
virtual void SetupEnqueue(uint32_t& dim, std::vector<size_t>& gobal_size,
|
||||||
std::vector<size_t>& gobal_size,
|
|
||||||
std::vector<size_t>& local_size) = 0;
|
std::vector<size_t>& local_size) = 0;
|
||||||
|
|
||||||
std::vector<Param> param_list_;
|
std::vector<Param> param_list_;
|
||||||
|
|
@ -120,4 +117,4 @@ class CustomOpBase : public Operation {
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
||||||
#endif /* TIM_VX_OPS_MATMUL_H_ */
|
#endif /* TIM_VX_OPS_CUSTOM_BASE_H_ */
|
||||||
|
|
@ -31,20 +31,32 @@ namespace tim {
|
||||||
namespace vx {
|
namespace vx {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
//scalar param for kernel function input
|
|
||||||
using DeriveParmaTuple = std::tuple<int, int, int, int, int, float, float,
|
|
||||||
float, float, float, float>;
|
|
||||||
|
|
||||||
class CustomGemm : public CustomOpBase {
|
class CustomGemm : public CustomOpBase {
|
||||||
public:
|
public:
|
||||||
|
//scalar param for kernel function input
|
||||||
|
using ParamTuple = std::tuple<int, /* M */
|
||||||
|
int, /* K */
|
||||||
|
int, /* N */
|
||||||
|
int, /* ac2zero */
|
||||||
|
int, /* bc2zero */
|
||||||
|
float, /* scale_a */
|
||||||
|
float, /* zp_a */
|
||||||
|
float, /* scale_b */
|
||||||
|
float, /* zp_b */
|
||||||
|
float, /* scale_out */
|
||||||
|
float /* zp_out */
|
||||||
|
>;
|
||||||
CustomGemm(Graph* graph, bool trans_a, bool trans_b,
|
CustomGemm(Graph* graph, bool trans_a, bool trans_b,
|
||||||
DeriveParmaTuple tuple_list,
|
ParamTuple tuple_list, uint32_t input_num = 2,
|
||||||
uint32_t input_num = 2, uint32_t output_num = 1)
|
uint32_t output_num = 1)
|
||||||
: CustomOpBase(graph, input_num, output_num,
|
: CustomOpBase(graph, input_num, output_num, CustomGemm::kernel_id_,
|
||||||
CustomGemm::kernel_id_, CustomGemm::kernel_name_),
|
CustomGemm::kernel_name_),
|
||||||
trans_a_(trans_a),trans_b_(trans_b) {
|
trans_a_(trans_a),
|
||||||
|
trans_b_(trans_b) {
|
||||||
tuple_list_.swap(tuple_list);
|
tuple_list_.swap(tuple_list);
|
||||||
transform_tuple_to_param_list(tuple_list_, param_list_);
|
param_transform(tuple_list_, param_list_);
|
||||||
|
|
||||||
kernel_resource_ =
|
kernel_resource_ =
|
||||||
"__kernel void gemm_F32F32toF32_2D(\n\
|
"__kernel void gemm_F32F32toF32_2D(\n\
|
||||||
|
|
@ -87,14 +99,14 @@ class CustomGemm : public CustomOpBase {
|
||||||
protected:
|
protected:
|
||||||
const char* kernel_NotTransA_NotTransB = "gemm_F32F32toF32_2D";
|
const char* kernel_NotTransA_NotTransB = "gemm_F32F32toF32_2D";
|
||||||
const char* kernel_TransA_NotTransB = ".....";
|
const char* kernel_TransA_NotTransB = ".....";
|
||||||
DeriveParmaTuple tuple_list_;
|
ParamTuple tuple_list_;
|
||||||
bool trans_a_;
|
bool trans_a_;
|
||||||
bool trans_b_;
|
bool trans_b_;
|
||||||
static const char* kernel_name_;
|
static const char* kernel_name_;
|
||||||
static int32_t kernel_id_;
|
static int32_t kernel_id_;
|
||||||
|
|
||||||
//function for setup output
|
//function for setup output
|
||||||
void setup_output_shape_info() override {
|
void SetupShapeInfor() override {
|
||||||
if (!trans_a_ && !trans_a_) {
|
if (!trans_a_ && !trans_a_) {
|
||||||
outputs_size_[0].push_back(inputs_size_[0][1]);
|
outputs_size_[0].push_back(inputs_size_[0][1]);
|
||||||
outputs_size_[0].push_back(inputs_size_[1][0]);
|
outputs_size_[0].push_back(inputs_size_[1][0]);
|
||||||
|
|
@ -105,7 +117,7 @@ class CustomGemm : public CustomOpBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
//function for kernel select and build option
|
//function for kernel select and build option
|
||||||
void extract_parameter_and_register(
|
void SetupParams(
|
||||||
std::vector<tim::vx::DataType> input_types,
|
std::vector<tim::vx::DataType> input_types,
|
||||||
std::string& build_option) override {
|
std::string& build_option) override {
|
||||||
if (trans_a_ == false &&
|
if (trans_a_ == false &&
|
||||||
|
|
@ -116,14 +128,11 @@ class CustomGemm : public CustomOpBase {
|
||||||
build_option = "";
|
build_option = "";
|
||||||
} else {
|
} else {
|
||||||
// other situation: named func_name_ and setup param_list
|
// other situation: named func_name_ and setup param_list
|
||||||
//func_name_ = "......";
|
|
||||||
//std::get<2>(param_list) = ......
|
|
||||||
}
|
}
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//function for kernel local size and gobal size
|
//function for kernel local size and gobal size
|
||||||
void setup_kernel_param(uint32_t& dim, std::vector<size_t>& global_size,
|
void SetupEnqueue(uint32_t& dim, std::vector<size_t>& global_size,
|
||||||
std::vector<size_t>& local_size) {
|
std::vector<size_t>& local_size) {
|
||||||
dim = 3;
|
dim = 3;
|
||||||
local_size[0] = 0;
|
local_size[0] = 0;
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ void custom_gemm_single_test(){
|
||||||
a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float));
|
a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float));
|
||||||
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
|
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
|
||||||
|
|
||||||
tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
|
tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
|
||||||
|
|
||||||
auto op = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
auto op = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||||
false,false,tuple_list);
|
false,false,tuple_list);
|
||||||
|
|
@ -143,7 +143,7 @@ void custom_gemm_op_and_add_op_test(){
|
||||||
auto op_add = graph->CreateOperation<tim::vx::ops::AddN>(2);
|
auto op_add = graph->CreateOperation<tim::vx::ops::AddN>(2);
|
||||||
(*op_add).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor});
|
(*op_add).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor});
|
||||||
|
|
||||||
tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
|
tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
|
||||||
|
|
||||||
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||||
false,false,tuple_list);
|
false,false,tuple_list);
|
||||||
|
|
@ -210,7 +210,7 @@ void custom_gemm_op_and_custom_gemm_op_test(){
|
||||||
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
|
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
|
||||||
d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float));
|
d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float));
|
||||||
|
|
||||||
tim::vx::ops::DeriveParmaTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0);
|
tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0);
|
||||||
|
|
||||||
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||||
false,false,tuple_list);
|
false,false,tuple_list);
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ vsi_bool op_setup(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||||
op_this->inputs_size_.push_back(input_size);
|
op_this->inputs_size_.push_back(input_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
op_this->setup_output_shape_info();
|
op_this->SetupShapeInfor();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < op_this->outputs_size_.size(); i++) {
|
for (uint32_t i = 0; i < op_this->outputs_size_.size(); i++) {
|
||||||
outputs[i]->attr.dim_num = op_this->outputs_size_[i].size();
|
outputs[i]->attr.dim_num = op_this->outputs_size_[i].size();
|
||||||
|
|
@ -129,7 +129,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_option;
|
std::string build_option;
|
||||||
op_this->extract_parameter_and_register(input_types, build_option);
|
op_this->SetupParams(input_types, build_option);
|
||||||
|
|
||||||
snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_);
|
snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_);
|
||||||
kernel->unique_id =
|
kernel->unique_id =
|
||||||
|
|
@ -233,7 +233,7 @@ vx_status derive_kernel_init(vx_node node, const vx_reference* param,
|
||||||
|
|
||||||
auto iter = node_base_map_.find(reinterpret_cast<void*>(node));
|
auto iter = node_base_map_.find(reinterpret_cast<void*>(node));
|
||||||
if (iter != node_base_map_.end()) {
|
if (iter != node_base_map_.end()) {
|
||||||
iter->second->setup_kernel_param(dim, global_size, local_size);
|
iter->second->SetupEnqueue(dim, global_size, local_size);
|
||||||
} else {
|
} else {
|
||||||
std::cout << "Something wrong in finding gpu param setup function"
|
std::cout << "Something wrong in finding gpu param setup function"
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue