Refine UnidirectionalGRU and GRUCell (#587)

Refine unidirectional_gru and gru_cell code to avoid including ovxlib files
in header of some op
Introduce TranslateToVsibool function to support above code refinement

Type: Code Improvement

Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
Chen Feiyue 2023-05-15 16:44:51 +08:00 committed by GitHub
parent b81f7979fa
commit 3c372dd646
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 17 additions and 17 deletions

View File

@ -26,7 +26,6 @@
#include <array> #include <array>
#include "tim/vx/builtin_op.h" #include "tim/vx/builtin_op.h"
#include "vsi_nn_pub.h"
namespace tim { namespace tim {
namespace vx { namespace vx {
@ -56,7 +55,7 @@ class GRUCell : public BuiltinOp {
GRUCell(Graph* graph, uint32_t num_units, GRUCell(Graph* graph, uint32_t num_units,
ActivationType activation = ActivationType::kTANH, ActivationType activation = ActivationType::kTANH,
ActivationType recurrent_activation = ActivationType::kSIGMOID, ActivationType recurrent_activation = ActivationType::kSIGMOID,
vsi_bool reset_after = TRUE); bool reset_after = true);
std::shared_ptr<Operation> Clone( std::shared_ptr<Operation> Clone(
std::shared_ptr<Graph>& graph) const override; std::shared_ptr<Graph>& graph) const override;
@ -65,7 +64,7 @@ class GRUCell : public BuiltinOp {
const uint32_t num_units_; const uint32_t num_units_;
const ActivationType activation_; const ActivationType activation_;
const ActivationType recurrent_activation_; const ActivationType recurrent_activation_;
const int32_t reset_after_; const bool reset_after_;
}; };
} // namespace ops } // namespace ops

View File

@ -26,7 +26,6 @@
#include <array> #include <array>
#include "tim/vx/builtin_op.h" #include "tim/vx/builtin_op.h"
#include "vsi_nn_pub.h"
namespace tim { namespace tim {
namespace vx { namespace vx {
@ -61,9 +60,9 @@ class UnidirectionalSequenceGRU : public BuiltinOp {
Graph* graph, uint32_t num_units, Graph* graph, uint32_t num_units,
ActivationType activation = ActivationType::kTANH, ActivationType activation = ActivationType::kTANH,
ActivationType recurrent_activation = ActivationType::kSIGMOID, ActivationType recurrent_activation = ActivationType::kSIGMOID,
vsi_bool reset_after = TRUE, bool reset_after = true,
vsi_bool return_sequences = FALSE, /*False: only return last state*/ bool return_sequences = false, /*False: only return last state*/
vsi_bool time_major = TRUE); bool time_major = true);
std::shared_ptr<Operation> Clone( std::shared_ptr<Operation> Clone(
std::shared_ptr<Graph>& graph) const override; std::shared_ptr<Graph>& graph) const override;
@ -72,9 +71,9 @@ class UnidirectionalSequenceGRU : public BuiltinOp {
const uint32_t num_units_; const uint32_t num_units_;
const ActivationType activation_; const ActivationType activation_;
const ActivationType recurrent_activation_; const ActivationType recurrent_activation_;
const int32_t reset_after_; const bool reset_after_;
const int32_t return_sequences_; const bool return_sequences_;
const int32_t time_major_; const bool time_major_;
}; };
} // namespace ops } // namespace ops

View File

@ -30,7 +30,7 @@ namespace tim {
namespace vx { namespace vx {
namespace ops { namespace ops {
GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation, GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation,
ActivationType recurrent_activation, vsi_bool reset_after) ActivationType recurrent_activation, bool reset_after)
: BuiltinOp(graph, VSI_NN_OP_GRUCELL), : BuiltinOp(graph, VSI_NN_OP_GRUCELL),
num_units_(num_units), num_units_(num_units),
activation_(activation), activation_(activation),
@ -39,7 +39,7 @@ GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation,
this->impl()->node()->nn_param.grucell.num_units = num_units; this->impl()->node()->nn_param.grucell.num_units = num_units;
this->impl()->node()->nn_param.grucell.activation = activation; this->impl()->node()->nn_param.grucell.activation = activation;
this->impl()->node()->nn_param.grucell.recurrent_activation = recurrent_activation; this->impl()->node()->nn_param.grucell.recurrent_activation = recurrent_activation;
this->impl()->node()->nn_param.grucell.reset_after = reset_after; this->impl()->node()->nn_param.grucell.reset_after = TranslateToVsibool(reset_after);
} }
std::shared_ptr<Operation> GRUCell::Clone(std::shared_ptr<Graph>& graph) const { std::shared_ptr<Operation> GRUCell::Clone(std::shared_ptr<Graph>& graph) const {

View File

@ -31,8 +31,8 @@ namespace vx {
namespace ops { namespace ops {
UnidirectionalSequenceGRU::UnidirectionalSequenceGRU( UnidirectionalSequenceGRU::UnidirectionalSequenceGRU(
Graph* graph, uint32_t num_units, ActivationType activation, Graph* graph, uint32_t num_units, ActivationType activation,
ActivationType recurrent_activation, vsi_bool reset_after, ActivationType recurrent_activation, bool reset_after,
vsi_bool return_sequences, vsi_bool time_major) bool return_sequences, bool time_major)
: BuiltinOp(graph, VSI_NN_OP_GRU), : BuiltinOp(graph, VSI_NN_OP_GRU),
num_units_(num_units), num_units_(num_units),
activation_(activation), activation_(activation),
@ -43,9 +43,9 @@ UnidirectionalSequenceGRU::UnidirectionalSequenceGRU(
this->impl()->node()->nn_param.gru.num_units = num_units; this->impl()->node()->nn_param.gru.num_units = num_units;
this->impl()->node()->nn_param.gru.activation = activation; this->impl()->node()->nn_param.gru.activation = activation;
this->impl()->node()->nn_param.gru.recurrent_activation = recurrent_activation; this->impl()->node()->nn_param.gru.recurrent_activation = recurrent_activation;
this->impl()->node()->nn_param.gru.reset_after = reset_after; this->impl()->node()->nn_param.gru.reset_after = TranslateToVsibool(reset_after);
this->impl()->node()->nn_param.gru.return_sequences = return_sequences; this->impl()->node()->nn_param.gru.return_sequences = TranslateToVsibool(return_sequences);
this->impl()->node()->nn_param.gru.time_major = time_major; this->impl()->node()->nn_param.gru.time_major = TranslateToVsibool(time_major);
} }
std::shared_ptr<Operation> UnidirectionalSequenceGRU::Clone( std::shared_ptr<Operation> UnidirectionalSequenceGRU::Clone(

View File

@ -163,6 +163,7 @@ vsi_enum TranslateResizeType(ResizeType type) {
} }
vx_bool_e ToVxBool(bool val) { return val ? vx_true_e : vx_false_e; } vx_bool_e ToVxBool(bool val) { return val ? vx_true_e : vx_false_e; }
vsi_bool TranslateToVsibool(bool val) { return val? TRUE : FALSE; }
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim

View File

@ -39,6 +39,7 @@ vsi_enum TranslateRoundingPolicy(RoundingPolicy type);
vsi_enum TranslateDownScaleSizeRounding(RoundType type); vsi_enum TranslateDownScaleSizeRounding(RoundType type);
vsi_enum TranslateResizeType(ResizeType type); vsi_enum TranslateResizeType(ResizeType type);
vx_bool_e ToVxBool(bool val); vx_bool_e ToVxBool(bool val);
vsi_bool TranslateToVsibool(bool val);
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim