Added customize operator APIs(#315)
Co-authored-by: zhouheng.zheng <zhouheng.zheng@ouotlook.com>
This commit is contained in:
parent
161bb8a7c4
commit
b02aa8b8c4
|
|
@ -85,5 +85,6 @@
|
||||||
#include "tim/vx/ops/unidirectional_sequence_lstm.h"
|
#include "tim/vx/ops/unidirectional_sequence_lstm.h"
|
||||||
#include "tim/vx/ops/unstack.h"
|
#include "tim/vx/ops/unstack.h"
|
||||||
#include "tim/vx/ops/conv3d.h"
|
#include "tim/vx/ops/conv3d.h"
|
||||||
|
#include "tim/vx/ops/custom_base.h"
|
||||||
|
|
||||||
#endif /* TIM_VX_OPS_H_ */
|
#endif /* TIM_VX_OPS_H_ */
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,123 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2021 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#ifndef TIM_VX_OPS_CUSTOM_BASE_H_
|
||||||
|
#define TIM_VX_OPS_CUSTOM_BASE_H_
|
||||||
|
|
||||||
|
#include "tim/vx/direct_map_op.h"
|
||||||
|
#include <typeinfo>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
namespace tim {
|
||||||
|
namespace vx {
|
||||||
|
namespace ops {
|
||||||
|
#define gpu_align(n, align) ((n) + ((align)-1)) & ~((align)-1)
|
||||||
|
|
||||||
|
__attribute__((unused)) static int32_t gobal_kernel_id_ = 0;
|
||||||
|
|
||||||
|
struct Param {
|
||||||
|
union {
|
||||||
|
float f;
|
||||||
|
int i;
|
||||||
|
bool b;
|
||||||
|
uint ui;
|
||||||
|
} data;
|
||||||
|
tim::vx::DataType type;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <size_t I = 0, typename... Ts>
|
||||||
|
typename std::enable_if<I == sizeof...(Ts), void>::type
|
||||||
|
transform_tuple_to_param_list(std::tuple<Ts...> tup,
|
||||||
|
std::vector<Param>& param_list) {
|
||||||
|
auto size = std::tuple_size<decltype(tup)>::value;
|
||||||
|
if (size == I && param_list.size() != 0) {
|
||||||
|
//nothing to do
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t I = 0, typename... Ts>
|
||||||
|
typename std::enable_if<(I < sizeof...(Ts)), void>::type
|
||||||
|
transform_tuple_to_param_list(std::tuple<Ts...> tup,
|
||||||
|
std::vector<Param>& param_list) {
|
||||||
|
if (typeid(std::get<I>(tup)) == typeid(float)) {
|
||||||
|
Param p;
|
||||||
|
p.data.f = std::get<I>(tup);
|
||||||
|
p.type = tim::vx::DataType::FLOAT32;
|
||||||
|
param_list.push_back(p);
|
||||||
|
} else if (typeid(std::get<I>(tup)) == typeid(uint32_t)) {
|
||||||
|
Param p;
|
||||||
|
p.data.ui = std::get<I>(tup);
|
||||||
|
p.type = tim::vx::DataType::UINT32;
|
||||||
|
param_list.push_back(p);
|
||||||
|
} else if (typeid(std::get<I>(tup)) == typeid(int32_t)) {
|
||||||
|
Param p;
|
||||||
|
p.data.i = std::get<I>(tup);
|
||||||
|
p.type = tim::vx::DataType::INT32;
|
||||||
|
param_list.push_back(p);
|
||||||
|
} else if (typeid(std::get<I>(tup)) == typeid(bool)) {
|
||||||
|
Param p;
|
||||||
|
p.data.b = std::get<I>(tup);
|
||||||
|
p.type = tim::vx::DataType::BOOL8;
|
||||||
|
param_list.push_back(p);
|
||||||
|
} else {
|
||||||
|
std::cout << "need more else if type of param list " << std::endl;
|
||||||
|
}
|
||||||
|
// Go to next element
|
||||||
|
transform_tuple_to_param_list<I + 1>(tup, param_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
class CustomOpBase : public Operation {
|
||||||
|
public:
|
||||||
|
CustomOpBase(Graph* graph, uint32_t input_num, uint32_t output_num,
|
||||||
|
int32_t kernel_id, const char* kernel_name);
|
||||||
|
|
||||||
|
~CustomOpBase();
|
||||||
|
virtual void extract_parameter_and_register(
|
||||||
|
std::vector<tim::vx::DataType> input_types,
|
||||||
|
std::string& build_option) = 0;
|
||||||
|
|
||||||
|
virtual void setup_output_shape_info() = 0;
|
||||||
|
|
||||||
|
virtual void setup_kernel_param(uint32_t& dim,
|
||||||
|
std::vector<size_t>& gobal_size,
|
||||||
|
std::vector<size_t>& local_size) = 0;
|
||||||
|
|
||||||
|
std::vector<Param> param_list_;
|
||||||
|
std::vector<tim::vx::ShapeType> inputs_size_;
|
||||||
|
std::vector<tim::vx::ShapeType> outputs_size_;
|
||||||
|
|
||||||
|
const char* func_name_;
|
||||||
|
const char* kernel_resource_;
|
||||||
|
void* init_kernel_;
|
||||||
|
void* vx_node_;
|
||||||
|
|
||||||
|
uint32_t input_num_;
|
||||||
|
uint32_t output_num_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace vx
|
||||||
|
} // namespace tim
|
||||||
|
|
||||||
|
#endif /* TIM_VX_OPS_MATMUL_H_ */
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
add_subdirectory("benchmark_test")
|
add_subdirectory("benchmark_test")
|
||||||
|
add_subdirectory("custom_op_test")
|
||||||
add_subdirectory("lenet")
|
add_subdirectory("lenet")
|
||||||
if(${TIM_VX_ENABLE_VIPLITE})
|
if(${TIM_VX_ENABLE_VIPLITE})
|
||||||
add_subdirectory("lenet_lite")
|
add_subdirectory("lenet_lite")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
message("samples/custom_op_test")
|
||||||
|
|
||||||
|
set(TARGET_NAME "custom_op_test")
|
||||||
|
|
||||||
|
aux_source_directory(. ${TARGET_NAME}_SRCS)
|
||||||
|
add_executable(${TARGET_NAME} ${${TARGET_NAME}_SRCS})
|
||||||
|
|
||||||
|
target_link_libraries(${TARGET_NAME} PRIVATE tim-vx)
|
||||||
|
target_include_directories(${TARGET_NAME} PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
${PROJECT_SOURCE_DIR}/include
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2021 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#include "custom_gemm.h"
|
||||||
|
|
||||||
|
namespace tim {
|
||||||
|
namespace vx {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
const char* CustomGemm::kernel_name_ = "xxxx_name12345";
|
||||||
|
int32_t CustomGemm::kernel_id_ = -1 * (++gobal_kernel_id_);
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace vx
|
||||||
|
} // namespace tim
|
||||||
|
|
@ -0,0 +1,148 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2021 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
|
||||||
|
#ifndef TIM_VX_OPS_CUSTOM_GEMM_H_
|
||||||
|
#define TIM_VX_OPS_CUSTOM_GEMM_H_
|
||||||
|
|
||||||
|
#include "tim/vx/ops/custom_base.h"
|
||||||
|
|
||||||
|
namespace tim {
|
||||||
|
namespace vx {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
//scalar param for kernel function input
|
||||||
|
using DeriveParmaTuple = std::tuple<int, int, int, int, int, float, float,
|
||||||
|
float, float, float, float>;
|
||||||
|
|
||||||
|
class CustomGemm : public CustomOpBase {
|
||||||
|
public:
|
||||||
|
CustomGemm(Graph* graph, bool trans_a, bool trans_b,
|
||||||
|
DeriveParmaTuple tuple_list,
|
||||||
|
uint32_t input_num = 2, uint32_t output_num = 1)
|
||||||
|
: CustomOpBase(graph, input_num, output_num,
|
||||||
|
CustomGemm::kernel_id_, CustomGemm::kernel_name_),
|
||||||
|
trans_a_(trans_a),trans_b_(trans_b) {
|
||||||
|
tuple_list_.swap(tuple_list);
|
||||||
|
transform_tuple_to_param_list(tuple_list_, param_list_);
|
||||||
|
|
||||||
|
kernel_resource_ =
|
||||||
|
"__kernel void gemm_F32F32toF32_2D(\n\
|
||||||
|
__read_only image2d_t inputA,\n\
|
||||||
|
__read_only image2d_t inputB,\n\
|
||||||
|
__write_only image2d_t output,\n\
|
||||||
|
int M,\n\
|
||||||
|
int K,\n\
|
||||||
|
int N,\n\
|
||||||
|
int ac2zero,\n\
|
||||||
|
int bc2zero,\n\
|
||||||
|
float scale_a,\n\
|
||||||
|
float zp_a,\n\
|
||||||
|
float scale_b,\n\
|
||||||
|
float zp_b,\n\
|
||||||
|
float scale_out,\n\
|
||||||
|
float zp_out\n\
|
||||||
|
)\n\
|
||||||
|
{\n\
|
||||||
|
int4 coord = (int4)(get_global_id(0), get_global_id(1), 0, 0);\n\
|
||||||
|
float4 sum = (float4)(0);\n\
|
||||||
|
\n\
|
||||||
|
for(; coord.z < K;)\n\
|
||||||
|
{\n\
|
||||||
|
float4 tempA0;\n\
|
||||||
|
float4 tempB0;\n\
|
||||||
|
\n\
|
||||||
|
tempA0 = read_imagef(inputA, coord.zy);\n\
|
||||||
|
tempB0 = read_imagef(inputB, coord.xz);\n\
|
||||||
|
coord.z++;\n\
|
||||||
|
\n\
|
||||||
|
sum = sum + tempA0 * tempB0;\n\
|
||||||
|
}\n\
|
||||||
|
write_imagef(output, coord.xy, sum);\n\
|
||||||
|
}\n\
|
||||||
|
\n\
|
||||||
|
";
|
||||||
|
};
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const char* kernel_NotTransA_NotTransB = "gemm_F32F32toF32_2D";
|
||||||
|
const char* kernel_TransA_NotTransB = ".....";
|
||||||
|
DeriveParmaTuple tuple_list_;
|
||||||
|
bool trans_a_;
|
||||||
|
bool trans_b_;
|
||||||
|
static const char* kernel_name_;
|
||||||
|
static int32_t kernel_id_;
|
||||||
|
|
||||||
|
//function for setup output
|
||||||
|
void setup_output_shape_info() override {
|
||||||
|
if (!trans_a_ && !trans_a_) {
|
||||||
|
outputs_size_[0].push_back(inputs_size_[0][1]);
|
||||||
|
outputs_size_[0].push_back(inputs_size_[1][0]);
|
||||||
|
} else {
|
||||||
|
//other situation: set up outputs_size
|
||||||
|
//outputs_size_[0].push_back()......
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//function for kernel select and build option
|
||||||
|
void extract_parameter_and_register(
|
||||||
|
std::vector<tim::vx::DataType> input_types,
|
||||||
|
std::string& build_option) override {
|
||||||
|
if (trans_a_ == false &&
|
||||||
|
trans_a_ == false &&
|
||||||
|
input_types[0] == tim::vx::DataType::FLOAT32 &&
|
||||||
|
input_types[1] == tim::vx::DataType::FLOAT32) {
|
||||||
|
func_name_ = kernel_NotTransA_NotTransB;
|
||||||
|
build_option = "";
|
||||||
|
} else {
|
||||||
|
//other situation: named func_name_ and setup param_list
|
||||||
|
//func_name_ = "......";
|
||||||
|
//std::get<2>(param_list) = ......
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
//function for kernel local size and gobal size
|
||||||
|
void setup_kernel_param(uint32_t& dim, std::vector<size_t>& global_size,
|
||||||
|
std::vector<size_t>& local_size) {
|
||||||
|
dim = 3;
|
||||||
|
local_size[0] = 0;
|
||||||
|
local_size[1] = 0;
|
||||||
|
local_size[2] = 0;
|
||||||
|
|
||||||
|
global_size[0] = gpu_align(outputs_size_[0][0], 4);
|
||||||
|
global_size[1] = gpu_align(outputs_size_[0][1], 4);
|
||||||
|
global_size[2] = outputs_size_[0].size() > 2 ? outputs_size_[0][2] : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Operation> Clone(
|
||||||
|
std::shared_ptr<Graph>& graph) const override {
|
||||||
|
return graph->CreateOperation<CustomGemm>(trans_a_,trans_b_,
|
||||||
|
this->tuple_list_, this->input_num_, this->output_num_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace vx
|
||||||
|
} // namespace tim
|
||||||
|
#endif
|
||||||
|
|
@ -0,0 +1,242 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2021 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#include "tim/vx/context.h"
|
||||||
|
#include "tim/vx/graph.h"
|
||||||
|
#include "tim/vx/ops.h"
|
||||||
|
#include "custom_gemm.h"
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
void custom_gemm_single_test(){
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType a_shape({6, 2});
|
||||||
|
tim::vx::ShapeType b_shape({2, 6});
|
||||||
|
tim::vx::ShapeType out_shape({2, 2});
|
||||||
|
tim::vx::TensorSpec a_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
a_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec b_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
b_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
out_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto a_tensor = graph->CreateTensor(a_spec);
|
||||||
|
auto b_tensor = graph->CreateTensor(b_spec);
|
||||||
|
auto out_tensor = graph->CreateTensor(out_spec);
|
||||||
|
|
||||||
|
std::vector<float> a_data = {
|
||||||
|
1, 2, 3, 4, 5, 6,
|
||||||
|
-1, -2, -3, -4, -5, -6
|
||||||
|
};
|
||||||
|
std::vector<float> b_data = {
|
||||||
|
6, 5,
|
||||||
|
4, 3,
|
||||||
|
2, 1,
|
||||||
|
-6, -5,
|
||||||
|
-4, -3,
|
||||||
|
-2, -1
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
-36, -27,
|
||||||
|
36, 27
|
||||||
|
};
|
||||||
|
|
||||||
|
a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float));
|
||||||
|
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
|
||||||
|
|
||||||
|
tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||||
|
false,false,tuple_list);
|
||||||
|
|
||||||
|
(*op).BindInputs({a_tensor, b_tensor}).BindOutputs({out_tensor});
|
||||||
|
|
||||||
|
graph->Compile();
|
||||||
|
graph->Run();
|
||||||
|
|
||||||
|
std::vector<float> output(golden.size());
|
||||||
|
out_tensor->CopyDataFromTensor(output.data());
|
||||||
|
|
||||||
|
std::cout<<"the diff between golan and result:"<<std::endl;
|
||||||
|
for(uint32_t i=0;i<output.size();i++){
|
||||||
|
std::cout<<output[i] - golden[i]<<" ";
|
||||||
|
}
|
||||||
|
std::cout<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void custom_gemm_op_and_add_op_test(){
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType a_shape({6, 2});
|
||||||
|
tim::vx::ShapeType b_shape({6, 2});
|
||||||
|
tim::vx::ShapeType c_shape({6, 2});
|
||||||
|
tim::vx::ShapeType d_shape({2, 6});
|
||||||
|
|
||||||
|
tim::vx::ShapeType out_shape({2, 2});
|
||||||
|
|
||||||
|
tim::vx::TensorSpec a_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
a_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec b_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
b_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec c_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
c_shape, tim::vx::TensorAttribute::TRANSIENT);
|
||||||
|
tim::vx::TensorSpec d_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
d_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
out_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto a_tensor = graph->CreateTensor(a_spec);
|
||||||
|
auto b_tensor = graph->CreateTensor(b_spec);
|
||||||
|
auto c_tensor = graph->CreateTensor(c_spec);
|
||||||
|
auto d_tensor = graph->CreateTensor(d_spec);
|
||||||
|
auto out_tensor = graph->CreateTensor(out_spec);
|
||||||
|
|
||||||
|
std::vector<float> a_data = {
|
||||||
|
0, 1, 2, 3, 4, 5,
|
||||||
|
-1, -2, -3, -4, -5, -6
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<float> b_data = {
|
||||||
|
1, 1, 1, 1, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<float> d_data = {
|
||||||
|
6, 5,
|
||||||
|
4, 3,
|
||||||
|
2, 1,
|
||||||
|
-6, -5,
|
||||||
|
-4, -3,
|
||||||
|
-2, -1
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
-36, -27,
|
||||||
|
36, 27
|
||||||
|
};
|
||||||
|
|
||||||
|
a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float));
|
||||||
|
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
|
||||||
|
d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float));
|
||||||
|
|
||||||
|
auto op_add = graph->CreateOperation<tim::vx::ops::AddN>(2);
|
||||||
|
(*op_add).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor});
|
||||||
|
|
||||||
|
tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
|
||||||
|
|
||||||
|
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||||
|
false,false,tuple_list);
|
||||||
|
|
||||||
|
(*op_gemm).BindInputs({c_tensor, d_tensor}).BindOutputs({out_tensor});
|
||||||
|
|
||||||
|
graph->Compile();
|
||||||
|
graph->Run();
|
||||||
|
|
||||||
|
std::vector<float> output(golden.size());
|
||||||
|
out_tensor->CopyDataFromTensor(output.data());
|
||||||
|
std::cout<<"the diff between golan and result:"<<std::endl;
|
||||||
|
for(uint32_t i=0;i<output.size();i++){
|
||||||
|
std::cout<<output[i] - golden[i]<<" ";
|
||||||
|
}
|
||||||
|
std::cout<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void custom_gemm_op_and_custom_gemm_op_test(){
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
|
||||||
|
tim::vx::ShapeType a_shape({2, 2});
|
||||||
|
tim::vx::ShapeType b_shape({2, 2});
|
||||||
|
tim::vx::ShapeType c_shape({2, 2});
|
||||||
|
tim::vx::ShapeType d_shape({2, 2});
|
||||||
|
|
||||||
|
tim::vx::ShapeType out_shape({2, 2});
|
||||||
|
|
||||||
|
tim::vx::TensorSpec a_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
a_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec b_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
b_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec c_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
c_shape, tim::vx::TensorAttribute::TRANSIENT);
|
||||||
|
tim::vx::TensorSpec d_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
d_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
out_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto a_tensor = graph->CreateTensor(a_spec);
|
||||||
|
auto b_tensor = graph->CreateTensor(b_spec);
|
||||||
|
auto c_tensor = graph->CreateTensor(c_spec);
|
||||||
|
auto d_tensor = graph->CreateTensor(d_spec);
|
||||||
|
auto out_tensor = graph->CreateTensor(out_spec);
|
||||||
|
|
||||||
|
std::vector<float> a_data = {
|
||||||
|
1, 1, 1,1
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<float> b_data = {
|
||||||
|
1, 1, 1, 1
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<float> d_data = {
|
||||||
|
1,1,1,1
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
4,4,4,4
|
||||||
|
};
|
||||||
|
|
||||||
|
a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float));
|
||||||
|
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
|
||||||
|
d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float));
|
||||||
|
|
||||||
|
tim::vx::ops::DeriveParmaTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0);
|
||||||
|
|
||||||
|
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||||
|
false,false,tuple_list);
|
||||||
|
|
||||||
|
(*op_gemm).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor});
|
||||||
|
|
||||||
|
auto op_gemm2 = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||||
|
false,false,tuple_list);
|
||||||
|
|
||||||
|
(*op_gemm2).BindInputs({c_tensor, d_tensor}).BindOutputs({out_tensor});
|
||||||
|
|
||||||
|
graph->Compile();
|
||||||
|
graph->Run();
|
||||||
|
|
||||||
|
std::vector<float> output(golden.size());
|
||||||
|
out_tensor->CopyDataFromTensor(output.data());
|
||||||
|
std::cout<<"the diff between golan and result:"<<std::endl;
|
||||||
|
for(uint32_t i=0;i<output.size();i++){
|
||||||
|
std::cout<<output[i] - golden[i]<<" ";
|
||||||
|
}
|
||||||
|
std::cout<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(){
|
||||||
|
custom_gemm_single_test();
|
||||||
|
custom_gemm_op_and_add_op_test();
|
||||||
|
custom_gemm_op_and_custom_gemm_op_test();
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
@ -105,6 +105,8 @@ if(TIM_VX_ENABLE_TEST)
|
||||||
target_include_directories(unit_test PRIVATE
|
target_include_directories(unit_test PRIVATE
|
||||||
${PROJECT_SOURCE_DIR}/include
|
${PROJECT_SOURCE_DIR}/include
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/vx
|
${CMAKE_CURRENT_SOURCE_DIR}/vx
|
||||||
|
${OVXLIB_INCLUDE_DIR}
|
||||||
|
${INC_DIRS}
|
||||||
)
|
)
|
||||||
|
|
||||||
install(TARGETS unit_test DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR})
|
install(TARGETS unit_test DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR})
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,10 @@ DirectMapOpImpl::DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt,
|
||||||
node_->uid = graph_->graph()->cur_nid;
|
node_->uid = graph_->graph()->cur_nid;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DirectMapOpImpl::DirectMapOpImpl(Graph* graph,DataLayout layout)
|
||||||
|
: OpImpl(graph, layout){}
|
||||||
|
|
||||||
|
|
||||||
DirectMapOpImpl& DirectMapOpImpl::BindInput(
|
DirectMapOpImpl& DirectMapOpImpl::BindInput(
|
||||||
const std::shared_ptr<Tensor>& tensor) {
|
const std::shared_ptr<Tensor>& tensor) {
|
||||||
inputs_tensor_.push_back(tensor);
|
inputs_tensor_.push_back(tensor);
|
||||||
|
|
@ -71,5 +75,16 @@ void DirectMapOpImpl::SetRoundingPolicy(OverflowPolicy overflow_policy,
|
||||||
node_->vx_param.accumulator_bits = accumulator_bits;
|
node_->vx_param.accumulator_bits = accumulator_bits;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CustomOpBaseImpl::CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc,
|
||||||
|
const char* kernel_name, DataLayout layout)
|
||||||
|
: DirectMapOpImpl(graph, layout) {
|
||||||
|
op_proc_ = proc;
|
||||||
|
vsi_nn_node_t* node = vsi_nn_AddExternalNode(graph_->graph(), operation_id,
|
||||||
|
proc, NULL, kernel_name);
|
||||||
|
node->uid = graph_->graph()->cur_nid;
|
||||||
|
SetNode(node);
|
||||||
|
SetRoundingPolicy();
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
@ -34,16 +34,16 @@ namespace vx {
|
||||||
|
|
||||||
class DirectMapOpImpl : public OpImpl {
|
class DirectMapOpImpl : public OpImpl {
|
||||||
public:
|
public:
|
||||||
// DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt = 0,
|
|
||||||
// int output_cnt = 0);
|
|
||||||
DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt = 0,
|
DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt = 0,
|
||||||
int output_cnt = 0, DataLayout layout = DataLayout::ANY);
|
int output_cnt = 0, DataLayout layout = DataLayout::ANY);
|
||||||
|
DirectMapOpImpl(Graph* graph,DataLayout layout = DataLayout::ANY);
|
||||||
~DirectMapOpImpl() {}
|
~DirectMapOpImpl() {}
|
||||||
|
|
||||||
DirectMapOpImpl& BindInput(const std::shared_ptr<Tensor>& tensor) override;
|
DirectMapOpImpl& BindInput(const std::shared_ptr<Tensor>& tensor) override;
|
||||||
DirectMapOpImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) override;
|
DirectMapOpImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) override;
|
||||||
|
|
||||||
vsi_nn_node_t* node() override { return this->node_; }
|
vsi_nn_node_t* node() override { return this->node_; }
|
||||||
|
void SetNode(vsi_nn_node_t* node) {this->node_ = node; }
|
||||||
|
|
||||||
void SetRoundingPolicy(
|
void SetRoundingPolicy(
|
||||||
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
||||||
|
|
@ -62,6 +62,14 @@ class DirectMapOpImpl : public OpImpl {
|
||||||
vsi_nn_node_t* node_{nullptr};
|
vsi_nn_node_t* node_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class CustomOpBaseImpl : public DirectMapOpImpl {
|
||||||
|
public:
|
||||||
|
CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc,
|
||||||
|
const char* kernel_name, DataLayout layout = DataLayout::ANY);
|
||||||
|
protected:
|
||||||
|
const void* op_proc_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,5 +33,10 @@ OpImpl::OpImpl(Graph* graph, uint32_t kind, int input_cnt, int output_cnt,
|
||||||
input_cnt_(input_cnt),
|
input_cnt_(input_cnt),
|
||||||
output_cnt_(output_cnt),
|
output_cnt_(output_cnt),
|
||||||
layout_(layout) {}
|
layout_(layout) {}
|
||||||
|
|
||||||
|
OpImpl::OpImpl(Graph* graph, DataLayout layout)
|
||||||
|
: graph_(reinterpret_cast<GraphImpl*>(graph)),
|
||||||
|
layout_(layout) {}
|
||||||
|
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,8 @@ class OpImpl {
|
||||||
public:
|
public:
|
||||||
OpImpl(Graph* graph, uint32_t kind, int input_cnt, int output_cnt,
|
OpImpl(Graph* graph, uint32_t kind, int input_cnt, int output_cnt,
|
||||||
DataLayout layout);
|
DataLayout layout);
|
||||||
|
OpImpl(Graph* graph, DataLayout layout);
|
||||||
|
|
||||||
virtual ~OpImpl() = default;
|
virtual ~OpImpl() = default;
|
||||||
virtual OpImpl& BindInput(const std::shared_ptr<Tensor>& tensor) = 0;
|
virtual OpImpl& BindInput(const std::shared_ptr<Tensor>& tensor) = 0;
|
||||||
virtual OpImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) = 0;
|
virtual OpImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) = 0;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,250 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2021 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <assert.h>
|
||||||
|
#include "tim/vx/ops.h"
|
||||||
|
#include "direct_map_op_impl.h"
|
||||||
|
#include "vsi_nn_pub.h"
|
||||||
|
|
||||||
|
#include "kernel/vsi_nn_kernel.h"
|
||||||
|
|
||||||
|
namespace tim {
|
||||||
|
namespace vx {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
static vsi_bool op_setup(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||||
|
vsi_nn_tensor_t** outputs);
|
||||||
|
|
||||||
|
static vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||||
|
vsi_nn_tensor_t** outputs);
|
||||||
|
|
||||||
|
static vx_status derive_kernel_init(vx_node node, const vx_reference* param,
|
||||||
|
vx_uint32 param_size);
|
||||||
|
|
||||||
|
static std::map<void*, CustomOpBase*> node_base_map_;
|
||||||
|
|
||||||
|
CustomOpBase::CustomOpBase(Graph* graph, uint32_t input_num,
|
||||||
|
uint32_t output_num, int32_t kernel_id,
|
||||||
|
const char* kernel_name)
|
||||||
|
: input_num_(input_num), output_num_(output_num) {
|
||||||
|
init_kernel_ = reinterpret_cast<void*>(derive_kernel_init);
|
||||||
|
vsi_nn_op_proc_t proc = {NULL, op_compute, NULL, NULL,
|
||||||
|
op_setup, NULL, input_num_, output_num_};
|
||||||
|
this->impl() = std::make_unique<CustomOpBaseImpl>(
|
||||||
|
graph, kernel_id, reinterpret_cast<void*>(&proc), kernel_name);
|
||||||
|
this->impl()->node()->nn_param.client_param = reinterpret_cast<void*>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
CustomOpBase::~CustomOpBase(){
|
||||||
|
auto iter = node_base_map_.find(this->vx_node_);
|
||||||
|
if (iter != node_base_map_.end()) {
|
||||||
|
node_base_map_.erase(this->vx_node_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vsi_bool op_setup(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||||
|
vsi_nn_tensor_t** outputs) {
|
||||||
|
CustomOpBase* op_this =
|
||||||
|
reinterpret_cast<CustomOpBase*>(self->nn_param.client_param);
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < op_this->output_num_; i++) {
|
||||||
|
std::vector<uint32_t> output_size;
|
||||||
|
op_this->outputs_size_.push_back(output_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < op_this->input_num_; i++) {
|
||||||
|
std::vector<uint32_t> input_size;
|
||||||
|
for (uint32_t j = 0; j < inputs[i]->attr.dim_num; j++) {
|
||||||
|
input_size.push_back(inputs[i]->attr.size[j]);
|
||||||
|
}
|
||||||
|
op_this->inputs_size_.push_back(input_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_this->setup_output_shape_info();
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < op_this->outputs_size_.size(); i++) {
|
||||||
|
outputs[i]->attr.dim_num = op_this->outputs_size_[i].size();
|
||||||
|
for (uint32_t j = 0; j < op_this->outputs_size_[i].size(); j++) {
|
||||||
|
outputs[i]->attr.size[j] = op_this->outputs_size_[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return TRUE;
|
||||||
|
};
|
||||||
|
|
||||||
|
vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||||
|
vsi_nn_tensor_t** outputs) {
|
||||||
|
vsi_status status = VSI_FAILURE;
|
||||||
|
auto kernel = vsi_nn_kernel_create(VSI_NN_KERNEL_TYPE_CL);
|
||||||
|
CustomOpBase* op_this =
|
||||||
|
reinterpret_cast<CustomOpBase*>(self->nn_param.client_param);
|
||||||
|
|
||||||
|
uint32_t param_num = op_this->param_list_.size();
|
||||||
|
|
||||||
|
std::vector<tim::vx::DataType> input_types;
|
||||||
|
for (uint32_t i = 0; i < op_this->input_num_; i++) {
|
||||||
|
if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32) {
|
||||||
|
input_types.push_back(tim::vx::DataType::FLOAT32);
|
||||||
|
} else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_UINT32) {
|
||||||
|
input_types.push_back(tim::vx::DataType::UINT32);
|
||||||
|
} else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_INT32) {
|
||||||
|
input_types.push_back(tim::vx::DataType::INT32);
|
||||||
|
} else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_BOOL8) {
|
||||||
|
input_types.push_back(tim::vx::DataType::BOOL8);
|
||||||
|
} else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_UINT8) {
|
||||||
|
input_types.push_back(tim::vx::DataType::UINT8);
|
||||||
|
} else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_INT8) {
|
||||||
|
input_types.push_back(tim::vx::DataType::INT8);
|
||||||
|
} else {
|
||||||
|
std::cout << "Can not find att type in op compute" << std::endl;
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string build_option;
|
||||||
|
op_this->extract_parameter_and_register(input_types, build_option);
|
||||||
|
|
||||||
|
snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_);
|
||||||
|
kernel->unique_id =
|
||||||
|
std::hash<std::string>()(std::string(op_this->func_name_));
|
||||||
|
vx_param_description_t kernel_param_def[param_num];
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < op_this->input_num_; i++) {
|
||||||
|
kernel_param_def[i] = {VX_INPUT, VX_TYPE_TENSOR,
|
||||||
|
VX_PARAMETER_STATE_REQUIRED};
|
||||||
|
}
|
||||||
|
for (uint32_t i = 0; i < op_this->output_num_; i++) {
|
||||||
|
kernel_param_def[op_this->input_num_ + i] = {VX_OUTPUT, VX_TYPE_TENSOR,
|
||||||
|
VX_PARAMETER_STATE_REQUIRED};
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < param_num; i++) {
|
||||||
|
kernel_param_def[op_this->input_num_ + op_this->output_num_ + i] = {
|
||||||
|
VX_INPUT, VX_TYPE_SCALAR, VX_PARAMETER_STATE_REQUIRED};
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel->info.parameters = kernel_param_def;
|
||||||
|
kernel->info.enumeration = KERNEL_ID_PLACEHOLDER;
|
||||||
|
kernel->info.numParams = param_num;
|
||||||
|
kernel->info.initialize =
|
||||||
|
reinterpret_cast<vx_kernel_initialize_f>(op_this->init_kernel_);
|
||||||
|
|
||||||
|
vsi_nn_kernel_add_source(kernel, VSI_NN_GPU_SOURCE_FMT_EXECUTABLE, 1,
|
||||||
|
"executable_name");
|
||||||
|
|
||||||
|
vsi_nn_kernel_add_source(kernel, VSI_NN_GPU_SOURCE_FMT_CODE, 2, "helper",
|
||||||
|
"fmt_code_name");
|
||||||
|
|
||||||
|
const char* tmp[] = {"", op_this->kernel_resource_};
|
||||||
|
const char** resource = tmp;
|
||||||
|
|
||||||
|
vsi_nn_kernel_add_build_option(kernel, build_option.c_str());
|
||||||
|
|
||||||
|
auto node = vsi_nn_kernel_create_node_ext(self->graph, kernel, resource);
|
||||||
|
if (node) {
|
||||||
|
vsi_nn_kernel_node_param_t node_params[param_num] = {NULL};
|
||||||
|
vsi_nn_kernel_node_pack_io(node_params, param_num, inputs,
|
||||||
|
op_this->input_num_, outputs,
|
||||||
|
op_this->output_num_);
|
||||||
|
|
||||||
|
uint32_t input_start = op_this->input_num_ + op_this->output_num_;
|
||||||
|
for (uint32_t i = 0; i < op_this->param_list_.size(); i++) {
|
||||||
|
if (op_this->param_list_[i].type == tim::vx::DataType::FLOAT32) {
|
||||||
|
node_params[input_start++] = vsi_nn_kernel_scalar_create(
|
||||||
|
self->graph, F32, &(op_this->param_list_[i].data.f));
|
||||||
|
} else if (op_this->param_list_[i].type == tim::vx::DataType::UINT32) {
|
||||||
|
node_params[input_start++] = vsi_nn_kernel_scalar_create(
|
||||||
|
self->graph, U32, &(op_this->param_list_[i].data.ui));
|
||||||
|
} else if (op_this->param_list_[i].type == tim::vx::DataType::INT32) {
|
||||||
|
node_params[input_start++] = vsi_nn_kernel_scalar_create(
|
||||||
|
self->graph, I32, &(op_this->param_list_[i].data.i));
|
||||||
|
} else if (op_this->param_list_[i].type == tim::vx::DataType::BOOL8) {
|
||||||
|
node_params[input_start++] = vsi_nn_kernel_scalar_create(
|
||||||
|
self->graph, BOOL8, &(op_this->param_list_[i].data.b));
|
||||||
|
}else if (op_this->param_list_[i].type == tim::vx::DataType::UINT8) {
|
||||||
|
node_params[input_start++] = vsi_nn_kernel_scalar_create(
|
||||||
|
self->graph, U8, &(op_this->param_list_[i].data.b));
|
||||||
|
} else if (op_this->param_list_[i].type == tim::vx::DataType::INT8) {
|
||||||
|
node_params[input_start++] = vsi_nn_kernel_scalar_create(
|
||||||
|
self->graph, I8, &(op_this->param_list_[i].data.b));
|
||||||
|
} else{
|
||||||
|
std::cout << "Can not find scalar type in op compute" << std::endl;
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
input_start = op_this->input_num_ + op_this->output_num_;
|
||||||
|
status = vsi_nn_kernel_node_pass_param(node, node_params, param_num);
|
||||||
|
for (uint32_t i = 0; i < param_num; i++) {
|
||||||
|
vsi_nn_kernel_scalar_release(&node_params[input_start + i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self->n = (vx_node)node;
|
||||||
|
|
||||||
|
node_base_map_.insert(std::pair<void*, CustomOpBase*>(reinterpret_cast<void*>(self->n), op_this));
|
||||||
|
op_this->vx_node_ = reinterpret_cast<void*>(self->n);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
vx_status derive_kernel_init(vx_node node, const vx_reference* param,
|
||||||
|
vx_uint32 param_size) {
|
||||||
|
vsi_status status = VSI_FAILURE;
|
||||||
|
if (param_size == 0 && param == nullptr) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
gpu_param_t gpu_param = {3, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
|
||||||
|
|
||||||
|
std::vector<size_t> global_size(3);
|
||||||
|
std::vector<size_t> local_size(3);
|
||||||
|
uint32_t dim = 0;
|
||||||
|
|
||||||
|
auto iter = node_base_map_.find(reinterpret_cast<void*>(node));
|
||||||
|
if (iter != node_base_map_.end()) {
|
||||||
|
iter->second->setup_kernel_param(dim, global_size, local_size);
|
||||||
|
} else {
|
||||||
|
std::cout << "Something wrong in finding gpu param setup function"
|
||||||
|
<< std::endl;
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
gpu_param.dim = dim;
|
||||||
|
gpu_param.global_scale[0] = 1;
|
||||||
|
gpu_param.global_scale[1] = 1;
|
||||||
|
gpu_param.global_scale[2] = 1;
|
||||||
|
|
||||||
|
gpu_param.global_size[0] = global_size[0];
|
||||||
|
gpu_param.global_size[1] = global_size[1];
|
||||||
|
gpu_param.global_size[2] = global_size[2];
|
||||||
|
|
||||||
|
gpu_param.local_size[0] = local_size[0];
|
||||||
|
gpu_param.local_size[1] = local_size[1];
|
||||||
|
gpu_param.local_size[2] = local_size[2];
|
||||||
|
status = vsi_nn_kernel_gpu_config(node, &gpu_param);
|
||||||
|
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace vx
|
||||||
|
} // namespace tim
|
||||||
Loading…
Reference in New Issue