add Swish op

This commit is contained in:
jing.tang 2021-08-16 12:02:21 +08:00 committed by Sven
parent 4d53e042c8
commit a364c3eafb
3 changed files with 15 additions and 1 deletions

View File

@ -47,6 +47,8 @@ namespace ops {
*
* Sigmoid(x) : 1/(1 + e^{-x})
*
* Swish(x) : x * sigmoid(x)
*
* HardSwish(x) : 0 if x <= -3; x(x + 3)/6 if -3 < x < 3; x if x >= 3
*
* Mish(x) : x if x >= 0 else alpha * x
@ -78,6 +80,7 @@ DECLARE_NO_PARAMETER_ACTIVATION(Relu6)
DECLARE_NO_PARAMETER_ACTIVATION(Elu)
DECLARE_NO_PARAMETER_ACTIVATION(Tanh)
DECLARE_NO_PARAMETER_ACTIVATION(Sigmoid)
DECLARE_NO_PARAMETER_ACTIVATION(Swish)
DECLARE_NO_PARAMETER_ACTIVATION(HardSwish)
DECLARE_NO_PARAMETER_ACTIVATION(Mish)
DECLARE_NO_PARAMETER_ACTIVATION(HardSigmoid)

View File

@ -81,7 +81,8 @@ Sin|SIN|Mapped|[tf.math.sin](https://tensorflow.google.cn/api_docs/python/tf/mat
Log|LOG|Mapped|[tf.math.log](https://tensorflow.google.cn/api_docs/python/tf/math/log)
ArgMin|ARGMIN|Mapped|[tf.math.argmin](https://tensorflow.google.cn/api_docs/python/tf/math/argmin)
LogSoftmax|LOG_SOFTMAX|Mapped|[tf.nn.log_softmax](https://tensorflow.google.cn/api_docs/python/tf/nn/log_softmax)
HardSwish|SWISH|Mapped|[tf.keras.activations.swish](https://tensorflow.google.cn/api_docs/python/tf/keras/activations/swish)
Swish|SWISH|Mapped|[tf.keras.activations.swish](https://tensorflow.google.cn/api_docs/python/tf/keras/activations/swish)
HardSwish|SWISH|Mapped|[torch.nn.Hardswish](https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html)
GatherNd|GATHER_ND|Mapped|[tf.gather_nd](https://tensorflow.google.cn/api_docs/python/tf/gather_nd)
Cast|CAST|Mapped|[tf.cast](https://tensorflow.google.cn/api_docs/python/tf/cast)
Moments|MOMENTS|Mapped|[tf.moments](https://tensorflow.google.cn/api_docs/python/tf/nn/moments)

View File

@ -59,6 +59,16 @@ std::shared_ptr<Operation> HardSwish::Clone(
return graph->CreateOperation<HardSwish>();
}
Swish::Swish(Graph* graph) : Operation(graph, VSI_NN_OP_SWISH) {
this->impl()->node()->nn_param.swish.type = VSI_NN_SWISH;
this->impl()->node()->nn_param.swish.beta = 1.0f;
}
std::shared_ptr<Operation> Swish::Clone(
std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Swish>();
}
Prelu::Prelu(Graph* graph, int axis)
: Operation(graph, VSI_NN_OP_PRELU), axis_(axis) {
this->impl()->node()->nn_param.prelu.axis = axis_;