Fix build issue (#397)

Signed-off-by: xiang.zhang <xiang.zhang@verisilicon.com>
This commit is contained in:
Sven 2022-05-16 14:24:44 +08:00 committed by GitHub
parent 4f2991c853
commit a9764291b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 25 additions and 3 deletions

View File

@ -34,6 +34,11 @@ if(${TIM_VX_ENABLE_40BIT})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DVSI_40BIT_VA_SUPPORT") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DVSI_40BIT_VA_SUPPORT")
endif() endif()
if(${TIM_VX_ENABLE_CUSTOM_OP})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DTIM_VX_ENABLE_CUSTOM_OP")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTIM_VX_ENABLE_CUSTOM_OP")
endif()
set(CMAKE_C_VISIBILITY_PRESET hidden) set(CMAKE_C_VISIBILITY_PRESET hidden)
if(EXTERNAL_VIV_SDK AND EXISTS ${EXTERNAL_VIV_SDK}) if(EXTERNAL_VIV_SDK AND EXISTS ${EXTERNAL_VIV_SDK})

View File

@ -75,6 +75,7 @@ void DirectMapOpImpl::SetRoundingPolicy(OverflowPolicy overflow_policy,
node_->vx_param.accumulator_bits = accumulator_bits; node_->vx_param.accumulator_bits = accumulator_bits;
} }
#ifdef TIM_VX_ENABLE_CUSTOM_OP
CustomOpBaseImpl::CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc, CustomOpBaseImpl::CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc,
const char* kernel_name, DataLayout layout) const char* kernel_name, DataLayout layout)
: DirectMapOpImpl(graph, layout) { : DirectMapOpImpl(graph, layout) {
@ -85,6 +86,7 @@ CustomOpBaseImpl::CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const vo
SetNode(node); SetNode(node);
SetRoundingPolicy(); SetRoundingPolicy();
}; };
#endif
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim

View File

@ -62,6 +62,7 @@ class DirectMapOpImpl : public OpImpl {
vsi_nn_node_t* node_{nullptr}; vsi_nn_node_t* node_{nullptr};
}; };
#ifdef TIM_VX_ENABLE_CUSTOM_OP
class CustomOpBaseImpl : public DirectMapOpImpl { class CustomOpBaseImpl : public DirectMapOpImpl {
public: public:
CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc, CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc,
@ -69,6 +70,7 @@ class CustomOpBaseImpl : public DirectMapOpImpl {
protected: protected:
const void* op_proc_; const void* op_proc_;
}; };
#endif
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim

View File

@ -135,6 +135,7 @@ std::shared_ptr<Operation> Gelu::Clone(std::shared_ptr<Graph>& graph) const {
this->impl()->node()->nn_param.gelu.approximate); this->impl()->node()->nn_param.gelu.approximate);
} }
#ifdef _VSI_NN_OP_SELU_H
Selu::Selu(Graph* graph, float alpha, float gamma) Selu::Selu(Graph* graph, float alpha, float gamma)
: DirectMapOp(graph, VSI_NN_OP_SELU), alpha_(alpha), gamma_(gamma) { : DirectMapOp(graph, VSI_NN_OP_SELU), alpha_(alpha), gamma_(gamma) {
this->impl()->node()->nn_param.selu.alpha = alpha; this->impl()->node()->nn_param.selu.alpha = alpha;
@ -144,7 +145,9 @@ Selu::Selu(Graph* graph, float alpha, float gamma)
std::shared_ptr<Operation> Selu::Clone(std::shared_ptr<Graph>& graph) const { std::shared_ptr<Operation> Selu::Clone(std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Selu>(this->alpha_, this->gamma_); return graph->CreateOperation<Selu>(this->alpha_, this->gamma_);
} }
#endif
#ifdef _VSI_NN_OP_CELU_H
Celu::Celu(Graph* graph, float alpha) Celu::Celu(Graph* graph, float alpha)
: DirectMapOp(graph, VSI_NN_OP_CELU), alpha_(alpha) { : DirectMapOp(graph, VSI_NN_OP_CELU), alpha_(alpha) {
this->impl()->node()->nn_param.selu.alpha = alpha; this->impl()->node()->nn_param.selu.alpha = alpha;
@ -153,6 +156,7 @@ Celu::Celu(Graph* graph, float alpha)
std::shared_ptr<Operation> Celu::Clone(std::shared_ptr<Graph>& graph) const { std::shared_ptr<Operation> Celu::Clone(std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Selu>(this->alpha_); return graph->CreateOperation<Selu>(this->alpha_);
} }
#endif
} // namespace ops } // namespace ops
} // namespace vx } // namespace vx

View File

@ -331,6 +331,7 @@ TEST(Elu, shape_5_1_fp32_a) {
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
} }
#ifdef _VSI_NN_OP_SELU_H
TEST(Selu, shape_2_2) { TEST(Selu, shape_2_2) {
auto ctx = tim::vx::Context::Create(); auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph(); auto graph = ctx->CreateGraph();
@ -363,7 +364,9 @@ TEST(Selu, shape_2_2) {
EXPECT_TRUE(out_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(out_tensor->CopyDataFromTensor(output.data()));
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
} }
#endif
#ifdef _VSI_NN_OP_CELU_H
TEST(Celu, shape_2_2) { TEST(Celu, shape_2_2) {
auto ctx = tim::vx::Context::Create(); auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph(); auto graph = ctx->CreateGraph();
@ -396,3 +399,4 @@ TEST(Celu, shape_2_2) {
EXPECT_TRUE(out_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(out_tensor->CopyDataFromTensor(output.data()));
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
} }
#endif

View File

@ -21,7 +21,7 @@
* DEALINGS IN THE SOFTWARE. * DEALINGS IN THE SOFTWARE.
* *
*****************************************************************************/ *****************************************************************************/
#ifdef TIM_VX_ENABLE_CUSTOM_OP
#include <map> #include <map>
#include <assert.h> #include <assert.h>
#include "tim/vx/ops.h" #include "tim/vx/ops.h"
@ -251,3 +251,4 @@ vx_status derive_kernel_init(vx_node node, const vx_reference* param,
} // namespace ops } // namespace ops
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim
#endif

View File

@ -29,7 +29,7 @@
namespace tim { namespace tim {
namespace vx { namespace vx {
namespace ops { namespace ops {
#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H
Gather_elements::Gather_elements(Graph* graph, int axis) Gather_elements::Gather_elements(Graph* graph, int axis)
: DirectMapOp(graph, VSI_NN_OP_GATHER_ELEMENTS), axis_(axis) { : DirectMapOp(graph, VSI_NN_OP_GATHER_ELEMENTS), axis_(axis) {
this->impl()->node()->nn_param.gather_elements.axis = axis_; this->impl()->node()->nn_param.gather_elements.axis = axis_;
@ -39,6 +39,7 @@ std::shared_ptr<Operation> Gather_elements::Clone(
std::shared_ptr<Graph>& graph) const { std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Gather_elements>(this->axis_); return graph->CreateOperation<Gather_elements>(this->axis_);
} }
#endif
} // namespace ops } // namespace ops
} // namespace vx } // namespace vx

View File

@ -28,6 +28,8 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "test_utils.h" #include "test_utils.h"
#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H
TEST(Gather_elements, shape_3_2_1_int32_axis_0) { TEST(Gather_elements, shape_3_2_1_int32_axis_0) {
auto ctx = tim::vx::Context::Create(); auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph(); auto graph = ctx->CreateGraph();
@ -168,3 +170,4 @@ TEST(Gather_elements, shape_3_2_1_float32_axis_2) {
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
EXPECT_EQ(golden, output); EXPECT_EQ(golden, output);
} }
#endif