diff --git a/CMakeLists.txt b/CMakeLists.txt index a266d0b..a3d0a6f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,11 @@ if(${TIM_VX_ENABLE_40BIT}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DVSI_40BIT_VA_SUPPORT") endif() +if(${TIM_VX_ENABLE_CUSTOM_OP}) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DTIM_VX_ENABLE_CUSTOM_OP") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTIM_VX_ENABLE_CUSTOM_OP") +endif() + set(CMAKE_C_VISIBILITY_PRESET hidden) if(EXTERNAL_VIV_SDK AND EXISTS ${EXTERNAL_VIV_SDK}) diff --git a/src/tim/vx/direct_map_op_impl.cc b/src/tim/vx/direct_map_op_impl.cc index 9b4371b..335014b 100644 --- a/src/tim/vx/direct_map_op_impl.cc +++ b/src/tim/vx/direct_map_op_impl.cc @@ -75,6 +75,7 @@ void DirectMapOpImpl::SetRoundingPolicy(OverflowPolicy overflow_policy, node_->vx_param.accumulator_bits = accumulator_bits; } +#ifdef TIM_VX_ENABLE_CUSTOM_OP CustomOpBaseImpl::CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc, const char* kernel_name, DataLayout layout) : DirectMapOpImpl(graph, layout) { @@ -85,6 +86,7 @@ CustomOpBaseImpl::CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const vo SetNode(node); SetRoundingPolicy(); }; +#endif } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/direct_map_op_impl.h b/src/tim/vx/direct_map_op_impl.h index ca57323..06aa723 100644 --- a/src/tim/vx/direct_map_op_impl.h +++ b/src/tim/vx/direct_map_op_impl.h @@ -62,6 +62,7 @@ class DirectMapOpImpl : public OpImpl { vsi_nn_node_t* node_{nullptr}; }; +#ifdef TIM_VX_ENABLE_CUSTOM_OP class CustomOpBaseImpl : public DirectMapOpImpl { public: CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc, @@ -69,6 +70,7 @@ class CustomOpBaseImpl : public DirectMapOpImpl { protected: const void* op_proc_; }; +#endif } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/activations.cc b/src/tim/vx/ops/activations.cc index 39e59c6..89d658c 100644 --- a/src/tim/vx/ops/activations.cc +++ b/src/tim/vx/ops/activations.cc @@ -135,6 +135,7 @@ std::shared_ptr Gelu::Clone(std::shared_ptr& graph) const { this->impl()->node()->nn_param.gelu.approximate); } +#ifdef _VSI_NN_OP_SELU_H Selu::Selu(Graph* graph, float alpha, float gamma) : DirectMapOp(graph, VSI_NN_OP_SELU), alpha_(alpha), gamma_(gamma) { this->impl()->node()->nn_param.selu.alpha = alpha; @@ -144,7 +145,9 @@ Selu::Selu(Graph* graph, float alpha, float gamma) std::shared_ptr Selu::Clone(std::shared_ptr& graph) const { return graph->CreateOperation(this->alpha_, this->gamma_); } +#endif +#ifdef _VSI_NN_OP_CELU_H Celu::Celu(Graph* graph, float alpha) : DirectMapOp(graph, VSI_NN_OP_CELU), alpha_(alpha) { this->impl()->node()->nn_param.selu.alpha = alpha; @@ -153,6 +156,7 @@ Celu::Celu(Graph* graph, float alpha) std::shared_ptr Celu::Clone(std::shared_ptr& graph) const { return graph->CreateOperation(this->alpha_); } +#endif } // namespace ops } // namespace vx diff --git a/src/tim/vx/ops/activations_test.cc b/src/tim/vx/ops/activations_test.cc index 0721664..644fd71 100644 --- a/src/tim/vx/ops/activations_test.cc +++ b/src/tim/vx/ops/activations_test.cc @@ -331,6 +331,7 @@ TEST(Elu, shape_5_1_fp32_a) { EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); } +#ifdef _VSI_NN_OP_SELU_H TEST(Selu, shape_2_2) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -363,7 +364,9 @@ TEST(Selu, shape_2_2) { EXPECT_TRUE(out_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); } +#endif +#ifdef _VSI_NN_OP_CELU_H TEST(Celu, shape_2_2) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -396,3 +399,4 @@ TEST(Celu, shape_2_2) { EXPECT_TRUE(out_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); } +#endif diff --git a/src/tim/vx/ops/custom_base.cc b/src/tim/vx/ops/custom_base.cc index d26aa13..8c538da 100644 --- a/src/tim/vx/ops/custom_base.cc +++ b/src/tim/vx/ops/custom_base.cc @@ -21,7 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ - +#ifdef TIM_VX_ENABLE_CUSTOM_OP #include #include #include "tim/vx/ops.h" @@ -250,4 +250,5 @@ vx_status derive_kernel_init(vx_node node, const vx_reference* param, } // namespace ops } // namespace vx -} // namespace tim \ No newline at end of file +} // namespace tim +#endif \ No newline at end of file diff --git a/src/tim/vx/ops/gather_elements.cc b/src/tim/vx/ops/gather_elements.cc index 0c02f5f..89876a9 100644 --- a/src/tim/vx/ops/gather_elements.cc +++ b/src/tim/vx/ops/gather_elements.cc @@ -29,7 +29,7 @@ namespace tim { namespace vx { namespace ops { - +#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H Gather_elements::Gather_elements(Graph* graph, int axis) : DirectMapOp(graph, VSI_NN_OP_GATHER_ELEMENTS), axis_(axis) { this->impl()->node()->nn_param.gather_elements.axis = axis_; @@ -39,6 +39,7 @@ std::shared_ptr Gather_elements::Clone( std::shared_ptr& graph) const { return graph->CreateOperation(this->axis_); } +#endif } // namespace ops } // namespace vx diff --git a/src/tim/vx/ops/gather_elements_test.cc b/src/tim/vx/ops/gather_elements_test.cc index 92195dc..6977518 100644 --- a/src/tim/vx/ops/gather_elements_test.cc +++ b/src/tim/vx/ops/gather_elements_test.cc @@ -28,6 +28,8 @@ #include "gtest/gtest.h" #include "test_utils.h" +#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H + TEST(Gather_elements, shape_3_2_1_int32_axis_0) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -168,3 +170,4 @@ TEST(Gather_elements, shape_3_2_1_float32_axis_2) { EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_EQ(golden, output); } +#endif