From a364c3eafb12f790a548f6f276ea46b2feae73b9 Mon Sep 17 00:00:00 2001 From: "jing.tang" Date: Mon, 16 Aug 2021 12:02:21 +0800 Subject: [PATCH] add Swish op --- include/tim/vx/ops/activations.h | 3 +++ src/tim/vx/ops/README.md | 3 ++- src/tim/vx/ops/activations.cc | 10 ++++++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/include/tim/vx/ops/activations.h b/include/tim/vx/ops/activations.h index a4b5002..8f8a034 100644 --- a/include/tim/vx/ops/activations.h +++ b/include/tim/vx/ops/activations.h @@ -47,6 +47,8 @@ namespace ops { * * Sigmoid(x) : 1/(1 + e^{-x}) * + * Swish(x) : x * sigmoid(x) + * * HardSwish(x) : 0 if x <= -3; x(x + 3)/6 if -3 < x < 3; x if x >= 3 * * Mish(x) : x if x >= 0 else alpha * x @@ -78,6 +80,7 @@ DECLARE_NO_PARAMETER_ACTIVATION(Relu6) DECLARE_NO_PARAMETER_ACTIVATION(Elu) DECLARE_NO_PARAMETER_ACTIVATION(Tanh) DECLARE_NO_PARAMETER_ACTIVATION(Sigmoid) +DECLARE_NO_PARAMETER_ACTIVATION(Swish) DECLARE_NO_PARAMETER_ACTIVATION(HardSwish) DECLARE_NO_PARAMETER_ACTIVATION(Mish) DECLARE_NO_PARAMETER_ACTIVATION(HardSigmoid) diff --git a/src/tim/vx/ops/README.md b/src/tim/vx/ops/README.md index 65e5aac..6e307a4 100644 --- a/src/tim/vx/ops/README.md +++ b/src/tim/vx/ops/README.md @@ -81,7 +81,8 @@ Sin|SIN|Mapped|[tf.math.sin](https://tensorflow.google.cn/api_docs/python/tf/mat Log|LOG|Mapped|[tf.math.log](https://tensorflow.google.cn/api_docs/python/tf/math/log) ArgMin|ARGMIN|Mapped|[tf.math.argmin](https://tensorflow.google.cn/api_docs/python/tf/math/argmin) LogSoftmax|LOG_SOFTMAX|Mapped|[tf.nn.log_softmax](https://tensorflow.google.cn/api_docs/python/tf/nn/log_softmax) -HardSwish|SWISH|Mapped|[tf.keras.activations.swish](https://tensorflow.google.cn/api_docs/python/tf/keras/activations/swish) +Swish|SWISH|Mapped|[tf.keras.activations.swish](https://tensorflow.google.cn/api_docs/python/tf/keras/activations/swish) +HardSwish|SWISH|Mapped|[torch.nn.Hardswish](https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html) GatherNd|GATHER_ND|Mapped|[tf.gather_nd](https://tensorflow.google.cn/api_docs/python/tf/gather_nd) Cast|CAST|Mapped|[tf.cast](https://tensorflow.google.cn/api_docs/python/tf/cast) Moments|MOMENTS|Mapped|[tf.moments](https://tensorflow.google.cn/api_docs/python/tf/nn/moments) diff --git a/src/tim/vx/ops/activations.cc b/src/tim/vx/ops/activations.cc index 16439e9..81a1d36 100644 --- a/src/tim/vx/ops/activations.cc +++ b/src/tim/vx/ops/activations.cc @@ -59,6 +59,16 @@ std::shared_ptr HardSwish::Clone( return graph->CreateOperation(); } +Swish::Swish(Graph* graph) : Operation(graph, VSI_NN_OP_SWISH) { + this->impl()->node()->nn_param.swish.type = VSI_NN_SWISH; + this->impl()->node()->nn_param.swish.beta = 1.0f; +} + +std::shared_ptr Swish::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(); +} + Prelu::Prelu(Graph* graph, int axis) : Operation(graph, VSI_NN_OP_PRELU), axis_(axis) { this->impl()->node()->nn_param.prelu.axis = axis_;