From 3c372dd646f9c28a8bb7a87b11dd9c336e774a14 Mon Sep 17 00:00:00 2001 From: Chen Feiyue <69809761+chenfeiyue-cfy@users.noreply.github.com> Date: Mon, 15 May 2023 16:44:51 +0800 Subject: [PATCH] Refine UnidirectionalGRU and GRUCell (#587) Refine unidirectional_gru and gru_cell code to avoid including ovxlib files in header of some op Introduce TranslateToVsibool function to support above code refinement Type: Code Improvement Signed-off-by: Feiyue Chen --- include/tim/vx/ops/grucell.h | 5 ++--- include/tim/vx/ops/unidirectional_sequence_gru.h | 13 ++++++------- src/tim/vx/ops/grucell.cc | 4 ++-- src/tim/vx/ops/unidirectional_sequence_gru.cc | 10 +++++----- src/tim/vx/type_utils.cc | 1 + src/tim/vx/type_utils.h | 1 + 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/include/tim/vx/ops/grucell.h b/include/tim/vx/ops/grucell.h index f9f3106..ed577da 100644 --- a/include/tim/vx/ops/grucell.h +++ b/include/tim/vx/ops/grucell.h @@ -26,7 +26,6 @@ #include #include "tim/vx/builtin_op.h" -#include "vsi_nn_pub.h" namespace tim { namespace vx { @@ -56,7 +55,7 @@ class GRUCell : public BuiltinOp { GRUCell(Graph* graph, uint32_t num_units, ActivationType activation = ActivationType::kTANH, ActivationType recurrent_activation = ActivationType::kSIGMOID, - vsi_bool reset_after = TRUE); + bool reset_after = true); std::shared_ptr Clone( std::shared_ptr& graph) const override; @@ -65,7 +64,7 @@ class GRUCell : public BuiltinOp { const uint32_t num_units_; const ActivationType activation_; const ActivationType recurrent_activation_; - const int32_t reset_after_; + const bool reset_after_; }; } // namespace ops diff --git a/include/tim/vx/ops/unidirectional_sequence_gru.h b/include/tim/vx/ops/unidirectional_sequence_gru.h index bbd16f6..de64a0e 100644 --- a/include/tim/vx/ops/unidirectional_sequence_gru.h +++ b/include/tim/vx/ops/unidirectional_sequence_gru.h @@ -26,7 +26,6 @@ #include #include "tim/vx/builtin_op.h" -#include "vsi_nn_pub.h" namespace tim { namespace vx { @@ -61,9 +60,9 @@ class UnidirectionalSequenceGRU : public BuiltinOp { Graph* graph, uint32_t num_units, ActivationType activation = ActivationType::kTANH, ActivationType recurrent_activation = ActivationType::kSIGMOID, - vsi_bool reset_after = TRUE, - vsi_bool return_sequences = FALSE, /*False: only return last state*/ - vsi_bool time_major = TRUE); + bool reset_after = true, + bool return_sequences = false, /*False: only return last state*/ + bool time_major = true); std::shared_ptr Clone( std::shared_ptr& graph) const override; @@ -72,9 +71,9 @@ class UnidirectionalSequenceGRU : public BuiltinOp { const uint32_t num_units_; const ActivationType activation_; const ActivationType recurrent_activation_; - const int32_t reset_after_; - const int32_t return_sequences_; - const int32_t time_major_; + const bool reset_after_; + const bool return_sequences_; + const bool time_major_; }; } // namespace ops diff --git a/src/tim/vx/ops/grucell.cc b/src/tim/vx/ops/grucell.cc index 1453156..eff40d7 100644 --- a/src/tim/vx/ops/grucell.cc +++ b/src/tim/vx/ops/grucell.cc @@ -30,7 +30,7 @@ namespace tim { namespace vx { namespace ops { GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation, - ActivationType recurrent_activation, vsi_bool reset_after) + ActivationType recurrent_activation, bool reset_after) : BuiltinOp(graph, VSI_NN_OP_GRUCELL), num_units_(num_units), activation_(activation), @@ -39,7 +39,7 @@ GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation, this->impl()->node()->nn_param.grucell.num_units = num_units; this->impl()->node()->nn_param.grucell.activation = activation; this->impl()->node()->nn_param.grucell.recurrent_activation = recurrent_activation; - this->impl()->node()->nn_param.grucell.reset_after = reset_after; + this->impl()->node()->nn_param.grucell.reset_after = TranslateToVsibool(reset_after); } std::shared_ptr GRUCell::Clone(std::shared_ptr& graph) const { diff --git a/src/tim/vx/ops/unidirectional_sequence_gru.cc b/src/tim/vx/ops/unidirectional_sequence_gru.cc index 036a141..f56b222 100644 --- a/src/tim/vx/ops/unidirectional_sequence_gru.cc +++ b/src/tim/vx/ops/unidirectional_sequence_gru.cc @@ -31,8 +31,8 @@ namespace vx { namespace ops { UnidirectionalSequenceGRU::UnidirectionalSequenceGRU( Graph* graph, uint32_t num_units, ActivationType activation, - ActivationType recurrent_activation, vsi_bool reset_after, - vsi_bool return_sequences, vsi_bool time_major) + ActivationType recurrent_activation, bool reset_after, + bool return_sequences, bool time_major) : BuiltinOp(graph, VSI_NN_OP_GRU), num_units_(num_units), activation_(activation), @@ -43,9 +43,9 @@ UnidirectionalSequenceGRU::UnidirectionalSequenceGRU( this->impl()->node()->nn_param.gru.num_units = num_units; this->impl()->node()->nn_param.gru.activation = activation; this->impl()->node()->nn_param.gru.recurrent_activation = recurrent_activation; - this->impl()->node()->nn_param.gru.reset_after = reset_after; - this->impl()->node()->nn_param.gru.return_sequences = return_sequences; - this->impl()->node()->nn_param.gru.time_major = time_major; + this->impl()->node()->nn_param.gru.reset_after = TranslateToVsibool(reset_after); + this->impl()->node()->nn_param.gru.return_sequences = TranslateToVsibool(return_sequences); + this->impl()->node()->nn_param.gru.time_major = TranslateToVsibool(time_major); } std::shared_ptr UnidirectionalSequenceGRU::Clone( diff --git a/src/tim/vx/type_utils.cc b/src/tim/vx/type_utils.cc index 383bfcb..11d5681 100644 --- a/src/tim/vx/type_utils.cc +++ b/src/tim/vx/type_utils.cc @@ -163,6 +163,7 @@ vsi_enum TranslateResizeType(ResizeType type) { } vx_bool_e ToVxBool(bool val) { return val ? vx_true_e : vx_false_e; } +vsi_bool TranslateToVsibool(bool val) { return val? TRUE : FALSE; } } // namespace vx } // namespace tim diff --git a/src/tim/vx/type_utils.h b/src/tim/vx/type_utils.h index 5536050..0454c14 100644 --- a/src/tim/vx/type_utils.h +++ b/src/tim/vx/type_utils.h @@ -39,6 +39,7 @@ vsi_enum TranslateRoundingPolicy(RoundingPolicy type); vsi_enum TranslateDownScaleSizeRounding(RoundType type); vsi_enum TranslateResizeType(ResizeType type); vx_bool_e ToVxBool(bool val); +vsi_bool TranslateToVsibool(bool val); } // namespace vx } // namespace tim