diff --git a/include/tim/vx/ops.h b/include/tim/vx/ops.h index a7ed369..d8da63b 100644 --- a/include/tim/vx/ops.h +++ b/include/tim/vx/ops.h @@ -29,6 +29,8 @@ #include "tim/vx/ops/arg.h" #include "tim/vx/ops/batch2space.h" #include "tim/vx/ops/batchnorm.h" +#include "tim/vx/ops/bidirectional_sequence_rnn.h" +#include "tim/vx/ops/bidirectional_sequence_rnn_ext.h" #include "tim/vx/ops/broadcast.h" #include "tim/vx/ops/clip.h" #include "tim/vx/ops/concat.h" @@ -89,6 +91,8 @@ #include "tim/vx/ops/tile.h" #include "tim/vx/ops/transpose.h" #include "tim/vx/ops/unidirectional_sequence_lstm.h" +#include "tim/vx/ops/unidirectional_sequence_rnn.h" +#include "tim/vx/ops/unidirectional_sequence_rnn_ext.h" #include "tim/vx/ops/unstack.h" #include "tim/vx/ops/conv3d.h" #include "tim/vx/ops/custom_base.h" diff --git a/include/tim/vx/ops/bidirectional_sequence_rnn.h b/include/tim/vx/ops/bidirectional_sequence_rnn.h new file mode 100644 index 0000000..305975d --- /dev/null +++ b/include/tim/vx/ops/bidirectional_sequence_rnn.h @@ -0,0 +1,64 @@ +/**************************************************************************** +* +* Copyright (c) 2020 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_BIDIRECTIONAL_SEQUENCE_RNN_H_ +#define TIM_VX_OPS_BIDIRECTIONAL_SEQUENCE_RNN_H_ +#include "tim/vx/builtin_op.h" + +namespace tim { +namespace vx { +namespace ops { + /** + * ## bidirectional sequence rnn + * how to bind input/output: take bidirectional_sequence_rnn_test.cc + */ + class BidirectionalSequenceRnn: public DirectMapOp { + public: + enum ActivationType { + kNONE = 0, + kRELU = 1, + kRELU1 = 2, + kRELU6 = 3, + kTANH = 4, + kSIGMOID = 6, + kHARDSIGMOID = 31, /* temporary use 31 */ + }; + + BidirectionalSequenceRnn( + Graph* graph, + ActivationType act_type, + bool time_major = false, + bool merge_outputs = false + ); + + std::shared_ptr Clone( + std::shared_ptr& graph) const override; + + protected: + ActivationType act_type_; + }; +} +} // namespace vx +} // namespace tim + +#endif diff --git a/include/tim/vx/ops/bidirectional_sequence_rnn_ext.h b/include/tim/vx/ops/bidirectional_sequence_rnn_ext.h new file mode 100644 index 0000000..249e06f --- /dev/null +++ b/include/tim/vx/ops/bidirectional_sequence_rnn_ext.h @@ -0,0 +1,52 @@ +/**************************************************************************** +* +* Copyright (c) 2020 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_BIDIRECTIONAL_SEQUENCE_RNN_EXT_H_ +#define TIM_VX_OPS_BIDIRECTIONAL_SEQUENCE_RNN_EXT_H_ +#include "tim/vx/operation.h" + +namespace tim { +namespace vx { +namespace ops { + /** + * ## Bidirectional sequence rnn for onnx + * how to bind input/output: take unidirectional_sequence_rnn_ext_test.cc + */ + class BidirectionalSequenceRnnExt: public Operation { + public: + BidirectionalSequenceRnnExt( + Graph* graph, + tim::vx::ops::BidirectionalSequenceRnn::ActivationType act_type + ); + + std::shared_ptr Clone( + std::shared_ptr& graph) const override; + + protected: + tim::vx::ops::BidirectionalSequenceRnn::ActivationType act_type_; + }; +} +} // namespace vx +} // namespace tim + +#endif diff --git a/include/tim/vx/ops/unidirectional_sequence_rnn.h b/include/tim/vx/ops/unidirectional_sequence_rnn.h new file mode 100644 index 0000000..10e973d --- /dev/null +++ b/include/tim/vx/ops/unidirectional_sequence_rnn.h @@ -0,0 +1,63 @@ +/**************************************************************************** +* +* Copyright (c) 2020 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_UNIDIRECTIONAL_SEQUENCE_RNN_H_ +#define TIM_VX_OPS_UNIDIRECTIONAL_SEQUENCE_RNN_H_ +#include "tim/vx/builtin_op.h" + +namespace tim { +namespace vx { +namespace ops { + /** + * ## Unidirectional sequence rnn + * how to bind input/output: take unidirectional_sequence_rnn_test.cc + */ + class UnidirectionalSequenceRnn: public DirectMapOp { + public: + enum ActivationType { + kNONE = 0, + kRELU = 1, + kRELU1 = 2, + kRELU6 = 3, + kTANH = 4, + kSIGMOID = 6, + kHARDSIGMOID = 31, /* temporary use 31 */ + }; + + UnidirectionalSequenceRnn( + Graph* graph, + ActivationType act_type, + bool time_major = false + ); + + std::shared_ptr Clone( + std::shared_ptr& graph) const override; + + protected: + ActivationType act_type_; + }; +} +} // namespace vx +} // namespace tim + +#endif diff --git a/include/tim/vx/ops/unidirectional_sequence_rnn_ext.h b/include/tim/vx/ops/unidirectional_sequence_rnn_ext.h new file mode 100644 index 0000000..98e5cb4 --- /dev/null +++ b/include/tim/vx/ops/unidirectional_sequence_rnn_ext.h @@ -0,0 +1,53 @@ +/**************************************************************************** +* +* Copyright (c) 2020 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_UNIDIRECTIONAL_SEQUENCE_RNN_EXT_H_ +#define TIM_VX_OPS_UNIDIRECTIONAL_SEQUENCE_RNN_EXT_H_ +#include "tim/vx/operation.h" + + +namespace tim { +namespace vx { +namespace ops { + /** + * ## Unidirectional sequence rnn for onnx + * how to bind input/output: take unidirectional_sequence_rnn_ext_test.cc + */ + class UnidirectionalSequenceRnnExt: public Operation { + public: + UnidirectionalSequenceRnnExt( + Graph* graph, + tim::vx::ops::UnidirectionalSequenceRnn::ActivationType act_type + ); + + std::shared_ptr Clone( + std::shared_ptr& graph) const override; + + protected: + tim::vx::ops::UnidirectionalSequenceRnn::ActivationType act_type_; + }; +} +} // namespace vx +} // namespace tim + +#endif \ No newline at end of file diff --git a/src/tim/vx/ops/bidirectional_sequence_rnn.cc b/src/tim/vx/ops/bidirectional_sequence_rnn.cc new file mode 100644 index 0000000..8308f9a --- /dev/null +++ b/src/tim/vx/ops/bidirectional_sequence_rnn.cc @@ -0,0 +1,79 @@ +/**************************************************************************** +* +* Copyright (c) 2020 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/ops/bidirectional_sequence_rnn.h" + +#include "builtin_op_impl.h" +#include "vsi_nn_pub.h" + +namespace tim { +namespace vx { +namespace ops { + +vsi_nn_activation_e downcast_act_type(BidirectionalSequenceRnn::ActivationType act) { + switch (act) { + case BidirectionalSequenceRnn::ActivationType::kRELU: + return VSI_NN_ACT_RELU; + case BidirectionalSequenceRnn::ActivationType::kRELU1: + return VSI_NN_ACT_RELU1; + case BidirectionalSequenceRnn::ActivationType::kRELU6: + return VSI_NN_ACT_RELU6; + case BidirectionalSequenceRnn::ActivationType::kTANH: + return VSI_NN_ACT_TANH; + case BidirectionalSequenceRnn::ActivationType::kSIGMOID: + return VSI_NN_ACT_SIGMOID; + case BidirectionalSequenceRnn::ActivationType::kHARDSIGMOID: + return VSI_NN_ACT_HARD_SIGMOID; + default: { + VSILOGW("Not supported activition type for RNN = %d", static_cast(act)); + return VSI_NN_ACT_NONE; + } + } +} + +BidirectionalSequenceRnn::BidirectionalSequenceRnn( + Graph* graph, + ActivationType act_type, + bool time_major, + bool merge_outputs) + : DirectMapOp(graph, VSI_NN_OP_BIDIRECTIONAL_SEQUENCE_RNN), + act_type_(act_type) { + + this->impl()->node()->nn_param.bidirectional_sequence_rnn.time_major = time_major; + this->impl()->node()->nn_param.bidirectional_sequence_rnn.merge_outputs = merge_outputs; + this->impl()->node()->nn_param.bidirectional_sequence_rnn.activation = + downcast_act_type(act_type); +} + +std::shared_ptr BidirectionalSequenceRnn::Clone(std::shared_ptr& graph) const { + auto cloned_op = + graph->CreateOperation( + act_type_, + this->impl()->node()->nn_param.bidirectional_sequence_rnn.time_major, + this->impl()->node()->nn_param.bidirectional_sequence_rnn.merge_outputs); + return cloned_op; +} + +} +} // namespace vx +} // namespace tim diff --git a/src/tim/vx/ops/bidirectional_sequence_rnn_ext.cc b/src/tim/vx/ops/bidirectional_sequence_rnn_ext.cc new file mode 100644 index 0000000..78e3f54 --- /dev/null +++ b/src/tim/vx/ops/bidirectional_sequence_rnn_ext.cc @@ -0,0 +1,299 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/ops.h" +#include "vsi_nn_pub.h" +#include "op_impl.h" + +#include +namespace tim { +namespace vx { +namespace ops { + +class BidirectionalSequenceRnnExtImpl : public OpImpl { + public: + enum { + // signature + RNN_EXT_INPUT_INPUT = 0, + RNN_EXT_INPUT_WEIGHT_I = 1, + RNN_EXT_INPUT_WEIGHT_H = 2, + RNN_EXT_INPUT_BIAS = 3, + RNN_EXT_INPUT_H_STATE = 4, + RNN_EXT_INPUT_CNT, + + RNN_EXT_OUTPUT_H_STATE = 0, + RNN_EXT_OUTPUT_OUTPUT = 1, + RNN_EXT_OUT_CNT, + // signature end + }; + + BidirectionalSequenceRnnExtImpl(Graph* graph, tim::vx::ops::BidirectionalSequenceRnn::ActivationType act_type, + DataLayout layout = DataLayout::ANY) + : OpImpl(graph, layout), + act_type_(act_type) { + + } + + ~BidirectionalSequenceRnnExtImpl() {} + + BidirectionalSequenceRnnExtImpl& BindInput(const std::shared_ptr& tensor) override { + in_tensors_[input_tensor_index] = tensor; + + if (this->input_tensor_index == RNN_EXT_INPUT_CNT - 1) { + tim::vx::DataType datatype = in_tensors_[RNN_EXT_INPUT_WEIGHT_I]->GetDataType(); + uint32_t input_size = in_tensors_[RNN_EXT_INPUT_WEIGHT_I]->GetShape()[0]; + uint32_t num_units = in_tensors_[RNN_EXT_INPUT_WEIGHT_I]->GetShape()[1]; + uint32_t batch_size = in_tensors_[RNN_EXT_INPUT_INPUT]->GetShape()[1]; + uint32_t seq_length = in_tensors_[RNN_EXT_INPUT_INPUT]->GetShape()[2]; + + + // Get all tensor + tim::vx::ShapeType input_weight_i_shape = {input_size, num_units, 1}; + tim::vx::ShapeType input_weight_h_shape = {num_units, num_units, 1}; + tim::vx::ShapeType input_reshape_weight_i_shape = {input_size, num_units}; + tim::vx::ShapeType input_reshape_weight_h_shape = {num_units, num_units}; + tim::vx::ShapeType input_bias_shape = {2*num_units, 1}; + tim::vx::ShapeType input_reshape_bias_shape = {2*num_units}; + tim::vx::ShapeType input_reshape_split_bias_shape = {num_units}; + tim::vx::ShapeType input_hstate_shape = {num_units, batch_size, 1}; + tim::vx::ShapeType input_reshape_hstate_shape = {num_units, batch_size}; + tim::vx::ShapeType output_shape = {num_units, batch_size, seq_length}; + tim::vx::ShapeType output_reshape_shape = {num_units, batch_size, 1, seq_length}; + tim::vx::ShapeType output_hstate_shape = {num_units, batch_size}; + tim::vx::ShapeType output_reshape_hstate_shape = {num_units, batch_size, 1}; + tim::vx::ShapeType ext_output_shape = {num_units, batch_size, 2, seq_length}; + tim::vx::ShapeType ext_output_hstate_shape = {num_units, batch_size, 2}; + + + + tim::vx::TensorSpec input_weight_i_spec(datatype, input_weight_i_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_weight_h_spec(datatype, input_weight_h_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_reshape_weight_i_spec(datatype, input_reshape_weight_i_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_reshape_weight_h_spec(datatype, input_reshape_weight_h_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_bias_spec(datatype, input_bias_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_reshape_bias_spec(datatype, input_reshape_bias_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_reshape_split_bias_spec(datatype, input_reshape_split_bias_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_hstate_spec(datatype, input_hstate_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_reshape_hstate_spec(datatype, input_reshape_hstate_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec output_spec(datatype, output_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec output_reshape_spec(datatype, output_reshape_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec output_hstate_spec(datatype, output_hstate_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec output_reshape_hstate_spec(datatype, output_reshape_hstate_shape, + tim::vx::TensorAttribute::TRANSIENT); + + auto input_fw_weight_i_tensor = graph_->CreateTensor(input_weight_i_spec); + auto input_fw_weight_h_tensor = graph_->CreateTensor(input_weight_h_spec); + auto input_fw_reshape_weight_i_tensor = graph_->CreateTensor(input_reshape_weight_i_spec); + auto input_fw_reshape_weight_h_tensor = graph_->CreateTensor(input_reshape_weight_h_spec); + auto input_fw_bias_tensor = graph_->CreateTensor(input_bias_spec); + auto input_fw_reshape_bias_tensor = graph_->CreateTensor(input_reshape_bias_spec); + auto input_fw_reshape_split_bias_i_tensor = graph_->CreateTensor(input_reshape_split_bias_spec); + auto input_fw_reshape_split_bias_h_tensor = graph_->CreateTensor(input_reshape_split_bias_spec); + auto input_fw_hstate_tensor = graph_->CreateTensor(input_hstate_spec); + auto input_fw_reshape_hstate_tensor = graph_->CreateTensor(input_reshape_hstate_spec); + auto output_fw_tensor = graph_->CreateTensor(output_spec); + auto output_fw_reshape_tensor = graph_->CreateTensor(output_reshape_spec); + auto output_fw_hstate_tensor = graph_->CreateTensor(output_hstate_spec); + auto output_fw_reshape_hstate_tensor = graph_->CreateTensor(output_reshape_hstate_spec); + + auto input_bw_weight_i_tensor = graph_->CreateTensor(input_weight_i_spec); + auto input_bw_weight_h_tensor = graph_->CreateTensor(input_weight_h_spec); + auto input_bw_reshape_weight_i_tensor = graph_->CreateTensor(input_reshape_weight_i_spec); + auto input_bw_reshape_weight_h_tensor = graph_->CreateTensor(input_reshape_weight_h_spec); + auto input_bw_bias_tensor = graph_->CreateTensor(input_bias_spec); + auto input_bw_reshape_bias_tensor = graph_->CreateTensor(input_reshape_bias_spec); + auto input_bw_reshape_split_bias_i_tensor = graph_->CreateTensor(input_reshape_split_bias_spec); + auto input_bw_reshape_split_bias_h_tensor = graph_->CreateTensor(input_reshape_split_bias_spec); + auto input_bw_hstate_tensor = graph_->CreateTensor(input_hstate_spec); + auto input_bw_reshape_hstate_tensor = graph_->CreateTensor(input_reshape_hstate_spec); + auto output_bw_tensor = graph_->CreateTensor(output_spec); + auto output_bw_reshape_tensor = graph_->CreateTensor(output_reshape_spec); + auto output_bw_hstate_tensor = graph_->CreateTensor(output_hstate_spec); + auto output_bw_reshape_hstate_tensor = graph_->CreateTensor(output_reshape_hstate_spec); + + std::vector slices_directions = {1, 1}; + split_weight_ = graph_->CreateOperation(2, slices_directions); + reshape_fw_weight_ = graph_->CreateOperation(input_reshape_weight_i_shape); + reshape_bw_weight_ = graph_->CreateOperation(input_reshape_weight_i_shape); + + split_recurrent_ = graph_->CreateOperation(2, slices_directions); + reshape_fw_recurrent_ = graph_->CreateOperation(input_reshape_weight_h_shape); + reshape_bw_recurrent_ = graph_->CreateOperation(input_reshape_weight_h_shape); + + split_bias_ = graph_->CreateOperation(1, slices_directions); + reshape_fw_bias_ = graph_->CreateOperation(input_reshape_bias_shape); + reshape_bw_bias_ = graph_->CreateOperation(input_reshape_bias_shape); + std::vector slices_units = {num_units, num_units}; + split_reshape_fw_bias = graph_->CreateOperation(0, slices_units); + split_reshape_bw_bias = graph_->CreateOperation(0, slices_units); + + split_hstate_ = graph_->CreateOperation(2, slices_directions); + reshape_fw_hstate_ = graph_->CreateOperation(input_reshape_hstate_shape); + reshape_bw_hstate_ = graph_->CreateOperation(input_reshape_hstate_shape); + + rnn_ = graph_->CreateOperation(act_type_, true, false); + + + reshape_fw_out_ = graph_->CreateOperation(output_reshape_shape); + reshape_fw_out_hstate_ = graph_->CreateOperation(output_reshape_hstate_shape); + reshape_bw_out_ = graph_->CreateOperation(output_reshape_shape); + reshape_bw_out_hstate_ = graph_->CreateOperation(output_reshape_hstate_shape); + concat_output_ = graph_->CreateOperation(2, 2); + concat_out_hstate_ = graph_->CreateOperation(2, 2); + + + split_weight_->BindInputs({in_tensors_[RNN_EXT_INPUT_WEIGHT_I]}). + BindOutputs({input_fw_weight_i_tensor, input_bw_weight_i_tensor}); + reshape_fw_weight_->BindInputs({input_fw_weight_i_tensor}). + BindOutputs({input_fw_reshape_weight_i_tensor}); + reshape_bw_weight_->BindInputs({input_bw_weight_i_tensor}). + BindOutputs({input_bw_reshape_weight_i_tensor}); + + split_recurrent_->BindInputs({in_tensors_[RNN_EXT_INPUT_WEIGHT_H]}). + BindOutputs({input_fw_weight_h_tensor, input_bw_weight_h_tensor}); + reshape_fw_recurrent_->BindInputs({input_fw_weight_h_tensor}). + BindOutputs({input_fw_reshape_weight_h_tensor}); + reshape_bw_recurrent_->BindInputs({input_bw_weight_h_tensor}). + BindOutputs({input_bw_reshape_weight_h_tensor}); + + split_bias_->BindInputs({in_tensors_[RNN_EXT_INPUT_BIAS]}). + BindOutputs({input_fw_bias_tensor, input_bw_bias_tensor}); + reshape_fw_bias_->BindInputs({input_fw_bias_tensor}). + BindOutputs({input_fw_reshape_bias_tensor}); + reshape_bw_bias_->BindInputs({input_bw_bias_tensor}). + BindOutputs({input_bw_reshape_bias_tensor}); + split_reshape_fw_bias->BindInputs({input_fw_reshape_bias_tensor}). + BindOutputs({input_fw_reshape_split_bias_i_tensor, input_fw_reshape_split_bias_h_tensor}); + split_reshape_bw_bias->BindInputs({input_bw_reshape_bias_tensor}). + BindOutputs({input_bw_reshape_split_bias_i_tensor, input_bw_reshape_split_bias_h_tensor}); + + split_hstate_->BindInputs({in_tensors_[RNN_EXT_INPUT_H_STATE]}). + BindOutputs({input_fw_hstate_tensor, input_bw_hstate_tensor}); + reshape_fw_hstate_->BindInputs({input_fw_hstate_tensor}). + BindOutputs({input_fw_reshape_hstate_tensor}); + reshape_bw_hstate_->BindInputs({input_bw_hstate_tensor}). + BindOutputs({input_bw_reshape_hstate_tensor}); + + + + + rnn_->BindInputs({in_tensors_[RNN_EXT_INPUT_INPUT], input_fw_reshape_weight_i_tensor, input_fw_reshape_weight_h_tensor, input_fw_reshape_split_bias_i_tensor, input_fw_reshape_split_bias_h_tensor, input_fw_reshape_hstate_tensor, + input_bw_reshape_weight_i_tensor, input_bw_reshape_weight_h_tensor, input_bw_reshape_split_bias_i_tensor, input_bw_reshape_split_bias_h_tensor, input_bw_reshape_hstate_tensor}); + rnn_->BindOutputs({output_fw_hstate_tensor, output_bw_hstate_tensor, + output_fw_tensor, output_bw_tensor}); + + reshape_fw_out_hstate_->BindInputs({output_fw_hstate_tensor}). + BindOutputs({output_fw_reshape_hstate_tensor}); + reshape_fw_out_->BindInputs({output_fw_tensor}). + BindOutputs({output_fw_reshape_tensor}); + reshape_bw_out_hstate_->BindInputs({output_bw_hstate_tensor}). + BindOutputs({output_bw_reshape_hstate_tensor}); + reshape_bw_out_->BindInputs({output_bw_tensor}). + BindOutputs({output_bw_reshape_tensor}); + + concat_out_hstate_->BindInputs({output_fw_reshape_hstate_tensor, output_bw_reshape_hstate_tensor}); + concat_output_->BindInputs({output_fw_reshape_tensor, output_bw_reshape_tensor}); + + } + this->input_tensor_index++; + return *this; + } + + BidirectionalSequenceRnnExtImpl& BindOutput(const std::shared_ptr& tensor) override { + out_tensors_[output_tensor_index] = tensor; + + if (this->output_tensor_index == RNN_EXT_OUT_CNT - 1) { + concat_output_->BindOutput(out_tensors_[RNN_EXT_OUTPUT_OUTPUT]); + concat_out_hstate_->BindOutput(out_tensors_[RNN_EXT_OUTPUT_H_STATE]); + } + this->output_tensor_index++; + return *this; + } + + vsi_nn_node_t* node() override { return nullptr; } + + std::vector> InputsTensor() override { + return inputs_tensor_; + } + std::vector> OutputsTensor() override { + return outputs_tensor_; + } + + private: + tim::vx::ops::BidirectionalSequenceRnn::ActivationType act_type_; + + std::shared_ptr split_weight_; + std::shared_ptr reshape_fw_weight_; + std::shared_ptr reshape_bw_weight_; + + std::shared_ptr split_recurrent_; + std::shared_ptr reshape_fw_recurrent_; + std::shared_ptr reshape_bw_recurrent_; + + std::shared_ptr split_bias_; + std::shared_ptr reshape_fw_bias_; + std::shared_ptr reshape_bw_bias_; + std::shared_ptr split_reshape_fw_bias; + std::shared_ptr split_reshape_bw_bias; + + std::shared_ptr split_hstate_; + std::shared_ptr reshape_fw_hstate_; + std::shared_ptr reshape_bw_hstate_; + + std::shared_ptr rnn_; + + std::shared_ptr reshape_fw_out_; + std::shared_ptr reshape_fw_out_hstate_; + std::shared_ptr reshape_bw_out_; + std::shared_ptr reshape_bw_out_hstate_; + std::shared_ptr concat_output_; + std::shared_ptr concat_out_hstate_; + + std::array, RNN_EXT_INPUT_CNT> in_tensors_; + std::array, RNN_EXT_OUT_CNT> out_tensors_; +}; + +BidirectionalSequenceRnnExt::BidirectionalSequenceRnnExt(Graph* graph, tim::vx::ops::BidirectionalSequenceRnn::ActivationType act_type) + : act_type_(act_type) { + impl_ = std::make_unique(graph, act_type, DataLayout::ANY); +} + +std::shared_ptr BidirectionalSequenceRnnExt::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->act_type_); +} + +} // namespace ops +} // namespace vx +} // namespace tim diff --git a/src/tim/vx/ops/bidirectional_sequence_rnn_ext_test.cc b/src/tim/vx/ops/bidirectional_sequence_rnn_ext_test.cc new file mode 100644 index 0000000..3cc38a7 --- /dev/null +++ b/src/tim/vx/ops/bidirectional_sequence_rnn_ext_test.cc @@ -0,0 +1,161 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops.h" +#include "tim/vx/types.h" +#include "gtest/gtest.h" +#include "test_utils.h" + + +TEST(BidirectionalSequenceRnnExt, shape_2_3_4_float_sigmoid) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + uint32_t input_size = 2, batch_size = 3, num_units = 4; + + tim::vx::ShapeType input_shape({input_size, batch_size, 2}); + tim::vx::ShapeType weights_shape({input_size, num_units, 2}); + tim::vx::ShapeType recurrent_weights_shape({num_units, num_units, 2}); + tim::vx::ShapeType bias_shape({num_units*2, 2}); + tim::vx::ShapeType state_in_shape({num_units, batch_size, 2}); + tim::vx::ShapeType output_shape({num_units, batch_size, 2, 2}); + tim::vx::ShapeType state_out_shape({num_units, batch_size, 2}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + input_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec weights_spec(tim::vx::DataType::FLOAT32, + weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_weights_spec(tim::vx::DataType::FLOAT32, + recurrent_weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, + bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec state_in_spec(tim::vx::DataType::FLOAT32, + state_in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + output_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec state_out_spec(tim::vx::DataType::FLOAT32, + state_out_shape, tim::vx::TensorAttribute::OUTPUT); + + + auto input_tensor = graph->CreateTensor(input_spec); + auto weights_tensor = graph->CreateTensor(weights_spec); + auto recurrent_weights_tensor = graph->CreateTensor(recurrent_weights_spec); + auto bias_tensor = graph->CreateTensor(bias_spec); + auto state_in_tensor = graph->CreateTensor(state_in_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto state_out_tensor = graph->CreateTensor(state_out_spec); + + std::vector in_data = { + 1.0, 2.0, + 3.0, 4.0, + 5.0, 6.0, + 7.0, 8.0, + 9.0, 10.0, + 11.0, 12.0 + }; + std::vector weights_data = { + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1 + }; + std::vector recurrent_weights_data = { + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + }; + std::vector bias_data = { + 0.1, 0.1, 0.1, 0.1, + 0.0, 0.0, 0.0, 0.0, + 0.1, 0.1, 0.1, 0.1, //bug 不能被获取到 + 0.0, 0.0, 0.0, 0.0, + }; + std::vector state_in_data = { + 0,0,0,0, + 0,0,0,0, + 0,0,0,0, + 0,0,0,0, + 0,0,0,0, + 0,0,0,0 + }; + std::vector output_golden = { + 0.5986, 0.5986, 0.5986, 0.5986, + 0.6899, 0.6899, 0.6899, 0.6899, + 0.7685, 0.7685, 0.7685, 0.7685, + 0.8320, 0.8320, 0.8320, 0.8320, + 0.8807, 0.8807, 0.8807, 0.8807, + 0.9168, 0.9168, 0.9168, 0.9168, + 0.8628, 0.8628, 0.8628, 0.8628, + 0.9068, 0.9068, 0.9068, 0.9068, + 0.9374, 0.9374, 0.9374, 0.9374, + 0.6754, 0.6754, 0.6754, 0.6754, + 0.7599, 0.7599, 0.7599, 0.7599, + 0.8273, 0.8273, 0.8273, 0.8273 + }; + std::vector state_out_golden = { + 0.8628, 0.8628, 0.8628, 0.8628, + 0.9068, 0.9068, 0.9068, 0.9068, + 0.9374, 0.9374, 0.9374, 0.9374, + 0.6754, 0.6754, 0.6754, 0.6754, + 0.7599, 0.7599, 0.7599, 0.7599, + 0.8273, 0.8273, 0.8273, 0.8273 + }; + + + EXPECT_TRUE(input_tensor->CopyDataToTensor( + in_data.data(), in_data.size() * sizeof(float))); + EXPECT_TRUE(weights_tensor->CopyDataToTensor( + weights_data.data(), weights_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_weights_tensor->CopyDataToTensor( + recurrent_weights_data.data(), recurrent_weights_data.size() * sizeof(float))); + EXPECT_TRUE(bias_tensor->CopyDataToTensor( + bias_data.data(), bias_data.size() * sizeof(float))); + EXPECT_TRUE(state_in_tensor->CopyDataToTensor( + state_in_data.data(), state_in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(tim::vx::ops::BidirectionalSequenceRnn::ActivationType::kSIGMOID); + (*op).BindInputs({input_tensor, weights_tensor, recurrent_weights_tensor, bias_tensor, state_in_tensor}) + .BindOutputs({state_out_tensor, output_tensor}); + graph->PrintGraph(); + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output(output_golden.size()); + std::vector state_out(state_out_golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(state_out_tensor->CopyDataFromTensor(state_out.data())); + + + EXPECT_TRUE(ArraysMatch(output_golden, output,1e-3f)); + EXPECT_TRUE(ArraysMatch(state_out_golden, state_out,1e-3f)); +} \ No newline at end of file diff --git a/src/tim/vx/ops/bidirectional_sequence_rnn_test.cc b/src/tim/vx/ops/bidirectional_sequence_rnn_test.cc new file mode 100644 index 0000000..fdf887f --- /dev/null +++ b/src/tim/vx/ops/bidirectional_sequence_rnn_test.cc @@ -0,0 +1,338 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops/bidirectional_sequence_rnn.h" +#include "tim/vx/types.h" +#include "gtest/gtest.h" +#include "test_utils.h" + + +TEST(BidirectionalSequenceRnn, shape_2_3_4_float_sigmoid) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + uint32_t input_size = 2, batch_size = 3, num_units = 4; + + tim::vx::ShapeType input_shape({input_size, batch_size, 2}); + tim::vx::ShapeType weights_shape({input_size, num_units}); + tim::vx::ShapeType recurrent_weights_shape({num_units, num_units}); + tim::vx::ShapeType bias_shape({num_units}); + tim::vx::ShapeType recurrent_bias_shape({num_units}); + tim::vx::ShapeType state_in_shape({num_units, batch_size}); + tim::vx::ShapeType output_shape({num_units, batch_size, 2}); + tim::vx::ShapeType state_out_shape({num_units, batch_size}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + input_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec weights_spec(tim::vx::DataType::FLOAT32, + weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_weights_spec(tim::vx::DataType::FLOAT32, + recurrent_weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, + bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_bias_spec(tim::vx::DataType::FLOAT32, + recurrent_bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec state_in_spec(tim::vx::DataType::FLOAT32, + state_in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + output_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec state_out_spec(tim::vx::DataType::FLOAT32, + state_out_shape, tim::vx::TensorAttribute::OUTPUT); + + + auto input_tensor = graph->CreateTensor(input_spec); + auto weights_tensor = graph->CreateTensor(weights_spec); + auto recurrent_weights_tensor = graph->CreateTensor(recurrent_weights_spec); + auto bias_tensor = graph->CreateTensor(bias_spec); + auto recurrent_bias_tensor = graph->CreateTensor(recurrent_bias_spec); + auto state_in_tensor = graph->CreateTensor(state_in_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto state_out_tensor = graph->CreateTensor(state_out_spec); + auto bw_weights_tensor = graph->CreateTensor(weights_spec); + auto bw_recurrent_weights_tensor = graph->CreateTensor(recurrent_weights_spec); + auto bw_bias_tensor = graph->CreateTensor(bias_spec); + auto bw_recurrent_bias_tensor = graph->CreateTensor(recurrent_bias_spec); + auto bw_state_in_tensor = graph->CreateTensor(state_in_spec); + auto bw_output_tensor = graph->CreateTensor(output_spec); + auto bw_state_out_tensor = graph->CreateTensor(state_out_spec); + + std::vector in_data = { + 1.0, 2.0, + 3.0, 4.0, + 5.0, 6.0, + 7.0, 8.0, + 9.0, 10.0, + 11.0, 12.0 + }; + std::vector weights_data = { + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1 + }; + std::vector recurrent_weights_data = { + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + }; + std::vector bias_data = { + 0.1, 0.1, 0.1, 0.1 + }; + std::vector recurrent_bias_data = { + 0.0, 0.0, 0.0, 0.0 + }; + std::vector state_in_data = { + 0,0,0,0, + 0,0,0,0, + 0,0,0,0 + }; + std::vector output_golden = { + 0.5986, 0.5986, 0.5986, 0.5986, + 0.6899, 0.6899, 0.6899, 0.6899, + 0.7685, 0.7685, 0.7685, 0.7685, + 0.8628, 0.8628, 0.8628, 0.8628, + 0.9068, 0.9068, 0.9068, 0.9068, + 0.9374, 0.9374, 0.9374, 0.9374, + }; + std::vector state_out_golden = { + 0.8628, 0.8628, 0.8628, 0.8628, + 0.9068, 0.9068, 0.9068, 0.9068, + 0.9374, 0.9374, 0.9374, 0.9374, + }; + std::vector bw_output_golden = { + 0.8320, 0.8320, 0.8320, 0.8320, + 0.8807, 0.8807, 0.8807, 0.8807, + 0.9168, 0.9168, 0.9168, 0.9168, + 0.6754, 0.6754, 0.6754, 0.6754, + 0.7599, 0.7599, 0.7599, 0.7599, + 0.8273, 0.8273, 0.8273, 0.8273 + }; + std::vector bw_state_out_golden = { + 0.6754, 0.6754, 0.6754, 0.6754, + 0.7599, 0.7599, 0.7599, 0.7599, + 0.8273, 0.8273, 0.8273, 0.8273 + }; + + + EXPECT_TRUE(input_tensor->CopyDataToTensor( + in_data.data(), in_data.size() * sizeof(float))); + EXPECT_TRUE(weights_tensor->CopyDataToTensor( + weights_data.data(), weights_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_weights_tensor->CopyDataToTensor( + recurrent_weights_data.data(), recurrent_weights_data.size() * sizeof(float))); + EXPECT_TRUE(bias_tensor->CopyDataToTensor( + bias_data.data(), bias_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_bias_tensor->CopyDataToTensor( + recurrent_bias_data.data(), recurrent_bias_data.size() * sizeof(float))); + EXPECT_TRUE(state_in_tensor->CopyDataToTensor( + state_in_data.data(), state_in_data.size() * sizeof(float))); + EXPECT_TRUE(bw_weights_tensor->CopyDataToTensor( + weights_data.data(), weights_data.size() * sizeof(float))); + EXPECT_TRUE(bw_recurrent_weights_tensor->CopyDataToTensor( + recurrent_weights_data.data(), recurrent_weights_data.size() * sizeof(float))); + EXPECT_TRUE(bw_bias_tensor->CopyDataToTensor( + bias_data.data(), bias_data.size() * sizeof(float))); + EXPECT_TRUE(bw_recurrent_bias_tensor->CopyDataToTensor( + recurrent_bias_data.data(), recurrent_bias_data.size() * sizeof(float))); + EXPECT_TRUE(bw_state_in_tensor->CopyDataToTensor( + state_in_data.data(), state_in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(tim::vx::ops::BidirectionalSequenceRnn::ActivationType::kSIGMOID, true, false); + (*op).BindInputs({input_tensor, weights_tensor, recurrent_weights_tensor, bias_tensor, recurrent_bias_tensor, state_in_tensor, + bw_weights_tensor, bw_recurrent_weights_tensor, bw_bias_tensor, bw_recurrent_bias_tensor, bw_state_in_tensor}) + .BindOutputs({state_out_tensor, bw_state_out_tensor, output_tensor, bw_output_tensor}); + graph->PrintGraph(); + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output(output_golden.size()); + std::vector state_out(state_out_golden.size()); + std::vector bw_output(output_golden.size()); + std::vector bw_state_out(state_out_golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(state_out_tensor->CopyDataFromTensor(state_out.data())); + EXPECT_TRUE(bw_output_tensor->CopyDataFromTensor(bw_output.data())); + EXPECT_TRUE(bw_state_out_tensor->CopyDataFromTensor(bw_state_out.data())); + + EXPECT_TRUE(ArraysMatch(output_golden, output,1e-3f)); + EXPECT_TRUE(ArraysMatch(state_out_golden, state_out,1e-3f)); + + EXPECT_TRUE(ArraysMatch(bw_output_golden, bw_output,1e-3f)); + EXPECT_TRUE(ArraysMatch(bw_state_out_golden, bw_state_out,1e-3f)); +} + +TEST(BidirectionalSequenceRnn, shape_2_3_4_float_relu) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + uint32_t input_size = 2, batch_size = 3, num_units = 4; + + tim::vx::ShapeType input_shape({input_size, batch_size, 2}); + tim::vx::ShapeType weights_shape({input_size, num_units}); + tim::vx::ShapeType recurrent_weights_shape({num_units, num_units}); + tim::vx::ShapeType bias_shape({num_units}); + tim::vx::ShapeType recurrent_bias_shape({num_units}); + tim::vx::ShapeType state_in_shape({num_units, batch_size}); + tim::vx::ShapeType output_shape({num_units, batch_size, 2}); + tim::vx::ShapeType state_out_shape({num_units, batch_size}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + input_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec weights_spec(tim::vx::DataType::FLOAT32, + weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_weights_spec(tim::vx::DataType::FLOAT32, + recurrent_weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, + bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_bias_spec(tim::vx::DataType::FLOAT32, + recurrent_bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec state_in_spec(tim::vx::DataType::FLOAT32, + state_in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + output_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec state_out_spec(tim::vx::DataType::FLOAT32, + state_out_shape, tim::vx::TensorAttribute::OUTPUT); + + + auto input_tensor = graph->CreateTensor(input_spec); + auto weights_tensor = graph->CreateTensor(weights_spec); + auto recurrent_weights_tensor = graph->CreateTensor(recurrent_weights_spec); + auto bias_tensor = graph->CreateTensor(bias_spec); + auto recurrent_bias_tensor = graph->CreateTensor(recurrent_bias_spec); + auto state_in_tensor = graph->CreateTensor(state_in_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto state_out_tensor = graph->CreateTensor(state_out_spec); + auto bw_weights_tensor = graph->CreateTensor(weights_spec); + auto bw_recurrent_weights_tensor = graph->CreateTensor(recurrent_weights_spec); + auto bw_bias_tensor = graph->CreateTensor(bias_spec); + auto bw_recurrent_bias_tensor = graph->CreateTensor(recurrent_bias_spec); + auto bw_state_in_tensor = graph->CreateTensor(state_in_spec); + auto bw_output_tensor = graph->CreateTensor(output_spec); + auto bw_state_out_tensor = graph->CreateTensor(state_out_spec); + + std::vector in_data = { + 1.0, 2.0, + 3.0, 4.0, + 5.0, 6.0, + 7.0, 8.0, + 9.0, 10.0, + 11.0, 12.0 + }; + std::vector weights_data = { + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1 + }; + std::vector recurrent_weights_data = { + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + }; + std::vector bias_data = { + 0.1, 0.1, 0.1, 0.1 + }; + std::vector recurrent_bias_data = { + 0.0, 0.0, 0.0, 0.0 + }; + std::vector state_in_data = { + 0,0,0,0, + 0,0,0,0, + 0,0,0,0 + }; + std::vector output_golden = { + 0.4, 0.4, 0.4, 0.4, + 0.8, 0.8, 0.8, 0.8, + 1.2, 1.2, 1.2, 1.2, + 1.76, 1.76, 1.76, 1.76, + 2.32, 2.32, 2.32, 2.32, + 2.88, 2.88, 2.88, 2.88, + }; + std::vector state_out_golden = { + 1.76, 1.76, 1.76, 1.76, + 2.32, 2.32, 2.32, 2.32, + 2.88, 2.88, 2.88, 2.88, + }; + std::vector bw_output_golden = { + 1.6, 1.6, 1.6, 1.6, + 2.0, 2.0, 2.0, 2.0, + 2.4, 2.4, 2.4, 2.4, + 1.04, 1.04, 1.04, 1.04, + 1.6, 1.6, 1.6, 1.6, + 2.16, 2.16, 2.16, 2.16, + }; + std::vector bw_state_out_golden = { + 1.04, 1.04, 1.04, 1.04, + 1.6, 1.6, 1.6, 1.6, + 2.16, 2.16, 2.16, 2.16, + }; + + + EXPECT_TRUE(input_tensor->CopyDataToTensor( + in_data.data(), in_data.size() * sizeof(float))); + EXPECT_TRUE(weights_tensor->CopyDataToTensor( + weights_data.data(), weights_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_weights_tensor->CopyDataToTensor( + recurrent_weights_data.data(), recurrent_weights_data.size() * sizeof(float))); + EXPECT_TRUE(bias_tensor->CopyDataToTensor( + bias_data.data(), bias_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_bias_tensor->CopyDataToTensor( + recurrent_bias_data.data(), recurrent_bias_data.size() * sizeof(float))); + EXPECT_TRUE(state_in_tensor->CopyDataToTensor( + state_in_data.data(), state_in_data.size() * sizeof(float))); + EXPECT_TRUE(bw_weights_tensor->CopyDataToTensor( + weights_data.data(), weights_data.size() * sizeof(float))); + EXPECT_TRUE(bw_recurrent_weights_tensor->CopyDataToTensor( + recurrent_weights_data.data(), recurrent_weights_data.size() * sizeof(float))); + EXPECT_TRUE(bw_bias_tensor->CopyDataToTensor( + bias_data.data(), bias_data.size() * sizeof(float))); + EXPECT_TRUE(bw_recurrent_bias_tensor->CopyDataToTensor( + recurrent_bias_data.data(), recurrent_bias_data.size() * sizeof(float))); + EXPECT_TRUE(bw_state_in_tensor->CopyDataToTensor( + state_in_data.data(), state_in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(tim::vx::ops::BidirectionalSequenceRnn::ActivationType::kRELU, true, false); + (*op).BindInputs({input_tensor, weights_tensor, recurrent_weights_tensor, bias_tensor, recurrent_bias_tensor, state_in_tensor, + bw_weights_tensor, bw_recurrent_weights_tensor, bw_bias_tensor, bw_recurrent_bias_tensor, bw_state_in_tensor}) + .BindOutputs({state_out_tensor, bw_state_out_tensor, output_tensor, bw_output_tensor}); + graph->PrintGraph(); + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output(output_golden.size()); + std::vector state_out(state_out_golden.size()); + std::vector bw_output(output_golden.size()); + std::vector bw_state_out(state_out_golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(state_out_tensor->CopyDataFromTensor(state_out.data())); + EXPECT_TRUE(bw_output_tensor->CopyDataFromTensor(bw_output.data())); + EXPECT_TRUE(bw_state_out_tensor->CopyDataFromTensor(bw_state_out.data())); + + EXPECT_TRUE(ArraysMatch(output_golden, output,1e-3f)); + EXPECT_TRUE(ArraysMatch(state_out_golden, state_out,1e-3f)); + + EXPECT_TRUE(ArraysMatch(bw_output_golden, bw_output,1e-3f)); + EXPECT_TRUE(ArraysMatch(bw_state_out_golden, bw_state_out,1e-3f)); +} \ No newline at end of file diff --git a/src/tim/vx/ops/unidirectional_sequence_rnn.cc b/src/tim/vx/ops/unidirectional_sequence_rnn.cc new file mode 100644 index 0000000..fd26ce1 --- /dev/null +++ b/src/tim/vx/ops/unidirectional_sequence_rnn.cc @@ -0,0 +1,76 @@ +/**************************************************************************** +* +* Copyright (c) 2020 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/ops/unidirectional_sequence_rnn.h" + +#include "builtin_op_impl.h" +#include "vsi_nn_pub.h" + +namespace tim { +namespace vx { +namespace ops { + +vsi_nn_activation_e downcast_act_type(UnidirectionalSequenceRnn::ActivationType act) { + switch (act) { + case UnidirectionalSequenceRnn::ActivationType::kRELU: + return VSI_NN_ACT_RELU; + case UnidirectionalSequenceRnn::ActivationType::kRELU1: + return VSI_NN_ACT_RELU1; + case UnidirectionalSequenceRnn::ActivationType::kRELU6: + return VSI_NN_ACT_RELU6; + case UnidirectionalSequenceRnn::ActivationType::kTANH: + return VSI_NN_ACT_TANH; + case UnidirectionalSequenceRnn::ActivationType::kSIGMOID: + return VSI_NN_ACT_SIGMOID; + case UnidirectionalSequenceRnn::ActivationType::kHARDSIGMOID: + return VSI_NN_ACT_HARD_SIGMOID; + default: { + VSILOGW("Not supported activition type for RNN = %d", static_cast(act)); + return VSI_NN_ACT_NONE; + } + } +} + +UnidirectionalSequenceRnn::UnidirectionalSequenceRnn( + Graph* graph, + ActivationType act_type, + bool time_major) + : DirectMapOp(graph, VSI_NN_OP_UNIDIRECTIONAL_SEQUENCE_RNN), + act_type_(act_type) { + + this->impl()->node()->nn_param.unidirectional_sequence_rnn.time_major = time_major; + this->impl()->node()->nn_param.unidirectional_sequence_rnn.activation = + downcast_act_type(act_type); +} + +std::shared_ptr UnidirectionalSequenceRnn::Clone(std::shared_ptr& graph) const { + auto cloned_op = + graph->CreateOperation( + act_type_, + this->impl()->node()->nn_param.unidirectional_sequence_rnn.time_major); + return cloned_op; +} + +} +} // namespace vx +} // namespace tim diff --git a/src/tim/vx/ops/unidirectional_sequence_rnn_ext.cc b/src/tim/vx/ops/unidirectional_sequence_rnn_ext.cc new file mode 100644 index 0000000..657051b --- /dev/null +++ b/src/tim/vx/ops/unidirectional_sequence_rnn_ext.cc @@ -0,0 +1,193 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/ops.h" +#include "vsi_nn_pub.h" +#include "op_impl.h" + +#include +namespace tim { +namespace vx { +namespace ops { + +class UnidirectionalSequenceRnnExtImpl : public OpImpl { + public: + enum { + // signature + RNN_EXT_INPUT_INPUT = 0, + RNN_EXT_INPUT_WEIGHT_I = 1, + RNN_EXT_INPUT_WEIGHT_H = 2, + RNN_EXT_INPUT_BIAS = 3, + RNN_EXT_INPUT_H_STATE = 4, + RNN_EXT_INPUT_CNT, + + RNN_EXT_OUTPUT_H_STATE = 0, + RNN_EXT_OUTPUT_OUTPUT = 1, + RNN_EXT_OUT_CNT, + // signature end + }; + + UnidirectionalSequenceRnnExtImpl(Graph* graph, tim::vx::ops::UnidirectionalSequenceRnn::ActivationType act_type, + DataLayout layout = DataLayout::ANY) + : OpImpl(graph, layout), + act_type_(act_type) { + + } + + ~UnidirectionalSequenceRnnExtImpl() {} + + UnidirectionalSequenceRnnExtImpl& BindInput(const std::shared_ptr& tensor) override { + in_tensors_[input_tensor_index] = tensor; + + if (this->input_tensor_index == RNN_EXT_INPUT_CNT - 1) { + tim::vx::DataType datatype = in_tensors_[RNN_EXT_INPUT_WEIGHT_I]->GetDataType(); + uint32_t input_size = in_tensors_[RNN_EXT_INPUT_WEIGHT_I]->GetShape()[0]; + uint32_t num_units = in_tensors_[RNN_EXT_INPUT_WEIGHT_I]->GetShape()[1]; + uint32_t batch_size = in_tensors_[RNN_EXT_INPUT_INPUT]->GetShape()[1]; + uint32_t seq_length = in_tensors_[RNN_EXT_INPUT_INPUT]->GetShape()[2]; + + + // Get all tensor + tim::vx::ShapeType input_weight_i_shape = {input_size, num_units}; + tim::vx::ShapeType input_weight_h_shape = {num_units, num_units}; + tim::vx::ShapeType input_bias_shape = {2*num_units}; + tim::vx::ShapeType input_bias_i_shape = {num_units}; + tim::vx::ShapeType input_bias_h_shape = {num_units}; + tim::vx::ShapeType input_hstate_shape = {num_units, batch_size}; + tim::vx::ShapeType output_shape = {num_units, batch_size, seq_length}; + tim::vx::ShapeType output_hstate_shape = {num_units, batch_size}; + tim::vx::ShapeType ext_output_shape = {num_units, 1, batch_size, seq_length}; + tim::vx::ShapeType ext_output_hstate_shape = {num_units, batch_size, 1}; + + + + tim::vx::TensorSpec input_weight_i_spec(datatype, input_weight_i_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_weight_h_spec(datatype, input_weight_h_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_bias_spec(datatype, input_bias_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_bias_i_spec(datatype, input_bias_i_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_bias_h_spec(datatype, input_bias_h_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec input_hstate_spec(datatype, input_hstate_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec output_spec(datatype, output_shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec output_hstate_spec(datatype, output_hstate_shape, + tim::vx::TensorAttribute::TRANSIENT); + + auto input_weight_i_tensor = graph_->CreateTensor(input_weight_i_spec); + auto input_weight_h_tensor = graph_->CreateTensor(input_weight_h_spec); + auto input_bias_tensor = graph_->CreateTensor(input_bias_spec); + auto input_bias_i_tensor = graph_->CreateTensor(input_bias_i_spec); + auto input_bias_h_tensor = graph_->CreateTensor(input_bias_h_spec); + auto input_hstate_tensor = graph_->CreateTensor(input_hstate_spec); + auto output_tensor = graph_->CreateTensor(output_spec); + auto output_hstate_tensor = graph_->CreateTensor(output_hstate_spec); + + reshape_weight_ = graph_->CreateOperation(input_weight_i_shape); + reshape_recurrent_ = graph_->CreateOperation(input_weight_h_shape); + reshape_bias_ = graph_->CreateOperation(input_bias_shape); + std::vector slices = {num_units, num_units}; + split_ = graph_->CreateOperation(0, slices); + reshape_hstate_ = graph_->CreateOperation(input_hstate_shape); + rnn_ = graph_->CreateOperation(act_type_, true); + reshape_out_ = graph_->CreateOperation(ext_output_shape); + reshape_out_hstate_ = graph_->CreateOperation(ext_output_hstate_shape); + + + reshape_weight_->BindInput(in_tensors_[RNN_EXT_INPUT_WEIGHT_I]); + reshape_weight_->BindOutput(input_weight_i_tensor); + + reshape_recurrent_->BindInput(in_tensors_[RNN_EXT_INPUT_WEIGHT_H]); + reshape_recurrent_->BindOutput(input_weight_h_tensor); + + reshape_bias_->BindInput(in_tensors_[RNN_EXT_INPUT_BIAS]); + reshape_bias_->BindOutput(input_bias_tensor); + split_->BindInput(input_bias_tensor); + split_->BindOutput(input_bias_i_tensor); + split_->BindOutput(input_bias_h_tensor); + + reshape_hstate_->BindInput(in_tensors_[RNN_EXT_INPUT_H_STATE]); + reshape_hstate_->BindOutput(input_hstate_tensor); + + rnn_->BindInputs({in_tensors_[RNN_EXT_INPUT_INPUT], input_weight_i_tensor, input_weight_h_tensor, input_bias_i_tensor, input_bias_h_tensor, input_hstate_tensor}); + rnn_->BindOutputs({output_hstate_tensor, output_tensor}); + + reshape_out_->BindInput(output_tensor); + reshape_out_hstate_->BindInput(output_hstate_tensor); + + } + this->input_tensor_index++; + return *this; + } + + UnidirectionalSequenceRnnExtImpl& BindOutput(const std::shared_ptr& tensor) override { + out_tensors_[output_tensor_index] = tensor; + + if (this->output_tensor_index == RNN_EXT_OUT_CNT - 1) { + reshape_out_->BindOutput(out_tensors_[RNN_EXT_OUTPUT_OUTPUT]); + reshape_out_hstate_->BindOutput(out_tensors_[RNN_EXT_OUTPUT_H_STATE]); + } + this->output_tensor_index++; + return *this; + } + + vsi_nn_node_t* node() override { return nullptr; } + + std::vector> InputsTensor() override { + return inputs_tensor_; + } + std::vector> OutputsTensor() override { + return outputs_tensor_; + } + + private: + tim::vx::ops::UnidirectionalSequenceRnn::ActivationType act_type_; + std::shared_ptr reshape_weight_; + std::shared_ptr reshape_recurrent_; + std::shared_ptr reshape_bias_; + std::shared_ptr split_; + std::shared_ptr reshape_hstate_; + std::shared_ptr reshape_out_; + std::shared_ptr reshape_out_hstate_; + std::shared_ptr rnn_; + + std::array, RNN_EXT_INPUT_CNT> in_tensors_; + std::array, RNN_EXT_OUT_CNT> out_tensors_; +}; + +UnidirectionalSequenceRnnExt::UnidirectionalSequenceRnnExt(Graph* graph, tim::vx::ops::UnidirectionalSequenceRnn::ActivationType act_type) + : act_type_(act_type) { + impl_ = std::make_unique(graph, act_type, DataLayout::ANY); +} + +std::shared_ptr UnidirectionalSequenceRnnExt::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->act_type_); +} + +} // namespace ops +} // namespace vx +} // namespace tim diff --git a/src/tim/vx/ops/unidirectional_sequence_rnn_ext_test.cc b/src/tim/vx/ops/unidirectional_sequence_rnn_ext_test.cc new file mode 100644 index 0000000..029a11d --- /dev/null +++ b/src/tim/vx/ops/unidirectional_sequence_rnn_ext_test.cc @@ -0,0 +1,245 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops.h" +#include "tim/vx/types.h" +#include "test_utils.h" +#include "gtest/gtest.h" + +TEST(UnidirectionalSequenceRnnExt, shape_2_3_4_float_sigmoid) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + uint32_t input_size = 2, batch_size = 3, num_units = 4; + + tim::vx::ShapeType input_shape({input_size, batch_size, 2}); + tim::vx::ShapeType weights_shape({input_size, num_units, 1}); + tim::vx::ShapeType recurrent_weights_shape({num_units, num_units, 1}); + tim::vx::ShapeType bias_shape({num_units*2, 1}); + tim::vx::ShapeType state_in_shape({num_units, batch_size, 1}); + tim::vx::ShapeType output_shape({num_units, batch_size, 1, 2}); + tim::vx::ShapeType state_out_shape({num_units, batch_size, 1}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + input_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec weights_spec(tim::vx::DataType::FLOAT32, + weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_weights_spec(tim::vx::DataType::FLOAT32, + recurrent_weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, + bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec state_in_spec(tim::vx::DataType::FLOAT32, + state_in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + output_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec state_out_spec(tim::vx::DataType::FLOAT32, + state_out_shape, tim::vx::TensorAttribute::OUTPUT); + + + auto input_tensor = graph->CreateTensor(input_spec); + auto weights_tensor = graph->CreateTensor(weights_spec); + auto recurrent_weights_tensor = graph->CreateTensor(recurrent_weights_spec); + auto bias_tensor = graph->CreateTensor(bias_spec); + auto state_in_tensor = graph->CreateTensor(state_in_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto state_out_tensor = graph->CreateTensor(state_out_spec); + + std::vector in_data = { + 1.0, 2.0, + 3.0, 4.0, + 5.0, 6.0, + 7.0, 8.0, + 9.0, 10.0, + 11.0, 12.0 + }; + std::vector weights_data = { + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1 + }; + std::vector recurrent_weights_data = { + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + }; + std::vector bias_data = { + 0.1, 0.1, 0.1, 0.1, + 0.0, 0.0, 0.0, 0.0 + }; + std::vector state_in_data = { + 0,0,0,0, + 0,0,0,0, + 0,0,0,0 + }; + std::vector output_golden = { + 0.5986, 0.5986, 0.5986, 0.5986, + 0.6899, 0.6899, 0.6899, 0.6899, + 0.7685, 0.7685, 0.7685, 0.7685, + 0.8628, 0.8628, 0.8628, 0.8628, + 0.9068, 0.9068, 0.9068, 0.9068, + 0.9374, 0.9374, 0.9374, 0.9374, + }; + std::vector state_out_golden = { + 0.8628, 0.8628, 0.8628, 0.8628, + 0.9068, 0.9068, 0.9068, 0.9068, + 0.9374, 0.9374, 0.9374, 0.9374, + }; + + + EXPECT_TRUE(input_tensor->CopyDataToTensor( + in_data.data(), in_data.size() * sizeof(float))); + EXPECT_TRUE(weights_tensor->CopyDataToTensor( + weights_data.data(), weights_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_weights_tensor->CopyDataToTensor( + recurrent_weights_data.data(), recurrent_weights_data.size() * sizeof(float))); + EXPECT_TRUE(bias_tensor->CopyDataToTensor( + bias_data.data(), bias_data.size() * sizeof(float))); + EXPECT_TRUE(state_in_tensor->CopyDataToTensor( + state_in_data.data(), state_in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(tim::vx::ops::UnidirectionalSequenceRnn::ActivationType::kSIGMOID); + (*op).BindInputs({input_tensor, weights_tensor, recurrent_weights_tensor, bias_tensor, state_in_tensor}) + .BindOutputs({state_out_tensor, output_tensor}); + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output(output_golden.size()); + std::vector state_out(state_out_golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(state_out_tensor->CopyDataFromTensor(state_out.data())); + + EXPECT_TRUE(ArraysMatch(output_golden, output,1e-3f)); + EXPECT_TRUE(ArraysMatch(state_out_golden, state_out,1e-3f)); +} + + + +TEST(UnidirectionalSequenceRnnExt, shape_2_3_4_float_relu) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + uint32_t input_size = 2, batch_size = 3, num_units = 4; + + tim::vx::ShapeType input_shape({input_size, batch_size, 2}); + tim::vx::ShapeType weights_shape({input_size, num_units, 1}); + tim::vx::ShapeType recurrent_weights_shape({num_units, num_units, 1}); + tim::vx::ShapeType bias_shape({num_units*2, 1}); + tim::vx::ShapeType state_in_shape({num_units, batch_size, 1}); + tim::vx::ShapeType output_shape({num_units, batch_size, 1, 2}); + tim::vx::ShapeType state_out_shape({num_units, batch_size, 1}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + input_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec weights_spec(tim::vx::DataType::FLOAT32, + weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_weights_spec(tim::vx::DataType::FLOAT32, + recurrent_weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, + bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec state_in_spec(tim::vx::DataType::FLOAT32, + state_in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + output_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec state_out_spec(tim::vx::DataType::FLOAT32, + state_out_shape, tim::vx::TensorAttribute::OUTPUT); + + + auto input_tensor = graph->CreateTensor(input_spec); + auto weights_tensor = graph->CreateTensor(weights_spec); + auto recurrent_weights_tensor = graph->CreateTensor(recurrent_weights_spec); + auto bias_tensor = graph->CreateTensor(bias_spec); + auto state_in_tensor = graph->CreateTensor(state_in_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto state_out_tensor = graph->CreateTensor(state_out_spec); + + std::vector in_data = { + 1.0, 2.0, + 3.0, 4.0, + 5.0, 6.0, + 7.0, 8.0, + 9.0, 10.0, + 11.0, 12.0 + }; + std::vector weights_data = { + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1 + }; + std::vector recurrent_weights_data = { + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + }; + std::vector bias_data = { + 0.1, 0.1, 0.1, 0.1, + 0.0, 0.0, 0.0, 0.0 + }; + std::vector state_in_data = { + 0,0,0,0, + 0,0,0,0, + 0,0,0,0 + }; + std::vector output_golden = { + 0.4, 0.4, 0.4, 0.4, + 0.8, 0.8, 0.8, 0.8, + 1.2, 1.2, 1.2, 1.2, + 1.76, 1.76, 1.76, 1.76, + 2.32, 2.32, 2.32, 2.32, + 2.88, 2.88, 2.88, 2.88 + }; + std::vector state_out_golden = { + 1.76, 1.76, 1.76, 1.76, + 2.32, 2.32, 2.32, 2.32, + 2.88, 2.88, 2.88, 2.88 + }; + + + EXPECT_TRUE(input_tensor->CopyDataToTensor( + in_data.data(), in_data.size() * sizeof(float))); + EXPECT_TRUE(weights_tensor->CopyDataToTensor( + weights_data.data(), weights_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_weights_tensor->CopyDataToTensor( + recurrent_weights_data.data(), recurrent_weights_data.size() * sizeof(float))); + EXPECT_TRUE(bias_tensor->CopyDataToTensor( + bias_data.data(), bias_data.size() * sizeof(float))); + EXPECT_TRUE(state_in_tensor->CopyDataToTensor( + state_in_data.data(), state_in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(tim::vx::ops::UnidirectionalSequenceRnn::ActivationType::kRELU); + (*op).BindInputs({input_tensor, weights_tensor, recurrent_weights_tensor, bias_tensor, state_in_tensor}) + .BindOutputs({state_out_tensor, output_tensor}); + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output(output_golden.size()); + std::vector state_out(state_out_golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(state_out_tensor->CopyDataFromTensor(state_out.data())); + + EXPECT_TRUE(ArraysMatch(output_golden, output,1e-3f)); + EXPECT_TRUE(ArraysMatch(state_out_golden, state_out,1e-3f)); +} diff --git a/src/tim/vx/ops/unidirectional_sequence_rnn_test.cc b/src/tim/vx/ops/unidirectional_sequence_rnn_test.cc new file mode 100644 index 0000000..f18e2f7 --- /dev/null +++ b/src/tim/vx/ops/unidirectional_sequence_rnn_test.cc @@ -0,0 +1,378 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops.h" +#include "tim/vx/types.h" +#include "test_utils.h" +#include "gtest/gtest.h" + +TEST(UnidirectionalSequenceRnn, shape_2_3_4_float_sigmoid) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + uint32_t input_size = 2, batch_size = 3, num_units = 4; + + tim::vx::ShapeType input_shape({input_size, batch_size, 2}); + tim::vx::ShapeType weights_shape({input_size, num_units}); + tim::vx::ShapeType recurrent_weights_shape({num_units, num_units}); + tim::vx::ShapeType bias_shape({num_units}); + tim::vx::ShapeType recurrent_bias_shape({num_units}); + tim::vx::ShapeType state_in_shape({num_units, batch_size}); + tim::vx::ShapeType output_shape({num_units, batch_size, 2}); + tim::vx::ShapeType state_out_shape({num_units, batch_size}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + input_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec weights_spec(tim::vx::DataType::FLOAT32, + weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_weights_spec(tim::vx::DataType::FLOAT32, + recurrent_weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, + bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_bias_spec(tim::vx::DataType::FLOAT32, + recurrent_bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec state_in_spec(tim::vx::DataType::FLOAT32, + state_in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + output_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec state_out_spec(tim::vx::DataType::FLOAT32, + state_out_shape, tim::vx::TensorAttribute::OUTPUT); + + + auto input_tensor = graph->CreateTensor(input_spec); + auto weights_tensor = graph->CreateTensor(weights_spec); + auto recurrent_weights_tensor = graph->CreateTensor(recurrent_weights_spec); + auto bias_tensor = graph->CreateTensor(bias_spec); + auto recurrent_bias_tensor = graph->CreateTensor(recurrent_bias_spec); + auto state_in_tensor = graph->CreateTensor(state_in_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto state_out_tensor = graph->CreateTensor(state_out_spec); + + std::vector in_data = { + 1.0, 2.0, + 3.0, 4.0, + 5.0, 6.0, + 7.0, 8.0, + 9.0, 10.0, + 11.0, 12.0 + }; + std::vector weights_data = { + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1 + }; + std::vector recurrent_weights_data = { + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + }; + std::vector bias_data = { + 0.1, 0.1, 0.1, 0.1 + }; + std::vector recurrent_bias_data = { + 0.0, 0.0, 0.0, 0.0 + }; + std::vector state_in_data = { + 0,0,0,0, + 0,0,0,0, + 0,0,0,0 + }; + std::vector output_golden = { + 0.5986, 0.5986, 0.5986, 0.5986, + 0.6899, 0.6899, 0.6899, 0.6899, + 0.7685, 0.7685, 0.7685, 0.7685, + 0.8628, 0.8628, 0.8628, 0.8628, + 0.9068, 0.9068, 0.9068, 0.9068, + 0.9374, 0.9374, 0.9374, 0.9374, + }; + std::vector state_out_golden = { + 0.8628, 0.8628, 0.8628, 0.8628, + 0.9068, 0.9068, 0.9068, 0.9068, + 0.9374, 0.9374, 0.9374, 0.9374, + }; + + + EXPECT_TRUE(input_tensor->CopyDataToTensor( + in_data.data(), in_data.size() * sizeof(float))); + EXPECT_TRUE(weights_tensor->CopyDataToTensor( + weights_data.data(), weights_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_weights_tensor->CopyDataToTensor( + recurrent_weights_data.data(), recurrent_weights_data.size() * sizeof(float))); + EXPECT_TRUE(bias_tensor->CopyDataToTensor( + bias_data.data(), bias_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_bias_tensor->CopyDataToTensor( + recurrent_bias_data.data(), recurrent_bias_data.size() * sizeof(float))); + EXPECT_TRUE(state_in_tensor->CopyDataToTensor( + state_in_data.data(), state_in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(tim::vx::ops::UnidirectionalSequenceRnn::ActivationType::kSIGMOID, true); + (*op).BindInputs({input_tensor, weights_tensor, recurrent_weights_tensor, bias_tensor, recurrent_bias_tensor, state_in_tensor}) + .BindOutputs({state_out_tensor, output_tensor}); + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output(output_golden.size()); + std::vector state_out(state_out_golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(state_out_tensor->CopyDataFromTensor(state_out.data())); + + EXPECT_TRUE(ArraysMatch(output_golden, output,1e-3f)); + EXPECT_TRUE(ArraysMatch(state_out_golden, state_out,1e-3f)); +} + + +TEST(UnidirectionalSequenceRnn, shape_2_3_4_float_relu) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + uint32_t input_size = 2, batch_size = 3, num_units = 4; + + tim::vx::ShapeType input_shape({input_size, batch_size, 2}); + tim::vx::ShapeType weights_shape({input_size, num_units}); + tim::vx::ShapeType recurrent_weights_shape({num_units, num_units}); + tim::vx::ShapeType bias_shape({num_units}); + tim::vx::ShapeType recurrent_bias_shape({num_units}); + tim::vx::ShapeType state_in_shape({num_units, batch_size}); + tim::vx::ShapeType output_shape({num_units, batch_size, 2}); + tim::vx::ShapeType state_out_shape({num_units, batch_size}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + input_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec weights_spec(tim::vx::DataType::FLOAT32, + weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_weights_spec(tim::vx::DataType::FLOAT32, + recurrent_weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, + bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_bias_spec(tim::vx::DataType::FLOAT32, + recurrent_bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec state_in_spec(tim::vx::DataType::FLOAT32, + state_in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + output_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec state_out_spec(tim::vx::DataType::FLOAT32, + state_out_shape, tim::vx::TensorAttribute::OUTPUT); + + + auto input_tensor = graph->CreateTensor(input_spec); + auto weights_tensor = graph->CreateTensor(weights_spec); + auto recurrent_weights_tensor = graph->CreateTensor(recurrent_weights_spec); + auto bias_tensor = graph->CreateTensor(bias_spec); + auto recurrent_bias_tensor = graph->CreateTensor(recurrent_bias_spec); + auto state_in_tensor = graph->CreateTensor(state_in_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto state_out_tensor = graph->CreateTensor(state_out_spec); + + std::vector in_data = { + 1.0, 2.0, + 3.0, 4.0, + 5.0, 6.0, + 7.0, 8.0, + 9.0, 10.0, + 11.0, 12.0 + }; + std::vector weights_data = { + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1 + }; + std::vector recurrent_weights_data = { + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + }; + std::vector bias_data = { + 0.1, 0.1, 0.1, 0.1 + }; + std::vector recurrent_bias_data = { + 0.0, 0.0, 0.0, 0.0 + }; + std::vector state_in_data = { + 0,0,0,0, + 0,0,0,0, + 0,0,0,0 + }; + std::vector output_golden = { + 0.4, 0.4, 0.4, 0.4, + 0.8, 0.8, 0.8, 0.8, + 1.2, 1.2, 1.2, 1.2, + 1.76, 1.76, 1.76, 1.76, + 2.32, 2.32, 2.32, 2.32, + 2.88, 2.88, 2.88, 2.88 + }; + std::vector state_out_golden = { + 1.76, 1.76, 1.76, 1.76, + 2.32, 2.32, 2.32, 2.32, + 2.88, 2.88, 2.88, 2.88 + }; + + + EXPECT_TRUE(input_tensor->CopyDataToTensor( + in_data.data(), in_data.size() * sizeof(float))); + EXPECT_TRUE(weights_tensor->CopyDataToTensor( + weights_data.data(), weights_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_weights_tensor->CopyDataToTensor( + recurrent_weights_data.data(), recurrent_weights_data.size() * sizeof(float))); + EXPECT_TRUE(bias_tensor->CopyDataToTensor( + bias_data.data(), bias_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_bias_tensor->CopyDataToTensor( + recurrent_bias_data.data(), recurrent_bias_data.size() * sizeof(float))); + EXPECT_TRUE(state_in_tensor->CopyDataToTensor( + state_in_data.data(), state_in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(tim::vx::ops::UnidirectionalSequenceRnn::ActivationType::kRELU, true); + (*op).BindInputs({input_tensor, weights_tensor, recurrent_weights_tensor, bias_tensor, recurrent_bias_tensor, state_in_tensor}) + .BindOutputs({state_out_tensor, output_tensor}); + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output(output_golden.size()); + std::vector state_out(state_out_golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(state_out_tensor->CopyDataFromTensor(state_out.data())); + + EXPECT_TRUE(ArraysMatch(output_golden, output,1e-3f)); + EXPECT_TRUE(ArraysMatch(state_out_golden, state_out,1e-3f)); +} + + +/* +TEST(UnidirectionalSequenceRnn, shape_2_3_4_float_tanh) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + uint32_t input_size = 2, batch_size = 3, num_units = 4; + + tim::vx::ShapeType input_shape({input_size, batch_size, 2}); + tim::vx::ShapeType weights_shape({input_size, num_units}); + tim::vx::ShapeType recurrent_weights_shape({num_units, num_units}); + tim::vx::ShapeType bias_shape({num_units}); + tim::vx::ShapeType recurrent_bias_shape({num_units}); + tim::vx::ShapeType state_in_shape({num_units, batch_size}); + tim::vx::ShapeType output_shape({num_units, batch_size, 2}); + tim::vx::ShapeType state_out_shape({num_units, batch_size}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + input_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec weights_spec(tim::vx::DataType::FLOAT32, + weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_weights_spec(tim::vx::DataType::FLOAT32, + recurrent_weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, + bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_bias_spec(tim::vx::DataType::FLOAT32, + recurrent_bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec state_in_spec(tim::vx::DataType::FLOAT32, + state_in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + output_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec state_out_spec(tim::vx::DataType::FLOAT32, + state_out_shape, tim::vx::TensorAttribute::OUTPUT); + + + auto input_tensor = graph->CreateTensor(input_spec); + auto weights_tensor = graph->CreateTensor(weights_spec); + auto recurrent_weights_tensor = graph->CreateTensor(recurrent_weights_spec); + auto bias_tensor = graph->CreateTensor(bias_spec); + auto recurrent_bias_tensor = graph->CreateTensor(recurrent_bias_spec); + auto state_in_tensor = graph->CreateTensor(state_in_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto state_out_tensor = graph->CreateTensor(state_out_spec); + + std::vector in_data = { + 1.0, 2.0, + 3.0, 4.0, + 5.0, 6.0, + 7.0, 8.0, + 9.0, 10.0, + 11.0, 12.0 + }; + std::vector weights_data = { + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1, + 0.1, 0.1 + }; + std::vector recurrent_weights_data = { + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, + }; + std::vector bias_data = { + 0.1, 0.1, 0.1, 0.1 + }; + std::vector recurrent_bias_data = { + 0.0, 0.0, 0.0, 0.0 + }; + std::vector state_in_data = { + 0,0,0,0, + 0,0,0,0, + 0,0,0,0 + }; + std::vector output_golden = { + 0.2913, 0.2913, 0.2913, 0.2913, + 0.6043, 0.6043, 0.6043, 0.6043, + 0.8004, 0.8004, 0.8004, 0.8004, + 0.9416, 0.9416, 0.9416, 0.9416, + 0.9786, 0.9786, 0.9786, 0.9786, + 0.9915, 0.9915, 0.9915, 0.9915 + }; + std::vector state_out_golden = { + 0.9416, 0.9416, 0.9416, 0.9416, + 0.9786, 0.9786, 0.9786, 0.9786, + 0.9915, 0.9915, 0.9915, 0.9915 + }; + + + EXPECT_TRUE(input_tensor->CopyDataToTensor( + in_data.data(), in_data.size() * sizeof(float))); + EXPECT_TRUE(weights_tensor->CopyDataToTensor( + weights_data.data(), weights_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_weights_tensor->CopyDataToTensor( + recurrent_weights_data.data(), recurrent_weights_data.size() * sizeof(float))); + EXPECT_TRUE(bias_tensor->CopyDataToTensor( + bias_data.data(), bias_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_bias_tensor->CopyDataToTensor( + recurrent_bias_data.data(), recurrent_bias_data.size() * sizeof(float))); + EXPECT_TRUE(state_in_tensor->CopyDataToTensor( + state_in_data.data(), state_in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(tim::vx::ops::UnidirectionalSequenceRnn::ActivationType::kTANH, true); + (*op).BindInputs({input_tensor, weights_tensor, recurrent_weights_tensor, bias_tensor, recurrent_bias_tensor, state_in_tensor}) + .BindOutputs({state_out_tensor, output_tensor}); + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + std::vector output(output_golden.size()); + std::vector state_out(state_out_golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(state_out_tensor->CopyDataFromTensor(state_out.data())); + + EXPECT_TRUE(ArraysMatch(output_golden, output,1e-3f)); + EXPECT_TRUE(ArraysMatch(state_out_golden, state_out,1e-3f)); +} +*/ \ No newline at end of file