Refine UnidirectionalGRU and GRUCell (#587)
Refine unidirectional_gru and gru_cell code to avoid including ovxlib files in header of some op Introduce TranslateToVsibool function to support above code refinement Type: Code Improvement Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
parent
b81f7979fa
commit
3c372dd646
|
|
@ -26,7 +26,6 @@
|
|||
|
||||
#include <array>
|
||||
#include "tim/vx/builtin_op.h"
|
||||
#include "vsi_nn_pub.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -56,7 +55,7 @@ class GRUCell : public BuiltinOp {
|
|||
GRUCell(Graph* graph, uint32_t num_units,
|
||||
ActivationType activation = ActivationType::kTANH,
|
||||
ActivationType recurrent_activation = ActivationType::kSIGMOID,
|
||||
vsi_bool reset_after = TRUE);
|
||||
bool reset_after = true);
|
||||
|
||||
std::shared_ptr<Operation> Clone(
|
||||
std::shared_ptr<Graph>& graph) const override;
|
||||
|
|
@ -65,7 +64,7 @@ class GRUCell : public BuiltinOp {
|
|||
const uint32_t num_units_;
|
||||
const ActivationType activation_;
|
||||
const ActivationType recurrent_activation_;
|
||||
const int32_t reset_after_;
|
||||
const bool reset_after_;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@
|
|||
|
||||
#include <array>
|
||||
#include "tim/vx/builtin_op.h"
|
||||
#include "vsi_nn_pub.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -61,9 +60,9 @@ class UnidirectionalSequenceGRU : public BuiltinOp {
|
|||
Graph* graph, uint32_t num_units,
|
||||
ActivationType activation = ActivationType::kTANH,
|
||||
ActivationType recurrent_activation = ActivationType::kSIGMOID,
|
||||
vsi_bool reset_after = TRUE,
|
||||
vsi_bool return_sequences = FALSE, /*False: only return last state*/
|
||||
vsi_bool time_major = TRUE);
|
||||
bool reset_after = true,
|
||||
bool return_sequences = false, /*False: only return last state*/
|
||||
bool time_major = true);
|
||||
|
||||
std::shared_ptr<Operation> Clone(
|
||||
std::shared_ptr<Graph>& graph) const override;
|
||||
|
|
@ -72,9 +71,9 @@ class UnidirectionalSequenceGRU : public BuiltinOp {
|
|||
const uint32_t num_units_;
|
||||
const ActivationType activation_;
|
||||
const ActivationType recurrent_activation_;
|
||||
const int32_t reset_after_;
|
||||
const int32_t return_sequences_;
|
||||
const int32_t time_major_;
|
||||
const bool reset_after_;
|
||||
const bool return_sequences_;
|
||||
const bool time_major_;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ namespace tim {
|
|||
namespace vx {
|
||||
namespace ops {
|
||||
GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation,
|
||||
ActivationType recurrent_activation, vsi_bool reset_after)
|
||||
ActivationType recurrent_activation, bool reset_after)
|
||||
: BuiltinOp(graph, VSI_NN_OP_GRUCELL),
|
||||
num_units_(num_units),
|
||||
activation_(activation),
|
||||
|
|
@ -39,7 +39,7 @@ GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation,
|
|||
this->impl()->node()->nn_param.grucell.num_units = num_units;
|
||||
this->impl()->node()->nn_param.grucell.activation = activation;
|
||||
this->impl()->node()->nn_param.grucell.recurrent_activation = recurrent_activation;
|
||||
this->impl()->node()->nn_param.grucell.reset_after = reset_after;
|
||||
this->impl()->node()->nn_param.grucell.reset_after = TranslateToVsibool(reset_after);
|
||||
}
|
||||
|
||||
std::shared_ptr<Operation> GRUCell::Clone(std::shared_ptr<Graph>& graph) const {
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ namespace vx {
|
|||
namespace ops {
|
||||
UnidirectionalSequenceGRU::UnidirectionalSequenceGRU(
|
||||
Graph* graph, uint32_t num_units, ActivationType activation,
|
||||
ActivationType recurrent_activation, vsi_bool reset_after,
|
||||
vsi_bool return_sequences, vsi_bool time_major)
|
||||
ActivationType recurrent_activation, bool reset_after,
|
||||
bool return_sequences, bool time_major)
|
||||
: BuiltinOp(graph, VSI_NN_OP_GRU),
|
||||
num_units_(num_units),
|
||||
activation_(activation),
|
||||
|
|
@ -43,9 +43,9 @@ UnidirectionalSequenceGRU::UnidirectionalSequenceGRU(
|
|||
this->impl()->node()->nn_param.gru.num_units = num_units;
|
||||
this->impl()->node()->nn_param.gru.activation = activation;
|
||||
this->impl()->node()->nn_param.gru.recurrent_activation = recurrent_activation;
|
||||
this->impl()->node()->nn_param.gru.reset_after = reset_after;
|
||||
this->impl()->node()->nn_param.gru.return_sequences = return_sequences;
|
||||
this->impl()->node()->nn_param.gru.time_major = time_major;
|
||||
this->impl()->node()->nn_param.gru.reset_after = TranslateToVsibool(reset_after);
|
||||
this->impl()->node()->nn_param.gru.return_sequences = TranslateToVsibool(return_sequences);
|
||||
this->impl()->node()->nn_param.gru.time_major = TranslateToVsibool(time_major);
|
||||
}
|
||||
|
||||
std::shared_ptr<Operation> UnidirectionalSequenceGRU::Clone(
|
||||
|
|
|
|||
|
|
@ -163,6 +163,7 @@ vsi_enum TranslateResizeType(ResizeType type) {
|
|||
}
|
||||
|
||||
vx_bool_e ToVxBool(bool val) { return val ? vx_true_e : vx_false_e; }
|
||||
vsi_bool TranslateToVsibool(bool val) { return val? TRUE : FALSE; }
|
||||
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ vsi_enum TranslateRoundingPolicy(RoundingPolicy type);
|
|||
vsi_enum TranslateDownScaleSizeRounding(RoundType type);
|
||||
vsi_enum TranslateResizeType(ResizeType type);
|
||||
vx_bool_e ToVxBool(bool val);
|
||||
vsi_bool TranslateToVsibool(bool val);
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue