diff --git a/include/tim/vx/ops/layernormalization.h b/include/tim/vx/ops/layernormalization.h index d45aa57..0cdc46b 100644 --- a/include/tim/vx/ops/layernormalization.h +++ b/include/tim/vx/ops/layernormalization.h @@ -38,7 +38,7 @@ class LayerNormalization : public BuiltinOp { protected: int32_t axis_; - int32_t eps_; + float eps_; }; } // namespace ops diff --git a/src/tim/vx/ops/layernormalization.cc b/src/tim/vx/ops/layernormalization.cc index d111e4f..038a2c0 100644 --- a/src/tim/vx/ops/layernormalization.cc +++ b/src/tim/vx/ops/layernormalization.cc @@ -37,7 +37,7 @@ LayerNormalization::LayerNormalization(Graph* graph, int32_t axis, float eps) VSILOGE("Layer norm only support axis 0."); assert(false); } - this->impl()->node()->nn_param.instancenorm.eps = eps_; + this->impl()->node()->nn_param.layernorm.eps = eps_; } std::shared_ptr LayerNormalization::Clone(