diff --git a/include/tim/vx/ops/bidirectional_sequence_lstm.h b/include/tim/vx/ops/bidirectional_sequence_lstm.h new file mode 100644 index 0000000..ee7cc02 --- /dev/null +++ b/include/tim/vx/ops/bidirectional_sequence_lstm.h @@ -0,0 +1,67 @@ +/**************************************************************************** +* +* Copyright (c) 2022 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_BIDIRECTIONAL_SEQUENCE_LSTM_H_ +#define TIM_VX_OPS_BIDIRECTIONAL_SEQUENCE_LSTM_H_ + +#include "tim/vx/operation.h" +namespace tim { +namespace vx { +namespace ops { + +class BidirectionalSequenceLstm : public Operation { + public: + enum ActivationType { + kNONE = 0, + kRELU = 1, + kRELU1 = 2, + kRELU6 = 3, + kTANH = 4, + kSIGMOID = 6, + kHARDSIGMOID = 31, /* temporary use 31 */ + }; + BidirectionalSequenceLstm( + Graph* graph, float cell_clip, float proj_clip, + ActivationType act_type, float forget_bias, bool time_major = false, + ActivationType recurrent_act_type = ActivationType::kSIGMOID, + bool return_sequences = false /*False: only return last state*/ + ); + + std::shared_ptr Clone( + std::shared_ptr& graph) const override; + + protected: + const float cell_clip_; + const float proj_clip_; + const ActivationType act_type_; + const float forget_bias_; + const bool time_major_; + const ActivationType recurrent_act_type_; + const bool return_sequences_; +}; + +} // namespace ops +} // namespace vx +} // namespace tim + +#endif /* TIM_VX_OPS_BIDIRECTIONAL_SEQUENCE_LSTM_H_ */ \ No newline at end of file diff --git a/src/tim/vx/ops/README.md b/src/tim/vx/ops/README.md index 056a149..e5d0072 100644 --- a/src/tim/vx/ops/README.md +++ b/src/tim/vx/ops/README.md @@ -109,13 +109,13 @@ GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow. |BroadCast|EXPAND_BROADCAST|Mapped|[numpy.broadcast_to](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html) ||PROPOSAL| TBD |[Faster-RCNN Proposal Layer](https://github.com/intel/caffe/blob/master/examples/faster-rcnn/lib/rpn/proposal_layer.py) ||ROI_POOL|Planned 22Q4|[ANEURALNETWORKS_ROI_POOLING](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a6736198af337b2efbdb0b6b64dee7fe4) -ROI_Align||ROI_ALIGN|Mapped|[ANEURALNETWORKS_ROI_ALIGN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a2848b39dd4bfba78f2438fda0d9397a4) -TopK||TOPK|Mapped (limited support)|[tf.math.top_k](https://tensorflow.google.cn/api_docs/python/tf/math/top_k) +|ROI_Align|ROI_ALIGN|Mapped|[ANEURALNETWORKS_ROI_ALIGN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a2848b39dd4bfba78f2438fda0d9397a4) +|TopK|TOPK|Mapped (limited support)|[tf.math.top_k](https://tensorflow.google.cn/api_docs/python/tf/math/top_k) |GRUCell|GRUCELL_OVXLIB|Mapped|[tf.keras.layers.GRUCell](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/GRUCell?hl=en) |UnidirectionalSequenceGRU|GRU_OVXLIB|Planned 22Q3|[tf.keras.layers.GRU](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/GRUCell?hl=en) |UnidirectionalSequenceRNN|UNIDIRECTIONAL_SEQUENCE_RNN|Planned 22Q3|[ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0ae11aa1d461d2abaa117f6ee2cb503dd8) |BidirectionalSequenceRNN|BIDIRECTIONAL_SEQUENCE_RNN|Planned 22Q3|[ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a487fc5ae247de828f13e62b99f259f3c) -|BidirectionalSequenceLSTM|BIDIRECTIONAL_SEQUENCE_LSTM|Planned 22Q3|[ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a492a71cb7aa50b9a1a834a3cb269d778) +|BidirectionalSequenceLSTM|BIDIRECTIONAL_SEQUENCE_LSTM|Mapped|[ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a492a71cb7aa50b9a1a834a3cb269d778) |UnidirectionalSequenceLSTM|LSTM_OVXLIB|Mapped|[ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_LSTM](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0aaf30e491ad0b1fc7602cbde695b2c859) |LSTMCell|LSTMUNIT_OVXLIB|replace with UnidirectionalSequenceLSTM by set n_step = 1 |[ANEURALNETWORKS_LSTM](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0ad0377e8c305e596fb7f64ff896671fc5) ||PRE_PROCESS|TBD |Image Preprocessing (YUV2RGB, Input Normalization, Resizing, etc) diff --git a/src/tim/vx/ops/bidirectional_sequence_lstm.cc b/src/tim/vx/ops/bidirectional_sequence_lstm.cc new file mode 100644 index 0000000..11c0c4b --- /dev/null +++ b/src/tim/vx/ops/bidirectional_sequence_lstm.cc @@ -0,0 +1,268 @@ +/**************************************************************************** +* +* Copyright (c) 2022 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/ops/bidirectional_sequence_lstm.h" +#include "tim/vx/ops/unidirectional_sequence_lstm.h" +#include "vsi_nn_pub.h" +#include "op_impl.h" + +#include +namespace tim { +namespace vx { +namespace ops { + +class BidirectionalSequenceLstmImpl : public OpImpl { + public: + enum { + BI_LSTM_INPUT_INPUT = 0, + + BI_LSTM_FW_INPUT_WEIGHT_I2I = 1, + BI_LSTM_FW_INPUT_WEIGHT_I2F = 2, + BI_LSTM_FW_INPUT_WEIGHT_I2C = 3, + BI_LSTM_FW_INPUT_WEIGHT_I2O = 4, + + BI_LSTM_FW_INPUT_WEIGHT_R2I = 5, + BI_LSTM_FW_INPUT_WEIGHT_R2F = 6, + BI_LSTM_FW_INPUT_WEIGHT_R2C = 7, + BI_LSTM_FW_INPUT_WEIGHT_R2O = 8, + + BI_LSTM_FW_INPUT_WEIGHT_C2I = 9, + BI_LSTM_FW_INPUT_WEIGHT_C2F = 10, + BI_LSTM_FW_INPUT_WEIGHT_C2O = 11, + + BI_LSTM_FW_INPUT_BIAS_I = 12, + BI_LSTM_FW_INPUT_BIAS_F = 13, + BI_LSTM_FW_INPUT_BIAS_C = 14, + BI_LSTM_FW_INPUT_BIAS_O = 15, + + BI_LSTM_FW_INPUT_WEIGHT_PROJ = 16, + BI_LSTM_FW_INPUT_BIAS_PROJ = 17, + + BI_LSTM_BW_INPUT_WEIGHT_I2I = 18, + BI_LSTM_BW_INPUT_WEIGHT_I2F = 19, + BI_LSTM_BW_INPUT_WEIGHT_I2C = 20, + BI_LSTM_BW_INPUT_WEIGHT_I2O = 21, + + BI_LSTM_BW_INPUT_WEIGHT_R2I = 22, + BI_LSTM_BW_INPUT_WEIGHT_R2F = 23, + BI_LSTM_BW_INPUT_WEIGHT_R2C = 24, + BI_LSTM_BW_INPUT_WEIGHT_R2O = 25, + + BI_LSTM_BW_INPUT_WEIGHT_C2I = 26, + BI_LSTM_BW_INPUT_WEIGHT_C2F = 27, + BI_LSTM_BW_INPUT_WEIGHT_C2O = 28, + + BI_LSTM_BW_INPUT_BIAS_I = 29, + BI_LSTM_BW_INPUT_BIAS_F = 30, + BI_LSTM_BW_INPUT_BIAS_C = 31, + BI_LSTM_BW_INPUT_BIAS_O = 32, + + BI_LSTM_BW_INPUT_WEIGHT_PROJ = 33, + BI_LSTM_BW_INPUT_BIAS_PROJ = 34, + + BI_LSTM_FW_INPUT_H_STATE = 35, + BI_LSTM_FW_INPUT_C_STATE = 36, + + BI_LSTM_BW_INPUT_H_STATE = 37, + BI_LSTM_BW_INPUT_C_STATE = 38, + + BI_LSTM_AUX_INPUT = 39, + + BI_LSTM_FW_AUX_INPUT_WEIGHT_I2I = 40, + BI_LSTM_FW_AUX_INPUT_WEIGHT_I2F = 41, + BI_LSTM_FW_AUX_INPUT_WEIGHT_I2C = 42, + BI_LSTM_FW_AUX_INPUT_WEIGHT_I2O = 43, + + BI_LSTM_BW_AUX_INPUT_WEIGHT_I2I = 44, + BI_LSTM_BW_AUX_INPUT_WEIGHT_I2F = 45, + BI_LSTM_BW_AUX_INPUT_WEIGHT_I2C = 46, + BI_LSTM_BW_AUX_INPUT_WEIGHT_I2O = 47, + + BI_LSTM_FW_INPUT_LAYERNORM_I = 48, + BI_LSTM_FW_INPUT_LAYERNORM_F = 49, + BI_LSTM_FW_INPUT_LAYERNORM_C = 50, + BI_LSTM_FW_INPUT_LAYERNORM_O = 51, + + BI_LSTM_BW_INPUT_LAYERNORM_I = 52, + BI_LSTM_BW_INPUT_LAYERNORM_F = 53, + BI_LSTM_BW_INPUT_LAYERNORM_C = 54, + BI_LSTM_BW_INPUT_LAYERNORM_O = 55, + + INPUT_CNT, + + BI_LSTM_FW_OUTPUT_OUTPUT = 0, + BI_LSTM_FW_OUTPUT_H_STATE = 1, + BI_LSTM_FW_OUTPUT_C_STATE = 2, + + BI_LSTM_BW_OUTPUT_OUTPUT = 3, + BI_LSTM_BW_OUTPUT_H_STATE = 4, + BI_LSTM_BW_OUTPUT_C_STATE = 5, + + OUTPUT_CNT + }; + + BidirectionalSequenceLstmImpl(Graph* graph, int input_cnt, int output_cnt, + DataLayout layout = DataLayout::ANY) + : OpImpl(graph, -1, input_cnt, output_cnt, layout) { + lstm_forward_ = graph->CreateOperation( + 0.0, 0.0, UnidirectionalSequenceLstm::kTANH, 0.0, false, + UnidirectionalSequenceLstm::kSIGMOID, true); + lstm_backward_ = + graph->CreateOperation( + 0.0, 0.0, UnidirectionalSequenceLstm::kTANH, 0.0, false, + UnidirectionalSequenceLstm::kSIGMOID, true); + } + + ~BidirectionalSequenceLstmImpl() {} + + BidirectionalSequenceLstmImpl& BindInput( + const std::shared_ptr& tensor) override { + in_tensors_[input_tensor_index] = tensor; + + if (this->input_tensor_index == INPUT_CNT - 1) { + // Get all input tensor + lstm_forward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_H_STATE]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_C_STATE]); + + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_WEIGHT_I2I]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_WEIGHT_I2F]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_WEIGHT_I2C]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_WEIGHT_I2O]); + + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_WEIGHT_R2I]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_WEIGHT_R2F]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_WEIGHT_R2C]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_WEIGHT_R2O]); + + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_WEIGHT_C2I]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_WEIGHT_C2F]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_WEIGHT_C2O]); + + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_BIAS_I]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_BIAS_F]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_BIAS_C]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_BIAS_O]); + + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_WEIGHT_PROJ]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_BIAS_PROJ]); + + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_I]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_F]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_C]); + lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_O]); + + lstm_backward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_H_STATE]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_C_STATE]); + + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_WEIGHT_I2I]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_WEIGHT_I2F]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_WEIGHT_I2C]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_WEIGHT_I2O]); + + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_WEIGHT_R2I]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_WEIGHT_R2F]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_WEIGHT_R2C]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_WEIGHT_R2O]); + + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_WEIGHT_C2I]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_WEIGHT_C2F]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_WEIGHT_C2O]); + + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_BIAS_I]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_BIAS_F]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_BIAS_C]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_BIAS_O]); + + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_WEIGHT_PROJ]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_BIAS_PROJ]); + + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_LAYERNORM_I]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_LAYERNORM_F]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_LAYERNORM_C]); + lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_LAYERNORM_O]); + } + this->input_tensor_index++; + return *this; + } + + BidirectionalSequenceLstmImpl& BindOutput( + const std::shared_ptr& tensor) override { + out_tensors_[output_tensor_index] = tensor; + + if (this->output_tensor_index == OUTPUT_CNT - 1) { + lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_OUTPUT]); + lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_H_STATE]); + lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_C_STATE]); + + lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]); + lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_H_STATE]); + lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_C_STATE]); + } + this->output_tensor_index++; + return *this; + } + + vsi_nn_node_t* node() override { return nullptr; } + + std::vector> InputsTensor() override { + return inputs_tensor_; + } + std::vector> OutputsTensor() override { + return outputs_tensor_; + } + + private: + std::shared_ptr lstm_forward_; + std::shared_ptr lstm_backward_; + + std::array, INPUT_CNT> in_tensors_; + std::array, OUTPUT_CNT> out_tensors_; +}; + +BidirectionalSequenceLstm::BidirectionalSequenceLstm( + Graph* graph, float cell_clip, float proj_clip, ActivationType act_type, + float forget_bias, bool time_major, ActivationType recurrent_act_type, + bool return_sequences) + : cell_clip_(cell_clip), + proj_clip_(proj_clip), + act_type_(act_type), + forget_bias_(forget_bias), + time_major_(time_major), + recurrent_act_type_(recurrent_act_type), + return_sequences_(return_sequences) { + impl_ = std::make_unique(graph, 0, 0, + DataLayout::ANY); +} + +std::shared_ptr BidirectionalSequenceLstm::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation( + this->cell_clip_, this->proj_clip_, this->act_type_, this->forget_bias_, + this->time_major_, this->recurrent_act_type_, this->return_sequences_); +} + +} // namespace ops +} // namespace vx +} // namespace tim diff --git a/src/tim/vx/ops/bidirectional_sequence_lstm_test.cc b/src/tim/vx/ops/bidirectional_sequence_lstm_test.cc new file mode 100644 index 0000000..f76edd4 --- /dev/null +++ b/src/tim/vx/ops/bidirectional_sequence_lstm_test.cc @@ -0,0 +1,301 @@ +/**************************************************************************** +* +* Copyright (c) 2022 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops/bidirectional_sequence_lstm.h" + +#include "gtest/gtest.h" +#include "test_utils.h" + +std::shared_ptr make_empty_tensor( + std::shared_ptr graph, const tim::vx::ShapeType& shape, + const tim::vx::TensorAttribute& role); //, const float& default_value) + +TEST(Bidirectional_LSTM_CELL, shape_in_2_cell_4_out_4_float32) { + // NoCifg_NoPeephole_NoProjection_NoLayerNorm + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + uint32_t n_batch, n_step, n_cell, n_input, n_output; + n_batch = 1, n_step = 3, n_cell = 4, n_input = 2, n_output = 4; + tim::vx::ShapeType input_shape, cell_shape, state_shape; + input_shape = {n_batch, n_step, n_input}; // non-time-major + + tim::vx::TensorSpec lstm_input_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_step, n_batch}), tim::vx::TensorAttribute::INPUT); + + tim::vx::TensorSpec fw_weight_i2i_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec fw_weight_i2f_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec fw_weight_i2c_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec fw_weight_i2o_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); + + tim::vx::TensorSpec fw_weight_r2i_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec fw_weight_r2f_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec fw_weight_r2c_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec fw_weight_r2o_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); + + tim::vx::TensorSpec fw_bias_i_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec fw_bias_f_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec fw_bias_c_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec fw_bias_o_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); + + tim::vx::TensorSpec bw_weight_i2i_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec bw_weight_i2f_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec bw_weight_i2c_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec bw_weight_i2o_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); + + tim::vx::TensorSpec bw_weight_r2i_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec bw_weight_r2f_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec bw_weight_r2c_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec bw_weight_r2o_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); + + tim::vx::TensorSpec bw_bias_i_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec bw_bias_f_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec bw_bias_c_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec bw_bias_o_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); + + tim::vx::TensorSpec fw_output_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_step, n_batch}), tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec bw_output_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_step, n_batch}), tim::vx::TensorAttribute::OUTPUT); + + auto lstm_input = graph->CreateTensor(lstm_input_spec); + std::vector lstm_input_data = {2., 3., 3., 4., 1., 1.}; + lstm_input->CopyDataToTensor(lstm_input_data.data(), lstm_input_data.size() * 4); + + auto fw_output_tensor = graph->CreateTensor(fw_output_spec); + auto bw_output_tensor = graph->CreateTensor(bw_output_spec); + + std::vector fw_weight_i2i = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + std::vector fw_weight_i2f = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + std::vector fw_weight_i2c = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, + -0.29909778}; + std::vector fw_weight_i2o = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, -0.1556896, + 0.19487578}; + auto fw_weight_i2i_tensor = graph->CreateTensor(fw_weight_i2i_spec, fw_weight_i2i.data()); + auto fw_weight_i2f_tensor = graph->CreateTensor(fw_weight_i2f_spec, fw_weight_i2f.data()); + auto fw_weight_i2c_tensor = graph->CreateTensor(fw_weight_i2c_spec, fw_weight_i2c.data()); + auto fw_weight_i2o_tensor = graph->CreateTensor(fw_weight_i2o_spec, fw_weight_i2o.data()); + + std::vector fw_weight_r2i = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + std::vector fw_weight_r2f = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + std::vector fw_weight_r2c = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + std::vector fw_weight_r2o = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + auto fw_weight_r2i_tensor = graph->CreateTensor(fw_weight_r2i_spec, fw_weight_r2i.data()); + auto fw_weight_r2f_tensor = graph->CreateTensor(fw_weight_r2f_spec, fw_weight_r2f.data()); + auto fw_weight_r2c_tensor = graph->CreateTensor(fw_weight_r2c_spec, fw_weight_r2c.data()); + auto fw_weight_r2o_tensor = graph->CreateTensor(fw_weight_r2o_spec, fw_weight_r2o.data()); + + std::vector fw_bias_i = {0.0, 0.0, 0.0, 0.0}; + std::vector fw_bias_f = {1., 1., 1., 1.}; + std::vector fw_bias_c = {0.0, 0.0, 0.0, 0.0}; + std::vector fw_bias_o = {0.0, 0.0, 0.0, 0.0}; + auto fw_bias_i_tensor = graph->CreateTensor(fw_bias_i_spec, fw_bias_i.data()); + auto fw_bias_f_tensor = graph->CreateTensor(fw_bias_f_spec, fw_bias_f.data()); + auto fw_bias_c_tensor = graph->CreateTensor(fw_bias_c_spec, fw_bias_c.data()); + auto fw_bias_o_tensor = graph->CreateTensor(fw_bias_o_spec, fw_bias_o.data()); + + std::vector bw_weight_i2i = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + std::vector bw_weight_i2f = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + std::vector bw_weight_i2c = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, + -0.29909778}; + std::vector bw_weight_i2o = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, -0.1556896, + 0.19487578}; + auto bw_weight_i2i_tensor = graph->CreateTensor(bw_weight_i2i_spec, bw_weight_i2i.data()); + auto bw_weight_i2f_tensor = graph->CreateTensor(bw_weight_i2f_spec, bw_weight_i2f.data()); + auto bw_weight_i2c_tensor = graph->CreateTensor(bw_weight_i2c_spec, bw_weight_i2c.data()); + auto bw_weight_i2o_tensor = graph->CreateTensor(bw_weight_i2o_spec, bw_weight_i2o.data()); + + std::vector bw_weight_r2i = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + std::vector bw_weight_r2f = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + std::vector bw_weight_r2c = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + std::vector bw_weight_r2o = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + auto bw_weight_r2i_tensor = graph->CreateTensor(bw_weight_r2i_spec, bw_weight_r2i.data()); + auto bw_weight_r2f_tensor = graph->CreateTensor(bw_weight_r2f_spec, bw_weight_r2f.data()); + auto bw_weight_r2c_tensor = graph->CreateTensor(bw_weight_r2c_spec, bw_weight_r2c.data()); + auto bw_weight_r2o_tensor = graph->CreateTensor(bw_weight_r2o_spec, bw_weight_r2o.data()); + + std::vector bw_bias_i = {0.0, 0.0, 0.0, 0.0}; + std::vector bw_bias_f = {1., 1., 1., 1.}; + std::vector bw_bias_c = {0.0, 0.0, 0.0, 0.0}; + std::vector bw_bias_o = {0.0, 0.0, 0.0, 0.0}; + auto bw_bias_i_tensor = graph->CreateTensor(bw_bias_i_spec, bw_bias_i.data()); + auto bw_bias_f_tensor = graph->CreateTensor(bw_bias_f_spec, bw_bias_f.data()); + auto bw_bias_c_tensor = graph->CreateTensor(bw_bias_c_spec, bw_bias_c.data()); + auto bw_bias_o_tensor = graph->CreateTensor(bw_bias_o_spec, fw_bias_o.data()); + + auto bidirectional_lstm = graph->CreateOperation( + 0.0, 0.0, tim::vx::ops::BidirectionalSequenceLstm::ActivationType::kTANH, 0.0, false, + tim::vx::ops::BidirectionalSequenceLstm::kSIGMOID, true); + + (*bidirectional_lstm) + .BindInputs({ + lstm_input, + + fw_weight_i2i_tensor, + fw_weight_i2f_tensor, + fw_weight_i2c_tensor, + fw_weight_i2o_tensor, + + fw_weight_r2i_tensor, + fw_weight_r2f_tensor, + fw_weight_r2c_tensor, + fw_weight_r2o_tensor, + + graph->CreateTensorPlaceHolder(), /*fw_weight_c2i*/ + graph->CreateTensorPlaceHolder(), /*fw_weight_c2f*/ + graph->CreateTensorPlaceHolder(), /*fw_weight_c2o*/ + + fw_bias_i_tensor, + fw_bias_f_tensor, + fw_bias_c_tensor, + fw_bias_o_tensor, + + // optional for projection + graph->CreateTensorPlaceHolder(), /*fw_weight_prj*/ + graph->CreateTensorPlaceHolder(), /*fw_bias_prj*/ + + bw_weight_i2i_tensor, + bw_weight_i2f_tensor, + bw_weight_i2c_tensor, + bw_weight_i2o_tensor, + + bw_weight_r2i_tensor, + bw_weight_r2f_tensor, + bw_weight_r2c_tensor, + bw_weight_r2o_tensor, + + graph->CreateTensorPlaceHolder(), /*bw_weight_c2i*/ + graph->CreateTensorPlaceHolder(), /*bw_weight_c2f*/ + graph->CreateTensorPlaceHolder(), /*bw_weight_c2o*/ + + bw_bias_i_tensor, + bw_bias_f_tensor, + bw_bias_c_tensor, + bw_bias_o_tensor, + + // optional for projection + graph->CreateTensorPlaceHolder(), /*bw_weight_prj*/ + graph->CreateTensorPlaceHolder(), /*bw_bias_prj*/ + + graph->CreateTensorPlaceHolder(), /*fw_h_state*/ + graph->CreateTensorPlaceHolder(), /*fw_c_state*/ + graph->CreateTensorPlaceHolder(), /*bw_h_state*/ + graph->CreateTensorPlaceHolder(), /*bw_c_state*/ + + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), // AUX + + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), + graph->CreateTensorPlaceHolder(), // Layer_norm + }) + .BindOutputs({ + fw_output_tensor, + make_empty_tensor( + graph, tim::vx::ShapeType({n_output, n_batch}), tim::vx::TensorAttribute::OUTPUT), /*fw_h_state*/ + make_empty_tensor( + graph, tim::vx::ShapeType({n_cell, n_batch}), tim::vx::TensorAttribute::OUTPUT), /*fw_c_state*/ + + bw_output_tensor, + make_empty_tensor( + graph, tim::vx::ShapeType({n_output, n_batch}), tim::vx::TensorAttribute::OUTPUT), /*bw_h_state*/ + make_empty_tensor( + graph, tim::vx::ShapeType({n_cell, n_batch}), tim::vx::TensorAttribute::OUTPUT), /*bw_c_state*/ + }); + + graph->Compile(); + graph->Run(); + + std::vector lstm_fw_golden_output = { + -0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}; + std::vector lstm_bw_golden_output = { + -0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}; + std::vector fw_output(lstm_fw_golden_output.size()); + std::vector bw_output(lstm_bw_golden_output.size()); + fw_output_tensor->CopyDataFromTensor(fw_output.data()); + bw_output_tensor->CopyDataFromTensor(bw_output.data()); + + EXPECT_TRUE(ArraysMatch(lstm_fw_golden_output, fw_output, 1e-4f)); + EXPECT_TRUE(ArraysMatch(lstm_bw_golden_output, bw_output, 1e-4f)); +} \ No newline at end of file