Refine UnidirectionalGRU and GRUCell (#587)
Refine unidirectional_gru and gru_cell code to avoid including ovxlib files in header of some op Introduce TranslateToVsibool function to support above code refinement Type: Code Improvement Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
parent
b81f7979fa
commit
3c372dd646
|
|
@ -26,7 +26,6 @@
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include "tim/vx/builtin_op.h"
|
#include "tim/vx/builtin_op.h"
|
||||||
#include "vsi_nn_pub.h"
|
|
||||||
|
|
||||||
namespace tim {
|
namespace tim {
|
||||||
namespace vx {
|
namespace vx {
|
||||||
|
|
@ -56,7 +55,7 @@ class GRUCell : public BuiltinOp {
|
||||||
GRUCell(Graph* graph, uint32_t num_units,
|
GRUCell(Graph* graph, uint32_t num_units,
|
||||||
ActivationType activation = ActivationType::kTANH,
|
ActivationType activation = ActivationType::kTANH,
|
||||||
ActivationType recurrent_activation = ActivationType::kSIGMOID,
|
ActivationType recurrent_activation = ActivationType::kSIGMOID,
|
||||||
vsi_bool reset_after = TRUE);
|
bool reset_after = true);
|
||||||
|
|
||||||
std::shared_ptr<Operation> Clone(
|
std::shared_ptr<Operation> Clone(
|
||||||
std::shared_ptr<Graph>& graph) const override;
|
std::shared_ptr<Graph>& graph) const override;
|
||||||
|
|
@ -65,7 +64,7 @@ class GRUCell : public BuiltinOp {
|
||||||
const uint32_t num_units_;
|
const uint32_t num_units_;
|
||||||
const ActivationType activation_;
|
const ActivationType activation_;
|
||||||
const ActivationType recurrent_activation_;
|
const ActivationType recurrent_activation_;
|
||||||
const int32_t reset_after_;
|
const bool reset_after_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include "tim/vx/builtin_op.h"
|
#include "tim/vx/builtin_op.h"
|
||||||
#include "vsi_nn_pub.h"
|
|
||||||
|
|
||||||
namespace tim {
|
namespace tim {
|
||||||
namespace vx {
|
namespace vx {
|
||||||
|
|
@ -61,9 +60,9 @@ class UnidirectionalSequenceGRU : public BuiltinOp {
|
||||||
Graph* graph, uint32_t num_units,
|
Graph* graph, uint32_t num_units,
|
||||||
ActivationType activation = ActivationType::kTANH,
|
ActivationType activation = ActivationType::kTANH,
|
||||||
ActivationType recurrent_activation = ActivationType::kSIGMOID,
|
ActivationType recurrent_activation = ActivationType::kSIGMOID,
|
||||||
vsi_bool reset_after = TRUE,
|
bool reset_after = true,
|
||||||
vsi_bool return_sequences = FALSE, /*False: only return last state*/
|
bool return_sequences = false, /*False: only return last state*/
|
||||||
vsi_bool time_major = TRUE);
|
bool time_major = true);
|
||||||
|
|
||||||
std::shared_ptr<Operation> Clone(
|
std::shared_ptr<Operation> Clone(
|
||||||
std::shared_ptr<Graph>& graph) const override;
|
std::shared_ptr<Graph>& graph) const override;
|
||||||
|
|
@ -72,9 +71,9 @@ class UnidirectionalSequenceGRU : public BuiltinOp {
|
||||||
const uint32_t num_units_;
|
const uint32_t num_units_;
|
||||||
const ActivationType activation_;
|
const ActivationType activation_;
|
||||||
const ActivationType recurrent_activation_;
|
const ActivationType recurrent_activation_;
|
||||||
const int32_t reset_after_;
|
const bool reset_after_;
|
||||||
const int32_t return_sequences_;
|
const bool return_sequences_;
|
||||||
const int32_t time_major_;
|
const bool time_major_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ namespace tim {
|
||||||
namespace vx {
|
namespace vx {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation,
|
GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation,
|
||||||
ActivationType recurrent_activation, vsi_bool reset_after)
|
ActivationType recurrent_activation, bool reset_after)
|
||||||
: BuiltinOp(graph, VSI_NN_OP_GRUCELL),
|
: BuiltinOp(graph, VSI_NN_OP_GRUCELL),
|
||||||
num_units_(num_units),
|
num_units_(num_units),
|
||||||
activation_(activation),
|
activation_(activation),
|
||||||
|
|
@ -39,7 +39,7 @@ GRUCell::GRUCell(Graph* graph, uint32_t num_units, ActivationType activation,
|
||||||
this->impl()->node()->nn_param.grucell.num_units = num_units;
|
this->impl()->node()->nn_param.grucell.num_units = num_units;
|
||||||
this->impl()->node()->nn_param.grucell.activation = activation;
|
this->impl()->node()->nn_param.grucell.activation = activation;
|
||||||
this->impl()->node()->nn_param.grucell.recurrent_activation = recurrent_activation;
|
this->impl()->node()->nn_param.grucell.recurrent_activation = recurrent_activation;
|
||||||
this->impl()->node()->nn_param.grucell.reset_after = reset_after;
|
this->impl()->node()->nn_param.grucell.reset_after = TranslateToVsibool(reset_after);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Operation> GRUCell::Clone(std::shared_ptr<Graph>& graph) const {
|
std::shared_ptr<Operation> GRUCell::Clone(std::shared_ptr<Graph>& graph) const {
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,8 @@ namespace vx {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
UnidirectionalSequenceGRU::UnidirectionalSequenceGRU(
|
UnidirectionalSequenceGRU::UnidirectionalSequenceGRU(
|
||||||
Graph* graph, uint32_t num_units, ActivationType activation,
|
Graph* graph, uint32_t num_units, ActivationType activation,
|
||||||
ActivationType recurrent_activation, vsi_bool reset_after,
|
ActivationType recurrent_activation, bool reset_after,
|
||||||
vsi_bool return_sequences, vsi_bool time_major)
|
bool return_sequences, bool time_major)
|
||||||
: BuiltinOp(graph, VSI_NN_OP_GRU),
|
: BuiltinOp(graph, VSI_NN_OP_GRU),
|
||||||
num_units_(num_units),
|
num_units_(num_units),
|
||||||
activation_(activation),
|
activation_(activation),
|
||||||
|
|
@ -43,9 +43,9 @@ UnidirectionalSequenceGRU::UnidirectionalSequenceGRU(
|
||||||
this->impl()->node()->nn_param.gru.num_units = num_units;
|
this->impl()->node()->nn_param.gru.num_units = num_units;
|
||||||
this->impl()->node()->nn_param.gru.activation = activation;
|
this->impl()->node()->nn_param.gru.activation = activation;
|
||||||
this->impl()->node()->nn_param.gru.recurrent_activation = recurrent_activation;
|
this->impl()->node()->nn_param.gru.recurrent_activation = recurrent_activation;
|
||||||
this->impl()->node()->nn_param.gru.reset_after = reset_after;
|
this->impl()->node()->nn_param.gru.reset_after = TranslateToVsibool(reset_after);
|
||||||
this->impl()->node()->nn_param.gru.return_sequences = return_sequences;
|
this->impl()->node()->nn_param.gru.return_sequences = TranslateToVsibool(return_sequences);
|
||||||
this->impl()->node()->nn_param.gru.time_major = time_major;
|
this->impl()->node()->nn_param.gru.time_major = TranslateToVsibool(time_major);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Operation> UnidirectionalSequenceGRU::Clone(
|
std::shared_ptr<Operation> UnidirectionalSequenceGRU::Clone(
|
||||||
|
|
|
||||||
|
|
@ -163,6 +163,7 @@ vsi_enum TranslateResizeType(ResizeType type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
vx_bool_e ToVxBool(bool val) { return val ? vx_true_e : vx_false_e; }
|
vx_bool_e ToVxBool(bool val) { return val ? vx_true_e : vx_false_e; }
|
||||||
|
vsi_bool TranslateToVsibool(bool val) { return val? TRUE : FALSE; }
|
||||||
|
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ vsi_enum TranslateRoundingPolicy(RoundingPolicy type);
|
||||||
vsi_enum TranslateDownScaleSizeRounding(RoundType type);
|
vsi_enum TranslateDownScaleSizeRounding(RoundType type);
|
||||||
vsi_enum TranslateResizeType(ResizeType type);
|
vsi_enum TranslateResizeType(ResizeType type);
|
||||||
vx_bool_e ToVxBool(bool val);
|
vx_bool_e ToVxBool(bool val);
|
||||||
|
vsi_bool TranslateToVsibool(bool val);
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue