From d0af7ae8df830c98323690ae5806d6af53dbb2c1 Mon Sep 17 00:00:00 2001 From: chxin66 <57057788+chxin66@users.noreply.github.com> Date: Mon, 11 Apr 2022 19:04:30 +0800 Subject: [PATCH] Support alpha in elu (#354) https://github.com/VeriSilicon/TIM-VX/issues/305 Signed-off-by: Chen Xin --- include/tim/vx/ops/activations.h | 13 ++++++-- src/tim/vx/ops/activations.cc | 28 +++++++++++------ src/tim/vx/ops/activations_test.cc | 50 +++++++++++++++++++++++------- 3 files changed, 67 insertions(+), 24 deletions(-) diff --git a/include/tim/vx/ops/activations.h b/include/tim/vx/ops/activations.h index e8f3fd4..782ce9a 100644 --- a/include/tim/vx/ops/activations.h +++ b/include/tim/vx/ops/activations.h @@ -69,7 +69,7 @@ namespace ops { */ #define DECLARE_NO_PARAMETER_ACTIVATION(NAME) \ - class NAME : public DirectMapOp { \ + class NAME : public DirectMapOp { \ public: \ NAME(Graph* graph); \ std::shared_ptr Clone( \ @@ -79,7 +79,6 @@ namespace ops { DECLARE_NO_PARAMETER_ACTIVATION(Relu) DECLARE_NO_PARAMETER_ACTIVATION(Relu1) DECLARE_NO_PARAMETER_ACTIVATION(Relu6) -DECLARE_NO_PARAMETER_ACTIVATION(Elu) DECLARE_NO_PARAMETER_ACTIVATION(Tanh) DECLARE_NO_PARAMETER_ACTIVATION(Sigmoid) DECLARE_NO_PARAMETER_ACTIVATION(Swish) @@ -89,6 +88,16 @@ DECLARE_NO_PARAMETER_ACTIVATION(SoftRelu) #undef DEFINE_NO_PARAMETER_ACTIVATION +class Elu : public DirectMapOp { + public: + Elu(Graph* graph, float alpha); + std::shared_ptr Clone( + std::shared_ptr& graph) const override; + + protected: + float alpha_; +}; + class Prelu : public DirectMapOp { public: Prelu(Graph* graph, int axis); diff --git a/src/tim/vx/ops/activations.cc b/src/tim/vx/ops/activations.cc index 3aca7d0..7688c7f 100644 --- a/src/tim/vx/ops/activations.cc +++ b/src/tim/vx/ops/activations.cc @@ -31,7 +31,7 @@ namespace vx { namespace ops { #define DEFINE_NO_PARAMETER_ACTIVATION(NAME, VSI_OP_CODE) \ - NAME::NAME(Graph* graph) : DirectMapOp(graph, VSI_OP_CODE) {} \ + NAME::NAME(Graph* graph) : DirectMapOp(graph, VSI_OP_CODE) {} \ std::shared_ptr NAME::Clone(std::shared_ptr& graph) \ const { \ return graph->CreateOperation(); \ @@ -40,14 +40,21 @@ namespace ops { DEFINE_NO_PARAMETER_ACTIVATION(Relu, VSI_NN_OP_RELU) DEFINE_NO_PARAMETER_ACTIVATION(Relu1, VSI_NN_OP_RELU1) DEFINE_NO_PARAMETER_ACTIVATION(Relu6, VSI_NN_OP_RELU6) -DEFINE_NO_PARAMETER_ACTIVATION(Elu, VSI_NN_OP_ELU) DEFINE_NO_PARAMETER_ACTIVATION(Sigmoid, VSI_NN_OP_SIGMOID) DEFINE_NO_PARAMETER_ACTIVATION(Mish, VSI_NN_OP_MISH) DEFINE_NO_PARAMETER_ACTIVATION(SoftRelu, VSI_NN_OP_SOFTRELU) - #undef DEFINE_NO_PARAMETER_ACTIVATION +Elu::Elu(Graph* graph, float alpha) + : DirectMapOp(graph, VSI_NN_OP_ELU), alpha_(alpha) { + this->impl()->node()->nn_param.elu.alpha = alpha_; +} + +std::shared_ptr Elu::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->alpha_); +} + HardSwish::HardSwish(Graph* graph) : DirectMapOp(graph, VSI_NN_OP_SWISH) { this->impl()->node()->nn_param.swish.type = VSI_NN_HSWISH; this->impl()->node()->nn_param.swish.beta = 1.0f; @@ -63,8 +70,7 @@ Swish::Swish(Graph* graph) : DirectMapOp(graph, VSI_NN_OP_SWISH) { this->impl()->node()->nn_param.swish.beta = 1.0f; } -std::shared_ptr Swish::Clone( - std::shared_ptr& graph) const { +std::shared_ptr Swish::Clone(std::shared_ptr& graph) const { return graph->CreateOperation(); } @@ -83,7 +89,8 @@ HardSigmoid::HardSigmoid(Graph* graph, float alpha, float beta) this->impl()->node()->nn_param.hard_sigmoid.beta = beta_; } -std::shared_ptr HardSigmoid::Clone(std::shared_ptr& graph) const { +std::shared_ptr HardSigmoid::Clone( + std::shared_ptr& graph) const { return graph->CreateOperation(this->alpha_, this->beta_); } @@ -117,12 +124,13 @@ std::shared_ptr Linear::Clone(std::shared_ptr& graph) const { } Gelu::Gelu(Graph* graph, bool approximate) - : DirectMapOp(graph, VSI_NN_OP_GELU){ - this->impl()->node()->nn_param.gelu.approximate = approximate; - } + : DirectMapOp(graph, VSI_NN_OP_GELU) { + this->impl()->node()->nn_param.gelu.approximate = approximate; +} std::shared_ptr Gelu::Clone(std::shared_ptr& graph) const { - return graph->CreateOperation(this->impl()->node()->nn_param.gelu.approximate); + return graph->CreateOperation( + this->impl()->node()->nn_param.gelu.approximate); } } // namespace ops diff --git a/src/tim/vx/ops/activations_test.cc b/src/tim/vx/ops/activations_test.cc index 62331e8..3adb245 100644 --- a/src/tim/vx/ops/activations_test.cc +++ b/src/tim/vx/ops/activations_test.cc @@ -186,8 +186,8 @@ TEST(HardSigmoid, shape_5_1_uint8_Quantized) { tim::vx::ShapeType in_shape({20, 1}); tim::vx::ShapeType out_shape({20, 1}); - std::vector scalesInput = {0.00228914}; //scale - std::vector zeroPointsInput = {128}; //zero point + std::vector scalesInput = {0.00228914}; //scale + std::vector zeroPointsInput = {128}; //zero point std::vector scalesOutput = {0.005}; std::vector zeroPointsOutput = {128}; @@ -207,17 +207,14 @@ TEST(HardSigmoid, shape_5_1_uint8_Quantized) { auto input_tensor = graph->CreateTensor(input_spec); auto output_tensor = graph->CreateTensor(output_spec); - std::vector in_data = {65, 255, 140, 92, 142, - 122, 117, 167, 132, 117, - 44, 99, 109, 96, 216, - 222, 135, 126, 113, 100}; - std::vector golden_data = {222, 240, 229, 225, 229, - 227, 227, 232, 228, 227, - 220, 225, 226, 225, 236, - 237, 229, 228, 227, 225}; + std::vector in_data = {65, 255, 140, 92, 142, 122, 117, + 167, 132, 117, 44, 99, 109, 96, + 216, 222, 135, 126, 113, 100}; + std::vector golden_data = {222, 240, 229, 225, 229, 227, 227, + 232, 228, 227, 220, 225, 226, 225, + 236, 237, 229, 228, 227, 225}; - EXPECT_TRUE( - input_tensor->CopyDataToTensor(in_data.data(), in_data.size())); + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size())); auto op = graph->CreateOperation(0.2, 0.5); (*op).BindInput(input_tensor).BindOutput(output_tensor); @@ -228,3 +225,32 @@ TEST(HardSigmoid, shape_5_1_uint8_Quantized) { EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(ArraysMatch(golden_data, output, (uint8_t)1)); } + +TEST(Elu, shape_5_1_fp32) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType io_shape({5, 1}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, io_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, io_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = {-2.5, -0.1, 0, 0.55, 99}; + std::vector golden = {-0.458957, -0.0475813, 0, 0.55, 99}; + + EXPECT_TRUE( + input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4)); + + auto op = graph->CreateOperation(0.5); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output(5, 0); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); +}