From 51faf286c29353e335d70146f9aaf32a1e2d51c8 Mon Sep 17 00:00:00 2001 From: shijie001 <25165513+shijie001@users.noreply.github.com> Date: Mon, 22 May 2023 14:13:44 +0800 Subject: [PATCH] Fixed LayerNormalization eps bug (#589) --- include/tim/vx/ops/layernormalization.h | 2 +- src/tim/vx/ops/layernormalization.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tim/vx/ops/layernormalization.h b/include/tim/vx/ops/layernormalization.h index d45aa57..0cdc46b 100644 --- a/include/tim/vx/ops/layernormalization.h +++ b/include/tim/vx/ops/layernormalization.h @@ -38,7 +38,7 @@ class LayerNormalization : public BuiltinOp { protected: int32_t axis_; - int32_t eps_; + float eps_; }; } // namespace ops diff --git a/src/tim/vx/ops/layernormalization.cc b/src/tim/vx/ops/layernormalization.cc index d111e4f..038a2c0 100644 --- a/src/tim/vx/ops/layernormalization.cc +++ b/src/tim/vx/ops/layernormalization.cc @@ -37,7 +37,7 @@ LayerNormalization::LayerNormalization(Graph* graph, int32_t axis, float eps) VSILOGE("Layer norm only support axis 0."); assert(false); } - this->impl()->node()->nn_param.instancenorm.eps = eps_; + this->impl()->node()->nn_param.layernorm.eps = eps_; } std::shared_ptr LayerNormalization::Clone(