Refine UnidirectionalGRU and GRUCell (#587)

Refine unidirectional_gru and gru_cell code to avoid including ovxlib files
in header of some op
Introduce TranslateToVsibool function to support above code refinement

Type: Code Improvement

Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
Chen Feiyue 2023-05-15 16:44:51 +08:00 committed by GitHub
parent b81f7979fa
commit 3c372dd646
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 17 additions and 17 deletions

View File

@ -26,7 +26,6 @@
#include <array>
#include "tim/vx/builtin_op.h"
#include "vsi_nn_pub.h"
namespace tim {
namespace vx {
@ -56,7 +55,7 @@ class GRUCell : public BuiltinOp {
GRUCell(Graph* graph, uint32_t num_units,
ActivationType activation = ActivationType::kTANH,
ActivationType recurrent_activation = ActivationType::kSIGMOID,
vsi_bool reset_after = TRUE);
bool reset_after = true);
std::shared_ptr<Operation> Clone(
std::shared_ptr<Graph>& graph) const override;
@ -65,7 +64,7 @@ class GRUCell : public BuiltinOp {
const uint32_t num_units_;
const ActivationType activation_;
const ActivationType recurrent_activation_;
const int32_t reset_after_;
const bool reset_after_;
};
} // namespace ops

View File

@ -26,7 +26,6 @@
#include <array>
#include "tim/vx/builtin_op.h"
#include "vsi_nn_pub.h"
namespace tim {
namespace vx {
@ -61,9 +60,9 @@ class UnidirectionalSequenceGRU : public BuiltinOp {
Graph* graph, uint32_t num_units,
ActivationType activation = ActivationType::kTANH,
ActivationType recurrent_activation = ActivationType::kSIGMOID,
vsi_bool reset_after = TRUE,
vsi_bool return_sequences = FALSE, /*False: only return last state*/
vsi_bool time_major = TRUE);
bool reset_after = true,
bool return_sequences = false, /*False: only return last state*/
bool time_major = true);
std::shared_ptr<Operation> Clone(
std::shared_ptr<Graph>& graph) const override;
@ -72,9 +71,9 @@ class UnidirectionalSequenceGRU : public BuiltinOp {
const uint32_t num_units_;
const ActivationType activation_;
const ActivationType recurrent_activation_;
const int32_t reset_after_;
const int32_t return_sequences_;
const int32_t time_major_;
const bool reset_after_;
const bool return_sequences_;
const bool time_major_;
};
} // namespace ops

View File

@ -30,7 +30,7 @@ namespace tim {
namespace vx {
namespace ops {
GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation,
ActivationType recurrent_activation, vsi_bool reset_after)
ActivationType recurrent_activation, bool reset_after)
: BuiltinOp(graph, VSI_NN_OP_GRUCELL),
num_units_(num_units),
activation_(activation),
@ -39,7 +39,7 @@ GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation,
this->impl()->node()->nn_param.grucell.num_units = num_units;
this->impl()->node()->nn_param.grucell.activation = activation;
this->impl()->node()->nn_param.grucell.recurrent_activation = recurrent_activation;
this->impl()->node()->nn_param.grucell.reset_after = reset_after;
this->impl()->node()->nn_param.grucell.reset_after = TranslateToVsibool(reset_after);
}
std::shared_ptr<Operation> GRUCell::Clone(std::shared_ptr<Graph>& graph) const {

View File

@ -31,8 +31,8 @@ namespace vx {
namespace ops {
UnidirectionalSequenceGRU::UnidirectionalSequenceGRU(
Graph* graph, uint32_t num_units, ActivationType activation,
ActivationType recurrent_activation, vsi_bool reset_after,
vsi_bool return_sequences, vsi_bool time_major)
ActivationType recurrent_activation, bool reset_after,
bool return_sequences, bool time_major)
: BuiltinOp(graph, VSI_NN_OP_GRU),
num_units_(num_units),
activation_(activation),
@ -43,9 +43,9 @@ UnidirectionalSequenceGRU::UnidirectionalSequenceGRU(
this->impl()->node()->nn_param.gru.num_units = num_units;
this->impl()->node()->nn_param.gru.activation = activation;
this->impl()->node()->nn_param.gru.recurrent_activation = recurrent_activation;
this->impl()->node()->nn_param.gru.reset_after = reset_after;
this->impl()->node()->nn_param.gru.return_sequences = return_sequences;
this->impl()->node()->nn_param.gru.time_major = time_major;
this->impl()->node()->nn_param.gru.reset_after = TranslateToVsibool(reset_after);
this->impl()->node()->nn_param.gru.return_sequences = TranslateToVsibool(return_sequences);
this->impl()->node()->nn_param.gru.time_major = TranslateToVsibool(time_major);
}
std::shared_ptr<Operation> UnidirectionalSequenceGRU::Clone(

View File

@ -163,6 +163,7 @@ vsi_enum TranslateResizeType(ResizeType type) {
}
vx_bool_e ToVxBool(bool val) { return val ? vx_true_e : vx_false_e; }
vsi_bool TranslateToVsibool(bool val) { return val? TRUE : FALSE; }
} // namespace vx
} // namespace tim

View File

@ -39,6 +39,7 @@ vsi_enum TranslateRoundingPolicy(RoundingPolicy type);
vsi_enum TranslateDownScaleSizeRounding(RoundType type);
vsi_enum TranslateResizeType(ResizeType type);
vx_bool_e ToVxBool(bool val);
vsi_bool TranslateToVsibool(bool val);
} // namespace vx
} // namespace tim