diff --git a/include/tim/vx/ops/activations.h b/include/tim/vx/ops/activations.h index 3025aac..3e93300 100644 --- a/include/tim/vx/ops/activations.h +++ b/include/tim/vx/ops/activations.h @@ -65,6 +65,10 @@ namespace ops { * Linear(x, a, b) : a*x + b. * * Gelu(x) : x * P(X <= x), where P(x) ~ N(0, 1). https://tensorflow.google.cn/api_docs/python/tf/nn/gelu + * + * Selu(x, alpha, gamma) : gamma * x if(x>=0), gamma * alpha * (exp(x)-1) x<0 + * + * Celu(x, alpha) : x if x >= 0; alpha * (exp(x/alpha) - 1) * ``` */ @@ -152,6 +156,29 @@ class Gelu : public DirectMapOp { std::shared_ptr& graph) const override; }; +class Selu : public DirectMapOp { + public: + Selu(Graph* graph, float alpha = 1.67326, float gamma = 1.0507); + + std::shared_ptr Clone( + std::shared_ptr& graph) const override; + + protected: + float alpha_; + float gamma_; +}; + +class Celu : public DirectMapOp { + public: + Celu(Graph* graph, float alpha); + + std::shared_ptr Clone( + std::shared_ptr& graph) const override; + + protected: + float alpha_; +}; + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/activations.cc b/src/tim/vx/ops/activations.cc index 9398f92..39e59c6 100644 --- a/src/tim/vx/ops/activations.cc +++ b/src/tim/vx/ops/activations.cc @@ -135,6 +135,25 @@ std::shared_ptr Gelu::Clone(std::shared_ptr& graph) const { this->impl()->node()->nn_param.gelu.approximate); } +Selu::Selu(Graph* graph, float alpha, float gamma) + : DirectMapOp(graph, VSI_NN_OP_SELU), alpha_(alpha), gamma_(gamma) { + this->impl()->node()->nn_param.selu.alpha = alpha; + this->impl()->node()->nn_param.selu.gamma = gamma; +} + +std::shared_ptr Selu::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->alpha_, this->gamma_); +} + +Celu::Celu(Graph* graph, float alpha) + : DirectMapOp(graph, VSI_NN_OP_CELU), alpha_(alpha) { + this->impl()->node()->nn_param.selu.alpha = alpha; +} + +std::shared_ptr Celu::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->alpha_); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/activations_test.cc b/src/tim/vx/ops/activations_test.cc index 0470204..0721664 100644 --- a/src/tim/vx/ops/activations_test.cc +++ b/src/tim/vx/ops/activations_test.cc @@ -330,3 +330,69 @@ TEST(Elu, shape_5_1_fp32_a) { EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); } + +TEST(Selu, shape_2_2) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType in_shape({2, 2}); + tim::vx::ShapeType out_shape({2, 2}); + + tim::vx::TensorSpec in_spec(tim::vx::DataType::FLOAT32, in_shape, + tim::vx::TensorAttribute::INPUT); + + tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32, out_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto in_tensor = graph->CreateTensor(in_spec); + auto out_tensor = graph->CreateTensor(out_spec); + + std::vector in_data = {2, 1, 3, 10}; + std::vector golden = {2.1014, 1.0507, 3.1521, 10.507}; + + EXPECT_TRUE(in_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(); + (*op).BindInputs({in_tensor}).BindOutputs({out_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(out_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); +} + +TEST(Celu, shape_2_2) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType in_shape({2, 2}); + tim::vx::ShapeType out_shape({2, 2}); + + tim::vx::TensorSpec in_spec(tim::vx::DataType::FLOAT32, in_shape, + tim::vx::TensorAttribute::INPUT); + + tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32, out_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto in_tensor = graph->CreateTensor(in_spec); + auto out_tensor = graph->CreateTensor(out_spec); + + std::vector in_data = {-1, 0.71, 3, 10}; + std::vector golden = {-0.69762, 0.71, 3, 10}; + + EXPECT_TRUE(in_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(1.3); + (*op).BindInputs({in_tensor}).BindOutputs({out_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(out_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); +}