Refine customized op support (#327)
Signed-off-by: ZhangXiang <Xiang.Zhang@verisilicon.com>
This commit is contained in:
parent
5bab9964e9
commit
097f8d74cd
|
|
@ -46,9 +46,8 @@ struct Param {
|
|||
};
|
||||
|
||||
template <size_t I = 0, typename... Ts>
|
||||
typename std::enable_if<I == sizeof...(Ts), void>::type
|
||||
transform_tuple_to_param_list(std::tuple<Ts...> tup,
|
||||
std::vector<Param>& param_list) {
|
||||
typename std::enable_if<I == sizeof...(Ts), void>::type param_transform(
|
||||
std::tuple<Ts...> tup, std::vector<Param>& param_list) {
|
||||
auto size = std::tuple_size<decltype(tup)>::value;
|
||||
if (size == I && param_list.size() != 0) {
|
||||
//nothing to do
|
||||
|
|
@ -57,9 +56,8 @@ transform_tuple_to_param_list(std::tuple<Ts...> tup,
|
|||
}
|
||||
|
||||
template <size_t I = 0, typename... Ts>
|
||||
typename std::enable_if<(I < sizeof...(Ts)), void>::type
|
||||
transform_tuple_to_param_list(std::tuple<Ts...> tup,
|
||||
std::vector<Param>& param_list) {
|
||||
typename std::enable_if<(I < sizeof...(Ts)), void>::type param_transform(
|
||||
std::tuple<Ts...> tup, std::vector<Param>& param_list) {
|
||||
if (typeid(std::get<I>(tup)) == typeid(float)) {
|
||||
Param p;
|
||||
p.data.f = std::get<I>(tup);
|
||||
|
|
@ -84,7 +82,7 @@ transform_tuple_to_param_list(std::tuple<Ts...> tup,
|
|||
std::cout << "need more else if type of param list " << std::endl;
|
||||
}
|
||||
// Go to next element
|
||||
transform_tuple_to_param_list<I + 1>(tup, param_list);
|
||||
param_transform<I + 1>(tup, param_list);
|
||||
}
|
||||
|
||||
class CustomOpBase : public Operation {
|
||||
|
|
@ -93,15 +91,14 @@ class CustomOpBase : public Operation {
|
|||
int32_t kernel_id, const char* kernel_name);
|
||||
|
||||
~CustomOpBase();
|
||||
virtual void extract_parameter_and_register(
|
||||
virtual void SetupParams(
|
||||
std::vector<tim::vx::DataType> input_types,
|
||||
std::string& build_option) = 0;
|
||||
|
||||
virtual void setup_output_shape_info() = 0;
|
||||
virtual void SetupShapeInfor() = 0;
|
||||
|
||||
virtual void setup_kernel_param(uint32_t& dim,
|
||||
std::vector<size_t>& gobal_size,
|
||||
std::vector<size_t>& local_size) = 0;
|
||||
virtual void SetupEnqueue(uint32_t& dim, std::vector<size_t>& gobal_size,
|
||||
std::vector<size_t>& local_size) = 0;
|
||||
|
||||
std::vector<Param> param_list_;
|
||||
std::vector<tim::vx::ShapeType> inputs_size_;
|
||||
|
|
@ -120,4 +117,4 @@ class CustomOpBase : public Operation {
|
|||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
||||
#endif /* TIM_VX_OPS_MATMUL_H_ */
|
||||
#endif /* TIM_VX_OPS_CUSTOM_BASE_H_ */
|
||||
|
|
@ -31,20 +31,32 @@ namespace tim {
|
|||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
//scalar param for kernel function input
|
||||
using DeriveParmaTuple = std::tuple<int, int, int, int, int, float, float,
|
||||
float, float, float, float>;
|
||||
|
||||
|
||||
class CustomGemm : public CustomOpBase {
|
||||
public:
|
||||
//scalar param for kernel function input
|
||||
using ParamTuple = std::tuple<int, /* M */
|
||||
int, /* K */
|
||||
int, /* N */
|
||||
int, /* ac2zero */
|
||||
int, /* bc2zero */
|
||||
float, /* scale_a */
|
||||
float, /* zp_a */
|
||||
float, /* scale_b */
|
||||
float, /* zp_b */
|
||||
float, /* scale_out */
|
||||
float /* zp_out */
|
||||
>;
|
||||
CustomGemm(Graph* graph, bool trans_a, bool trans_b,
|
||||
DeriveParmaTuple tuple_list,
|
||||
uint32_t input_num = 2, uint32_t output_num = 1)
|
||||
: CustomOpBase(graph, input_num, output_num,
|
||||
CustomGemm::kernel_id_, CustomGemm::kernel_name_),
|
||||
trans_a_(trans_a),trans_b_(trans_b) {
|
||||
ParamTuple tuple_list, uint32_t input_num = 2,
|
||||
uint32_t output_num = 1)
|
||||
: CustomOpBase(graph, input_num, output_num, CustomGemm::kernel_id_,
|
||||
CustomGemm::kernel_name_),
|
||||
trans_a_(trans_a),
|
||||
trans_b_(trans_b) {
|
||||
tuple_list_.swap(tuple_list);
|
||||
transform_tuple_to_param_list(tuple_list_, param_list_);
|
||||
param_transform(tuple_list_, param_list_);
|
||||
|
||||
kernel_resource_ =
|
||||
"__kernel void gemm_F32F32toF32_2D(\n\
|
||||
|
|
@ -87,14 +99,14 @@ class CustomGemm : public CustomOpBase {
|
|||
protected:
|
||||
const char* kernel_NotTransA_NotTransB = "gemm_F32F32toF32_2D";
|
||||
const char* kernel_TransA_NotTransB = ".....";
|
||||
DeriveParmaTuple tuple_list_;
|
||||
ParamTuple tuple_list_;
|
||||
bool trans_a_;
|
||||
bool trans_b_;
|
||||
static const char* kernel_name_;
|
||||
static int32_t kernel_id_;
|
||||
|
||||
//function for setup output
|
||||
void setup_output_shape_info() override {
|
||||
void SetupShapeInfor() override {
|
||||
if (!trans_a_ && !trans_a_) {
|
||||
outputs_size_[0].push_back(inputs_size_[0][1]);
|
||||
outputs_size_[0].push_back(inputs_size_[1][0]);
|
||||
|
|
@ -105,7 +117,7 @@ class CustomGemm : public CustomOpBase {
|
|||
}
|
||||
|
||||
//function for kernel select and build option
|
||||
void extract_parameter_and_register(
|
||||
void SetupParams(
|
||||
std::vector<tim::vx::DataType> input_types,
|
||||
std::string& build_option) override {
|
||||
if (trans_a_ == false &&
|
||||
|
|
@ -115,16 +127,13 @@ class CustomGemm : public CustomOpBase {
|
|||
func_name_ = kernel_NotTransA_NotTransB;
|
||||
build_option = "";
|
||||
} else {
|
||||
//other situation: named func_name_ and setup param_list
|
||||
//func_name_ = "......";
|
||||
//std::get<2>(param_list) = ......
|
||||
// other situation: named func_name_ and setup param_list
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
//function for kernel local size and gobal size
|
||||
void setup_kernel_param(uint32_t& dim, std::vector<size_t>& global_size,
|
||||
std::vector<size_t>& local_size) {
|
||||
void SetupEnqueue(uint32_t& dim, std::vector<size_t>& global_size,
|
||||
std::vector<size_t>& local_size) {
|
||||
dim = 3;
|
||||
local_size[0] = 0;
|
||||
local_size[1] = 0;
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ void custom_gemm_single_test(){
|
|||
a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float));
|
||||
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
|
||||
|
||||
tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
|
||||
tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||
false,false,tuple_list);
|
||||
|
|
@ -143,7 +143,7 @@ void custom_gemm_op_and_add_op_test(){
|
|||
auto op_add = graph->CreateOperation<tim::vx::ops::AddN>(2);
|
||||
(*op_add).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor});
|
||||
|
||||
tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
|
||||
tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
|
||||
|
||||
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||
false,false,tuple_list);
|
||||
|
|
@ -210,7 +210,7 @@ void custom_gemm_op_and_custom_gemm_op_test(){
|
|||
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
|
||||
d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float));
|
||||
|
||||
tim::vx::ops::DeriveParmaTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0);
|
||||
tim::vx::ops::CustomGemm::ParamTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0);
|
||||
|
||||
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||
false,false,tuple_list);
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ vsi_bool op_setup(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
|||
op_this->inputs_size_.push_back(input_size);
|
||||
}
|
||||
|
||||
op_this->setup_output_shape_info();
|
||||
op_this->SetupShapeInfor();
|
||||
|
||||
for (uint32_t i = 0; i < op_this->outputs_size_.size(); i++) {
|
||||
outputs[i]->attr.dim_num = op_this->outputs_size_[i].size();
|
||||
|
|
@ -129,7 +129,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
|||
}
|
||||
|
||||
std::string build_option;
|
||||
op_this->extract_parameter_and_register(input_types, build_option);
|
||||
op_this->SetupParams(input_types, build_option);
|
||||
|
||||
snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_);
|
||||
kernel->unique_id =
|
||||
|
|
@ -233,7 +233,7 @@ vx_status derive_kernel_init(vx_node node, const vx_reference* param,
|
|||
|
||||
auto iter = node_base_map_.find(reinterpret_cast<void*>(node));
|
||||
if (iter != node_base_map_.end()) {
|
||||
iter->second->setup_kernel_param(dim, global_size, local_size);
|
||||
iter->second->SetupEnqueue(dim, global_size, local_size);
|
||||
} else {
|
||||
std::cout << "Something wrong in finding gpu param setup function"
|
||||
<< std::endl;
|
||||
|
|
|
|||
Loading…
Reference in New Issue