Fix build issue (#397)

Signed-off-by: xiang.zhang <xiang.zhang@verisilicon.com>
This commit is contained in:
Sven 2022-05-16 14:24:44 +08:00 committed by GitHub
parent 4f2991c853
commit a9764291b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 25 additions and 3 deletions

View File

@ -34,6 +34,11 @@ if(${TIM_VX_ENABLE_40BIT})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DVSI_40BIT_VA_SUPPORT")
endif()
if(${TIM_VX_ENABLE_CUSTOM_OP})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DTIM_VX_ENABLE_CUSTOM_OP")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTIM_VX_ENABLE_CUSTOM_OP")
endif()
set(CMAKE_C_VISIBILITY_PRESET hidden)
if(EXTERNAL_VIV_SDK AND EXISTS ${EXTERNAL_VIV_SDK})

View File

@ -75,6 +75,7 @@ void DirectMapOpImpl::SetRoundingPolicy(OverflowPolicy overflow_policy,
node_->vx_param.accumulator_bits = accumulator_bits;
}
#ifdef TIM_VX_ENABLE_CUSTOM_OP
CustomOpBaseImpl::CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc,
const char* kernel_name, DataLayout layout)
: DirectMapOpImpl(graph, layout) {
@ -85,6 +86,7 @@ CustomOpBaseImpl::CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const vo
SetNode(node);
SetRoundingPolicy();
};
#endif
} // namespace vx
} // namespace tim

View File

@ -62,6 +62,7 @@ class DirectMapOpImpl : public OpImpl {
vsi_nn_node_t* node_{nullptr};
};
#ifdef TIM_VX_ENABLE_CUSTOM_OP
class CustomOpBaseImpl : public DirectMapOpImpl {
public:
CustomOpBaseImpl(Graph* graph, uint32_t operation_id, const void* proc,
@ -69,6 +70,7 @@ class CustomOpBaseImpl : public DirectMapOpImpl {
protected:
const void* op_proc_;
};
#endif
} // namespace vx
} // namespace tim

View File

@ -135,6 +135,7 @@ std::shared_ptr<Operation> Gelu::Clone(std::shared_ptr<Graph>& graph) const {
this->impl()->node()->nn_param.gelu.approximate);
}
#ifdef _VSI_NN_OP_SELU_H
Selu::Selu(Graph* graph, float alpha, float gamma)
: DirectMapOp(graph, VSI_NN_OP_SELU), alpha_(alpha), gamma_(gamma) {
this->impl()->node()->nn_param.selu.alpha = alpha;
@ -144,7 +145,9 @@ Selu::Selu(Graph* graph, float alpha, float gamma)
std::shared_ptr<Operation> Selu::Clone(std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Selu>(this->alpha_, this->gamma_);
}
#endif
#ifdef _VSI_NN_OP_CELU_H
Celu::Celu(Graph* graph, float alpha)
: DirectMapOp(graph, VSI_NN_OP_CELU), alpha_(alpha) {
this->impl()->node()->nn_param.selu.alpha = alpha;
@ -153,6 +156,7 @@ Celu::Celu(Graph* graph, float alpha)
std::shared_ptr<Operation> Celu::Clone(std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Selu>(this->alpha_);
}
#endif
} // namespace ops
} // namespace vx

View File

@ -331,6 +331,7 @@ TEST(Elu, shape_5_1_fp32_a) {
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
}
#ifdef _VSI_NN_OP_SELU_H
TEST(Selu, shape_2_2) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
@ -363,7 +364,9 @@ TEST(Selu, shape_2_2) {
EXPECT_TRUE(out_tensor->CopyDataFromTensor(output.data()));
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
}
#endif
#ifdef _VSI_NN_OP_CELU_H
TEST(Celu, shape_2_2) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
@ -396,3 +399,4 @@ TEST(Celu, shape_2_2) {
EXPECT_TRUE(out_tensor->CopyDataFromTensor(output.data()));
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
}
#endif

View File

@ -21,7 +21,7 @@
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifdef TIM_VX_ENABLE_CUSTOM_OP
#include <map>
#include <assert.h>
#include "tim/vx/ops.h"
@ -250,4 +250,5 @@ vx_status derive_kernel_init(vx_node node, const vx_reference* param,
} // namespace ops
} // namespace vx
} // namespace tim
} // namespace tim
#endif

View File

@ -29,7 +29,7 @@
namespace tim {
namespace vx {
namespace ops {
#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H
Gather_elements::Gather_elements(Graph* graph, int axis)
: DirectMapOp(graph, VSI_NN_OP_GATHER_ELEMENTS), axis_(axis) {
this->impl()->node()->nn_param.gather_elements.axis = axis_;
@ -39,6 +39,7 @@ std::shared_ptr<Operation> Gather_elements::Clone(
std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Gather_elements>(this->axis_);
}
#endif
} // namespace ops
} // namespace vx

View File

@ -28,6 +28,8 @@
#include "gtest/gtest.h"
#include "test_utils.h"
#ifdef _VSI_NN_OP_GATHER_ELEMENTS_H
TEST(Gather_elements, shape_3_2_1_int32_axis_0) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
@ -168,3 +170,4 @@ TEST(Gather_elements, shape_3_2_1_float32_axis_2) {
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
EXPECT_EQ(golden, output);
}
#endif