Added customize operator APIs(#315)
Co-authored-by: zhouheng.zheng <zhouheng.zheng@ouotlook.com>
This commit is contained in:
parent
161bb8a7c4
commit
b02aa8b8c4
|
|
@ -85,5 +85,6 @@
|
|||
#include "tim/vx/ops/unidirectional_sequence_lstm.h"
|
||||
#include "tim/vx/ops/unstack.h"
|
||||
#include "tim/vx/ops/conv3d.h"
|
||||
#include "tim/vx/ops/custom_base.h"
|
||||
|
||||
#endif /* TIM_VX_OPS_H_ */
|
||||
|
|
|
|||
|
|
@ -0,0 +1,123 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2021 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_CUSTOM_BASE_H_
|
||||
#define TIM_VX_OPS_CUSTOM_BASE_H_
|
||||
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
#include <typeinfo>
|
||||
#include <tuple>
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
#define gpu_align(n, align) ((n) + ((align)-1)) & ~((align)-1)
|
||||
|
||||
__attribute__((unused)) static int32_t gobal_kernel_id_ = 0;
|
||||
|
||||
struct Param {
|
||||
union {
|
||||
float f;
|
||||
int i;
|
||||
bool b;
|
||||
uint ui;
|
||||
} data;
|
||||
tim::vx::DataType type;
|
||||
};
|
||||
|
||||
template <size_t I = 0, typename... Ts>
|
||||
typename std::enable_if<I == sizeof...(Ts), void>::type
|
||||
transform_tuple_to_param_list(std::tuple<Ts...> tup,
|
||||
std::vector<Param>& param_list) {
|
||||
auto size = std::tuple_size<decltype(tup)>::value;
|
||||
if (size == I && param_list.size() != 0) {
|
||||
//nothing to do
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <size_t I = 0, typename... Ts>
|
||||
typename std::enable_if<(I < sizeof...(Ts)), void>::type
|
||||
transform_tuple_to_param_list(std::tuple<Ts...> tup,
|
||||
std::vector<Param>& param_list) {
|
||||
if (typeid(std::get<I>(tup)) == typeid(float)) {
|
||||
Param p;
|
||||
p.data.f = std::get<I>(tup);
|
||||
p.type = tim::vx::DataType::FLOAT32;
|
||||
param_list.push_back(p);
|
||||
} else if (typeid(std::get<I>(tup)) == typeid(uint32_t)) {
|
||||
Param p;
|
||||
p.data.ui = std::get<I>(tup);
|
||||
p.type = tim::vx::DataType::UINT32;
|
||||
param_list.push_back(p);
|
||||
} else if (typeid(std::get<I>(tup)) == typeid(int32_t)) {
|
||||
Param p;
|
||||
p.data.i = std::get<I>(tup);
|
||||
p.type = tim::vx::DataType::INT32;
|
||||
param_list.push_back(p);
|
||||
} else if (typeid(std::get<I>(tup)) == typeid(bool)) {
|
||||
Param p;
|
||||
p.data.b = std::get<I>(tup);
|
||||
p.type = tim::vx::DataType::BOOL8;
|
||||
param_list.push_back(p);
|
||||
} else {
|
||||
std::cout << "need more else if type of param list " << std::endl;
|
||||
}
|
||||
// Go to next element
|
||||
transform_tuple_to_param_list<I + 1>(tup, param_list);
|
||||
}
|
||||
|
||||
class CustomOpBase : public Operation {
|
||||
public:
|
||||
CustomOpBase(Graph* graph, uint32_t input_num, uint32_t output_num,
|
||||
int32_t kernel_id, const char* kernel_name);
|
||||
|
||||
~CustomOpBase();
|
||||
virtual void extract_parameter_and_register(
|
||||
std::vector<tim::vx::DataType> input_types,
|
||||
std::string& build_option) = 0;
|
||||
|
||||
virtual void setup_output_shape_info() = 0;
|
||||
|
||||
virtual void setup_kernel_param(uint32_t& dim,
|
||||
std::vector<size_t>& gobal_size,
|
||||
std::vector<size_t>& local_size) = 0;
|
||||
|
||||
std::vector<Param> param_list_;
|
||||
std::vector<tim::vx::ShapeType> inputs_size_;
|
||||
std::vector<tim::vx::ShapeType> outputs_size_;
|
||||
|
||||
const char* func_name_;
|
||||
const char* kernel_resource_;
|
||||
void* init_kernel_;
|
||||
void* vx_node_;
|
||||
|
||||
uint32_t input_num_;
|
||||
uint32_t output_num_;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
||||
#endif /* TIM_VX_OPS_MATMUL_H_ */
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
add_subdirectory("benchmark_test")
|
||||
add_subdirectory("custom_op_test")
|
||||
add_subdirectory("lenet")
|
||||
if(${TIM_VX_ENABLE_VIPLITE})
|
||||
add_subdirectory("lenet_lite")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
message("samples/custom_op_test")
|
||||
|
||||
set(TARGET_NAME "custom_op_test")
|
||||
|
||||
aux_source_directory(. ${TARGET_NAME}_SRCS)
|
||||
add_executable(${TARGET_NAME} ${${TARGET_NAME}_SRCS})
|
||||
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE tim-vx)
|
||||
target_include_directories(${TARGET_NAME} PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
)
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2021 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#include "custom_gemm.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
const char* CustomGemm::kernel_name_ = "xxxx_name12345";
|
||||
int32_t CustomGemm::kernel_id_ = -1 * (++gobal_kernel_id_);
|
||||
|
||||
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2021 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef TIM_VX_OPS_CUSTOM_GEMM_H_
|
||||
#define TIM_VX_OPS_CUSTOM_GEMM_H_
|
||||
|
||||
#include "tim/vx/ops/custom_base.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
//scalar param for kernel function input
|
||||
using DeriveParmaTuple = std::tuple<int, int, int, int, int, float, float,
|
||||
float, float, float, float>;
|
||||
|
||||
class CustomGemm : public CustomOpBase {
|
||||
public:
|
||||
CustomGemm(Graph* graph, bool trans_a, bool trans_b,
|
||||
DeriveParmaTuple tuple_list,
|
||||
uint32_t input_num = 2, uint32_t output_num = 1)
|
||||
: CustomOpBase(graph, input_num, output_num,
|
||||
CustomGemm::kernel_id_, CustomGemm::kernel_name_),
|
||||
trans_a_(trans_a),trans_b_(trans_b) {
|
||||
tuple_list_.swap(tuple_list);
|
||||
transform_tuple_to_param_list(tuple_list_, param_list_);
|
||||
|
||||
kernel_resource_ =
|
||||
"__kernel void gemm_F32F32toF32_2D(\n\
|
||||
__read_only image2d_t inputA,\n\
|
||||
__read_only image2d_t inputB,\n\
|
||||
__write_only image2d_t output,\n\
|
||||
int M,\n\
|
||||
int K,\n\
|
||||
int N,\n\
|
||||
int ac2zero,\n\
|
||||
int bc2zero,\n\
|
||||
float scale_a,\n\
|
||||
float zp_a,\n\
|
||||
float scale_b,\n\
|
||||
float zp_b,\n\
|
||||
float scale_out,\n\
|
||||
float zp_out\n\
|
||||
)\n\
|
||||
{\n\
|
||||
int4 coord = (int4)(get_global_id(0), get_global_id(1), 0, 0);\n\
|
||||
float4 sum = (float4)(0);\n\
|
||||
\n\
|
||||
for(; coord.z < K;)\n\
|
||||
{\n\
|
||||
float4 tempA0;\n\
|
||||
float4 tempB0;\n\
|
||||
\n\
|
||||
tempA0 = read_imagef(inputA, coord.zy);\n\
|
||||
tempB0 = read_imagef(inputB, coord.xz);\n\
|
||||
coord.z++;\n\
|
||||
\n\
|
||||
sum = sum + tempA0 * tempB0;\n\
|
||||
}\n\
|
||||
write_imagef(output, coord.xy, sum);\n\
|
||||
}\n\
|
||||
\n\
|
||||
";
|
||||
};
|
||||
|
||||
protected:
|
||||
const char* kernel_NotTransA_NotTransB = "gemm_F32F32toF32_2D";
|
||||
const char* kernel_TransA_NotTransB = ".....";
|
||||
DeriveParmaTuple tuple_list_;
|
||||
bool trans_a_;
|
||||
bool trans_b_;
|
||||
static const char* kernel_name_;
|
||||
static int32_t kernel_id_;
|
||||
|
||||
//function for setup output
|
||||
void setup_output_shape_info() override {
|
||||
if (!trans_a_ && !trans_a_) {
|
||||
outputs_size_[0].push_back(inputs_size_[0][1]);
|
||||
outputs_size_[0].push_back(inputs_size_[1][0]);
|
||||
} else {
|
||||
//other situation: set up outputs_size
|
||||
//outputs_size_[0].push_back()......
|
||||
}
|
||||
}
|
||||
|
||||
//function for kernel select and build option
|
||||
void extract_parameter_and_register(
|
||||
std::vector<tim::vx::DataType> input_types,
|
||||
std::string& build_option) override {
|
||||
if (trans_a_ == false &&
|
||||
trans_a_ == false &&
|
||||
input_types[0] == tim::vx::DataType::FLOAT32 &&
|
||||
input_types[1] == tim::vx::DataType::FLOAT32) {
|
||||
func_name_ = kernel_NotTransA_NotTransB;
|
||||
build_option = "";
|
||||
} else {
|
||||
//other situation: named func_name_ and setup param_list
|
||||
//func_name_ = "......";
|
||||
//std::get<2>(param_list) = ......
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
//function for kernel local size and gobal size
|
||||
void setup_kernel_param(uint32_t& dim, std::vector<size_t>& global_size,
|
||||
std::vector<size_t>& local_size) {
|
||||
dim = 3;
|
||||
local_size[0] = 0;
|
||||
local_size[1] = 0;
|
||||
local_size[2] = 0;
|
||||
|
||||
global_size[0] = gpu_align(outputs_size_[0][0], 4);
|
||||
global_size[1] = gpu_align(outputs_size_[0][1], 4);
|
||||
global_size[2] = outputs_size_[0].size() > 2 ? outputs_size_[0][2] : 1;
|
||||
}
|
||||
|
||||
std::shared_ptr<Operation> Clone(
|
||||
std::shared_ptr<Graph>& graph) const override {
|
||||
return graph->CreateOperation<CustomGemm>(trans_a_,trans_b_,
|
||||
this->tuple_list_, this->input_num_, this->output_num_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
#endif
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2021 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#include "tim/vx/context.h"
|
||||
#include "tim/vx/graph.h"
|
||||
#include "tim/vx/ops.h"
|
||||
#include "custom_gemm.h"
|
||||
#include <tuple>
|
||||
|
||||
void custom_gemm_single_test(){
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
tim::vx::ShapeType a_shape({6, 2});
|
||||
tim::vx::ShapeType b_shape({2, 6});
|
||||
tim::vx::ShapeType out_shape({2, 2});
|
||||
tim::vx::TensorSpec a_spec(tim::vx::DataType::FLOAT32,
|
||||
a_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec b_spec(tim::vx::DataType::FLOAT32,
|
||||
b_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32,
|
||||
out_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto a_tensor = graph->CreateTensor(a_spec);
|
||||
auto b_tensor = graph->CreateTensor(b_spec);
|
||||
auto out_tensor = graph->CreateTensor(out_spec);
|
||||
|
||||
std::vector<float> a_data = {
|
||||
1, 2, 3, 4, 5, 6,
|
||||
-1, -2, -3, -4, -5, -6
|
||||
};
|
||||
std::vector<float> b_data = {
|
||||
6, 5,
|
||||
4, 3,
|
||||
2, 1,
|
||||
-6, -5,
|
||||
-4, -3,
|
||||
-2, -1
|
||||
};
|
||||
std::vector<float> golden = {
|
||||
-36, -27,
|
||||
36, 27
|
||||
};
|
||||
|
||||
a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float));
|
||||
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
|
||||
|
||||
tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||
false,false,tuple_list);
|
||||
|
||||
(*op).BindInputs({a_tensor, b_tensor}).BindOutputs({out_tensor});
|
||||
|
||||
graph->Compile();
|
||||
graph->Run();
|
||||
|
||||
std::vector<float> output(golden.size());
|
||||
out_tensor->CopyDataFromTensor(output.data());
|
||||
|
||||
std::cout<<"the diff between golan and result:"<<std::endl;
|
||||
for(uint32_t i=0;i<output.size();i++){
|
||||
std::cout<<output[i] - golden[i]<<" ";
|
||||
}
|
||||
std::cout<<std::endl;
|
||||
}
|
||||
|
||||
void custom_gemm_op_and_add_op_test(){
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
tim::vx::ShapeType a_shape({6, 2});
|
||||
tim::vx::ShapeType b_shape({6, 2});
|
||||
tim::vx::ShapeType c_shape({6, 2});
|
||||
tim::vx::ShapeType d_shape({2, 6});
|
||||
|
||||
tim::vx::ShapeType out_shape({2, 2});
|
||||
|
||||
tim::vx::TensorSpec a_spec(tim::vx::DataType::FLOAT32,
|
||||
a_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec b_spec(tim::vx::DataType::FLOAT32,
|
||||
b_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec c_spec(tim::vx::DataType::FLOAT32,
|
||||
c_shape, tim::vx::TensorAttribute::TRANSIENT);
|
||||
tim::vx::TensorSpec d_spec(tim::vx::DataType::FLOAT32,
|
||||
d_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32,
|
||||
out_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto a_tensor = graph->CreateTensor(a_spec);
|
||||
auto b_tensor = graph->CreateTensor(b_spec);
|
||||
auto c_tensor = graph->CreateTensor(c_spec);
|
||||
auto d_tensor = graph->CreateTensor(d_spec);
|
||||
auto out_tensor = graph->CreateTensor(out_spec);
|
||||
|
||||
std::vector<float> a_data = {
|
||||
0, 1, 2, 3, 4, 5,
|
||||
-1, -2, -3, -4, -5, -6
|
||||
};
|
||||
|
||||
std::vector<float> b_data = {
|
||||
1, 1, 1, 1, 1, 1,
|
||||
0, 0, 0, 0, 0, 0
|
||||
};
|
||||
|
||||
std::vector<float> d_data = {
|
||||
6, 5,
|
||||
4, 3,
|
||||
2, 1,
|
||||
-6, -5,
|
||||
-4, -3,
|
||||
-2, -1
|
||||
};
|
||||
std::vector<float> golden = {
|
||||
-36, -27,
|
||||
36, 27
|
||||
};
|
||||
|
||||
a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float));
|
||||
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
|
||||
d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float));
|
||||
|
||||
auto op_add = graph->CreateOperation<tim::vx::ops::AddN>(2);
|
||||
(*op_add).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor});
|
||||
|
||||
tim::vx::ops::DeriveParmaTuple tuple_list(2,6,6,0,0,1.0,0,1.0,0,1.0,0);
|
||||
|
||||
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||
false,false,tuple_list);
|
||||
|
||||
(*op_gemm).BindInputs({c_tensor, d_tensor}).BindOutputs({out_tensor});
|
||||
|
||||
graph->Compile();
|
||||
graph->Run();
|
||||
|
||||
std::vector<float> output(golden.size());
|
||||
out_tensor->CopyDataFromTensor(output.data());
|
||||
std::cout<<"the diff between golan and result:"<<std::endl;
|
||||
for(uint32_t i=0;i<output.size();i++){
|
||||
std::cout<<output[i] - golden[i]<<" ";
|
||||
}
|
||||
std::cout<<std::endl;
|
||||
}
|
||||
|
||||
void custom_gemm_op_and_custom_gemm_op_test(){
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
|
||||
tim::vx::ShapeType a_shape({2, 2});
|
||||
tim::vx::ShapeType b_shape({2, 2});
|
||||
tim::vx::ShapeType c_shape({2, 2});
|
||||
tim::vx::ShapeType d_shape({2, 2});
|
||||
|
||||
tim::vx::ShapeType out_shape({2, 2});
|
||||
|
||||
tim::vx::TensorSpec a_spec(tim::vx::DataType::FLOAT32,
|
||||
a_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec b_spec(tim::vx::DataType::FLOAT32,
|
||||
b_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec c_spec(tim::vx::DataType::FLOAT32,
|
||||
c_shape, tim::vx::TensorAttribute::TRANSIENT);
|
||||
tim::vx::TensorSpec d_spec(tim::vx::DataType::FLOAT32,
|
||||
d_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32,
|
||||
out_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto a_tensor = graph->CreateTensor(a_spec);
|
||||
auto b_tensor = graph->CreateTensor(b_spec);
|
||||
auto c_tensor = graph->CreateTensor(c_spec);
|
||||
auto d_tensor = graph->CreateTensor(d_spec);
|
||||
auto out_tensor = graph->CreateTensor(out_spec);
|
||||
|
||||
std::vector<float> a_data = {
|
||||
1, 1, 1,1
|
||||
};
|
||||
|
||||
std::vector<float> b_data = {
|
||||
1, 1, 1, 1
|
||||
};
|
||||
|
||||
std::vector<float> d_data = {
|
||||
1,1,1,1
|
||||
};
|
||||
std::vector<float> golden = {
|
||||
4,4,4,4
|
||||
};
|
||||
|
||||
a_tensor->CopyDataToTensor(a_data.data(), a_data.size() * sizeof(float));
|
||||
b_tensor->CopyDataToTensor(b_data.data(), b_data.size() * sizeof(float));
|
||||
d_tensor->CopyDataToTensor(d_data.data(), d_data.size() * sizeof(float));
|
||||
|
||||
tim::vx::ops::DeriveParmaTuple tuple_list(2,2,2,0,0,1.0,0,1.0,0,1.0,0);
|
||||
|
||||
auto op_gemm = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||
false,false,tuple_list);
|
||||
|
||||
(*op_gemm).BindInputs({a_tensor, b_tensor}).BindOutputs({c_tensor});
|
||||
|
||||
auto op_gemm2 = graph->CreateOperation<tim::vx::ops::CustomGemm>(
|
||||
false,false,tuple_list);
|
||||
|
||||
(*op_gemm2).BindInputs({c_tensor, d_tensor}).BindOutputs({out_tensor});
|
||||
|
||||
graph->Compile();
|
||||
graph->Run();
|
||||
|
||||
std::vector<float> output(golden.size());
|
||||
out_tensor->CopyDataFromTensor(output.data());
|
||||
std::cout<<"the diff between golan and result:"<<std::endl;
|
||||
for(uint32_t i=0;i<output.size();i++){
|
||||
std::cout<<output[i] - golden[i]<<" ";
|
||||
}
|
||||
std::cout<<std::endl;
|
||||
}
|
||||
|
||||
int main(){
|
||||
custom_gemm_single_test();
|
||||
custom_gemm_op_and_add_op_test();
|
||||
custom_gemm_op_and_custom_gemm_op_test();
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -105,6 +105,8 @@ if(TIM_VX_ENABLE_TEST)
|
|||
target_include_directories(unit_test PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/vx
|
||||
${OVXLIB_INCLUDE_DIR}
|
||||
${INC_DIRS}
|
||||
)
|
||||
|
||||
install(TARGETS unit_test DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR})
|
||||
|
|
|
|||
|
|
@ -36,6 +36,10 @@ DirectMapOpImpl::DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt,
|
|||
node_->uid = graph_->graph()->cur_nid;
|
||||
}
|
||||
|
||||
DirectMapOpImpl::DirectMapOpImpl(Graph* graph,DataLayout layout)
|
||||
: OpImpl(graph, layout){}
|
||||
|
||||
|
||||
DirectMapOpImpl& DirectMapOpImpl::BindInput(
|
||||
const std::shared_ptr<Tensor>& tensor) {
|
||||
inputs_tensor_.push_back(tensor);
|
||||
|
|
@ -71,5 +75,16 @@ void DirectMapOpImpl::SetRoundingPolicy(OverflowPolicy overflow_policy,
|
|||
node_->vx_param.accumulator_bits = accumulator_bits;
|
||||
}
|
||||
|
||||
CustomOpBaseImpl::CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc,
|
||||
const char* kernel_name, DataLayout layout)
|
||||
: DirectMapOpImpl(graph, layout) {
|
||||
op_proc_ = proc;
|
||||
vsi_nn_node_t* node = vsi_nn_AddExternalNode(graph_->graph(), operation_id,
|
||||
proc, NULL, kernel_name);
|
||||
node->uid = graph_->graph()->cur_nid;
|
||||
SetNode(node);
|
||||
SetRoundingPolicy();
|
||||
};
|
||||
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
@ -34,16 +34,16 @@ namespace vx {
|
|||
|
||||
class DirectMapOpImpl : public OpImpl {
|
||||
public:
|
||||
// DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt = 0,
|
||||
// int output_cnt = 0);
|
||||
DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt = 0,
|
||||
int output_cnt = 0, DataLayout layout = DataLayout::ANY);
|
||||
DirectMapOpImpl(Graph* graph,DataLayout layout = DataLayout::ANY);
|
||||
~DirectMapOpImpl() {}
|
||||
|
||||
DirectMapOpImpl& BindInput(const std::shared_ptr<Tensor>& tensor) override;
|
||||
DirectMapOpImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) override;
|
||||
|
||||
vsi_nn_node_t* node() override { return this->node_; }
|
||||
void SetNode(vsi_nn_node_t* node) {this->node_ = node; }
|
||||
|
||||
void SetRoundingPolicy(
|
||||
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
||||
|
|
@ -62,6 +62,14 @@ class DirectMapOpImpl : public OpImpl {
|
|||
vsi_nn_node_t* node_{nullptr};
|
||||
};
|
||||
|
||||
class CustomOpBaseImpl : public DirectMapOpImpl {
|
||||
public:
|
||||
CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc,
|
||||
const char* kernel_name, DataLayout layout = DataLayout::ANY);
|
||||
protected:
|
||||
const void* op_proc_;
|
||||
};
|
||||
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
||||
|
|
|
|||
|
|
@ -33,5 +33,10 @@ OpImpl::OpImpl(Graph* graph, uint32_t kind, int input_cnt, int output_cnt,
|
|||
input_cnt_(input_cnt),
|
||||
output_cnt_(output_cnt),
|
||||
layout_(layout) {}
|
||||
|
||||
OpImpl::OpImpl(Graph* graph, DataLayout layout)
|
||||
: graph_(reinterpret_cast<GraphImpl*>(graph)),
|
||||
layout_(layout) {}
|
||||
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ class OpImpl {
|
|||
public:
|
||||
OpImpl(Graph* graph, uint32_t kind, int input_cnt, int output_cnt,
|
||||
DataLayout layout);
|
||||
OpImpl(Graph* graph, DataLayout layout);
|
||||
|
||||
virtual ~OpImpl() = default;
|
||||
virtual OpImpl& BindInput(const std::shared_ptr<Tensor>& tensor) = 0;
|
||||
virtual OpImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) = 0;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,250 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2021 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
|
||||
#include <map>
|
||||
#include <assert.h>
|
||||
#include "tim/vx/ops.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "vsi_nn_pub.h"
|
||||
|
||||
#include "kernel/vsi_nn_kernel.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
static vsi_bool op_setup(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||
vsi_nn_tensor_t** outputs);
|
||||
|
||||
static vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||
vsi_nn_tensor_t** outputs);
|
||||
|
||||
static vx_status derive_kernel_init(vx_node node, const vx_reference* param,
|
||||
vx_uint32 param_size);
|
||||
|
||||
static std::map<void*, CustomOpBase*> node_base_map_;
|
||||
|
||||
CustomOpBase::CustomOpBase(Graph* graph, uint32_t input_num,
|
||||
uint32_t output_num, int32_t kernel_id,
|
||||
const char* kernel_name)
|
||||
: input_num_(input_num), output_num_(output_num) {
|
||||
init_kernel_ = reinterpret_cast<void*>(derive_kernel_init);
|
||||
vsi_nn_op_proc_t proc = {NULL, op_compute, NULL, NULL,
|
||||
op_setup, NULL, input_num_, output_num_};
|
||||
this->impl() = std::make_unique<CustomOpBaseImpl>(
|
||||
graph, kernel_id, reinterpret_cast<void*>(&proc), kernel_name);
|
||||
this->impl()->node()->nn_param.client_param = reinterpret_cast<void*>(this);
|
||||
}
|
||||
|
||||
CustomOpBase::~CustomOpBase(){
|
||||
auto iter = node_base_map_.find(this->vx_node_);
|
||||
if (iter != node_base_map_.end()) {
|
||||
node_base_map_.erase(this->vx_node_);
|
||||
}
|
||||
}
|
||||
vsi_bool op_setup(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||
vsi_nn_tensor_t** outputs) {
|
||||
CustomOpBase* op_this =
|
||||
reinterpret_cast<CustomOpBase*>(self->nn_param.client_param);
|
||||
|
||||
for (uint32_t i = 0; i < op_this->output_num_; i++) {
|
||||
std::vector<uint32_t> output_size;
|
||||
op_this->outputs_size_.push_back(output_size);
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < op_this->input_num_; i++) {
|
||||
std::vector<uint32_t> input_size;
|
||||
for (uint32_t j = 0; j < inputs[i]->attr.dim_num; j++) {
|
||||
input_size.push_back(inputs[i]->attr.size[j]);
|
||||
}
|
||||
op_this->inputs_size_.push_back(input_size);
|
||||
}
|
||||
|
||||
op_this->setup_output_shape_info();
|
||||
|
||||
for (uint32_t i = 0; i < op_this->outputs_size_.size(); i++) {
|
||||
outputs[i]->attr.dim_num = op_this->outputs_size_[i].size();
|
||||
for (uint32_t j = 0; j < op_this->outputs_size_[i].size(); j++) {
|
||||
outputs[i]->attr.size[j] = op_this->outputs_size_[i][j];
|
||||
}
|
||||
}
|
||||
return TRUE;
|
||||
};
|
||||
|
||||
vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||
vsi_nn_tensor_t** outputs) {
|
||||
vsi_status status = VSI_FAILURE;
|
||||
auto kernel = vsi_nn_kernel_create(VSI_NN_KERNEL_TYPE_CL);
|
||||
CustomOpBase* op_this =
|
||||
reinterpret_cast<CustomOpBase*>(self->nn_param.client_param);
|
||||
|
||||
uint32_t param_num = op_this->param_list_.size();
|
||||
|
||||
std::vector<tim::vx::DataType> input_types;
|
||||
for (uint32_t i = 0; i < op_this->input_num_; i++) {
|
||||
if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32) {
|
||||
input_types.push_back(tim::vx::DataType::FLOAT32);
|
||||
} else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_UINT32) {
|
||||
input_types.push_back(tim::vx::DataType::UINT32);
|
||||
} else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_INT32) {
|
||||
input_types.push_back(tim::vx::DataType::INT32);
|
||||
} else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_BOOL8) {
|
||||
input_types.push_back(tim::vx::DataType::BOOL8);
|
||||
} else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_UINT8) {
|
||||
input_types.push_back(tim::vx::DataType::UINT8);
|
||||
} else if (inputs[i]->attr.dtype.vx_type == VSI_NN_TYPE_INT8) {
|
||||
input_types.push_back(tim::vx::DataType::INT8);
|
||||
} else {
|
||||
std::cout << "Can not find att type in op compute" << std::endl;
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
std::string build_option;
|
||||
op_this->extract_parameter_and_register(input_types, build_option);
|
||||
|
||||
snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_);
|
||||
kernel->unique_id =
|
||||
std::hash<std::string>()(std::string(op_this->func_name_));
|
||||
vx_param_description_t kernel_param_def[param_num];
|
||||
|
||||
for (uint32_t i = 0; i < op_this->input_num_; i++) {
|
||||
kernel_param_def[i] = {VX_INPUT, VX_TYPE_TENSOR,
|
||||
VX_PARAMETER_STATE_REQUIRED};
|
||||
}
|
||||
for (uint32_t i = 0; i < op_this->output_num_; i++) {
|
||||
kernel_param_def[op_this->input_num_ + i] = {VX_OUTPUT, VX_TYPE_TENSOR,
|
||||
VX_PARAMETER_STATE_REQUIRED};
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < param_num; i++) {
|
||||
kernel_param_def[op_this->input_num_ + op_this->output_num_ + i] = {
|
||||
VX_INPUT, VX_TYPE_SCALAR, VX_PARAMETER_STATE_REQUIRED};
|
||||
}
|
||||
|
||||
kernel->info.parameters = kernel_param_def;
|
||||
kernel->info.enumeration = KERNEL_ID_PLACEHOLDER;
|
||||
kernel->info.numParams = param_num;
|
||||
kernel->info.initialize =
|
||||
reinterpret_cast<vx_kernel_initialize_f>(op_this->init_kernel_);
|
||||
|
||||
vsi_nn_kernel_add_source(kernel, VSI_NN_GPU_SOURCE_FMT_EXECUTABLE, 1,
|
||||
"executable_name");
|
||||
|
||||
vsi_nn_kernel_add_source(kernel, VSI_NN_GPU_SOURCE_FMT_CODE, 2, "helper",
|
||||
"fmt_code_name");
|
||||
|
||||
const char* tmp[] = {"", op_this->kernel_resource_};
|
||||
const char** resource = tmp;
|
||||
|
||||
vsi_nn_kernel_add_build_option(kernel, build_option.c_str());
|
||||
|
||||
auto node = vsi_nn_kernel_create_node_ext(self->graph, kernel, resource);
|
||||
if (node) {
|
||||
vsi_nn_kernel_node_param_t node_params[param_num] = {NULL};
|
||||
vsi_nn_kernel_node_pack_io(node_params, param_num, inputs,
|
||||
op_this->input_num_, outputs,
|
||||
op_this->output_num_);
|
||||
|
||||
uint32_t input_start = op_this->input_num_ + op_this->output_num_;
|
||||
for (uint32_t i = 0; i < op_this->param_list_.size(); i++) {
|
||||
if (op_this->param_list_[i].type == tim::vx::DataType::FLOAT32) {
|
||||
node_params[input_start++] = vsi_nn_kernel_scalar_create(
|
||||
self->graph, F32, &(op_this->param_list_[i].data.f));
|
||||
} else if (op_this->param_list_[i].type == tim::vx::DataType::UINT32) {
|
||||
node_params[input_start++] = vsi_nn_kernel_scalar_create(
|
||||
self->graph, U32, &(op_this->param_list_[i].data.ui));
|
||||
} else if (op_this->param_list_[i].type == tim::vx::DataType::INT32) {
|
||||
node_params[input_start++] = vsi_nn_kernel_scalar_create(
|
||||
self->graph, I32, &(op_this->param_list_[i].data.i));
|
||||
} else if (op_this->param_list_[i].type == tim::vx::DataType::BOOL8) {
|
||||
node_params[input_start++] = vsi_nn_kernel_scalar_create(
|
||||
self->graph, BOOL8, &(op_this->param_list_[i].data.b));
|
||||
}else if (op_this->param_list_[i].type == tim::vx::DataType::UINT8) {
|
||||
node_params[input_start++] = vsi_nn_kernel_scalar_create(
|
||||
self->graph, U8, &(op_this->param_list_[i].data.b));
|
||||
} else if (op_this->param_list_[i].type == tim::vx::DataType::INT8) {
|
||||
node_params[input_start++] = vsi_nn_kernel_scalar_create(
|
||||
self->graph, I8, &(op_this->param_list_[i].data.b));
|
||||
} else{
|
||||
std::cout << "Can not find scalar type in op compute" << std::endl;
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
input_start = op_this->input_num_ + op_this->output_num_;
|
||||
status = vsi_nn_kernel_node_pass_param(node, node_params, param_num);
|
||||
for (uint32_t i = 0; i < param_num; i++) {
|
||||
vsi_nn_kernel_scalar_release(&node_params[input_start + i]);
|
||||
}
|
||||
}
|
||||
self->n = (vx_node)node;
|
||||
|
||||
node_base_map_.insert(std::pair<void*, CustomOpBase*>(reinterpret_cast<void*>(self->n), op_this));
|
||||
op_this->vx_node_ = reinterpret_cast<void*>(self->n);
|
||||
return status;
|
||||
}
|
||||
|
||||
vx_status derive_kernel_init(vx_node node, const vx_reference* param,
|
||||
vx_uint32 param_size) {
|
||||
vsi_status status = VSI_FAILURE;
|
||||
if (param_size == 0 && param == nullptr) {
|
||||
return status;
|
||||
}
|
||||
|
||||
gpu_param_t gpu_param = {3, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
|
||||
|
||||
std::vector<size_t> global_size(3);
|
||||
std::vector<size_t> local_size(3);
|
||||
uint32_t dim = 0;
|
||||
|
||||
auto iter = node_base_map_.find(reinterpret_cast<void*>(node));
|
||||
if (iter != node_base_map_.end()) {
|
||||
iter->second->setup_kernel_param(dim, global_size, local_size);
|
||||
} else {
|
||||
std::cout << "Something wrong in finding gpu param setup function"
|
||||
<< std::endl;
|
||||
assert(false);
|
||||
}
|
||||
|
||||
gpu_param.dim = dim;
|
||||
gpu_param.global_scale[0] = 1;
|
||||
gpu_param.global_scale[1] = 1;
|
||||
gpu_param.global_scale[2] = 1;
|
||||
|
||||
gpu_param.global_size[0] = global_size[0];
|
||||
gpu_param.global_size[1] = global_size[1];
|
||||
gpu_param.global_size[2] = global_size[2];
|
||||
|
||||
gpu_param.local_size[0] = local_size[0];
|
||||
gpu_param.local_size[1] = local_size[1];
|
||||
gpu_param.local_size[2] = local_size[2];
|
||||
status = vsi_nn_kernel_gpu_config(node, &gpu_param);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
Loading…
Reference in New Issue