From c033cfc582e82b7a4254d0d64bfd7a2dc30c307c Mon Sep 17 00:00:00 2001 From: chxin66 <57057788+chxin66@users.noreply.github.com> Date: Tue, 12 Apr 2022 18:42:50 +0800 Subject: [PATCH] Fixed compiler fail for elu (#358) Signed-off-by: Chen Xin Co-authored-by: Chen Xin --- include/tim/vx/ops/activations.h | 1 + src/tim/vx/ops/activations.cc | 2 ++ src/tim/vx/ops/activations_test.cc | 29 +++++++++++++++++++++++++++++ 3 files changed, 32 insertions(+) diff --git a/include/tim/vx/ops/activations.h b/include/tim/vx/ops/activations.h index 782ce9a..3025aac 100644 --- a/include/tim/vx/ops/activations.h +++ b/include/tim/vx/ops/activations.h @@ -90,6 +90,7 @@ DECLARE_NO_PARAMETER_ACTIVATION(SoftRelu) class Elu : public DirectMapOp { public: + Elu(Graph* graph); Elu(Graph* graph, float alpha); std::shared_ptr Clone( std::shared_ptr& graph) const override; diff --git a/src/tim/vx/ops/activations.cc b/src/tim/vx/ops/activations.cc index 7688c7f..9398f92 100644 --- a/src/tim/vx/ops/activations.cc +++ b/src/tim/vx/ops/activations.cc @@ -46,6 +46,8 @@ DEFINE_NO_PARAMETER_ACTIVATION(SoftRelu, VSI_NN_OP_SOFTRELU) #undef DEFINE_NO_PARAMETER_ACTIVATION +Elu::Elu(Graph* graph) : Elu(graph, 1) {} + Elu::Elu(Graph* graph, float alpha) : DirectMapOp(graph, VSI_NN_OP_ELU), alpha_(alpha) { this->impl()->node()->nn_param.elu.alpha = alpha_; diff --git a/src/tim/vx/ops/activations_test.cc b/src/tim/vx/ops/activations_test.cc index 3adb245..889e134 100644 --- a/src/tim/vx/ops/activations_test.cc +++ b/src/tim/vx/ops/activations_test.cc @@ -239,6 +239,35 @@ TEST(Elu, shape_5_1_fp32) { auto input_tensor = graph->CreateTensor(input_spec); auto output_tensor = graph->CreateTensor(output_spec); + std::vector in_data = {-2.5, -0.1, 0, 0.55, 99}; + std::vector golden = {-0.917915, -0.0951626, 0, 0.55, 99}; + + EXPECT_TRUE( + input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4)); + + auto op = graph->CreateOperation(); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output(5, 0); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); +} + +TEST(Elu, shape_5_1_fp32_a) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType io_shape({5, 1}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, io_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, io_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + std::vector in_data = {-2.5, -0.1, 0, 0.55, 99}; std::vector golden = {-0.458957, -0.0475813, 0, 0.55, 99};