/**************************************************************************** * * Copyright (c) 2022 Vivante Corporation * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), * to deal in the Software without restriction, including without limitation * the rights to use, copy, modify, merge, publish, distribute, sublicense, * and/or sell copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ #include "tim/vx/context.h" #include "tim/vx/graph.h" #include "tim/vx/ops/bidirectional_sequence_lstm.h" #include "gtest/gtest.h" #include "test_utils.h" std::shared_ptr make_empty_tensor( std::shared_ptr graph, const tim::vx::ShapeType& shape, const tim::vx::TensorAttribute& role); //, const float& default_value) TEST(Bidirectional_LSTM_CELL, shape_in_2_cell_4_out_4_float32) { // NoCifg_NoPeephole_NoProjection_NoLayerNorm auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); uint32_t n_batch, n_step, n_cell, n_input, n_output; n_batch = 1, n_step = 3, n_cell = 4, n_input = 2, n_output = 4; tim::vx::ShapeType input_shape, cell_shape, state_shape; input_shape = {n_batch, n_step, n_input}; // non-time-major tim::vx::TensorSpec lstm_input_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_step, n_batch}), tim::vx::TensorAttribute::INPUT); tim::vx::TensorSpec fw_weight_i2i_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec fw_weight_i2f_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec fw_weight_i2c_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec fw_weight_i2o_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec fw_weight_r2i_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec fw_weight_r2f_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec fw_weight_r2c_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec fw_weight_r2o_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec fw_bias_i_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec fw_bias_f_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec fw_bias_c_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec fw_bias_o_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec bw_weight_i2i_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec bw_weight_i2f_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec bw_weight_i2c_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec bw_weight_i2o_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_input, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec bw_weight_r2i_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec bw_weight_r2f_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec bw_weight_r2c_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec bw_weight_r2o_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec bw_bias_i_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec bw_bias_f_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec bw_bias_c_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec bw_bias_o_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_cell}), tim::vx::TensorAttribute::CONSTANT); tim::vx::TensorSpec fw_output_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_step, n_batch}), tim::vx::TensorAttribute::OUTPUT); tim::vx::TensorSpec bw_output_spec(tim::vx::DataType::FLOAT32, tim::vx::ShapeType({n_output, n_step, n_batch}), tim::vx::TensorAttribute::OUTPUT); auto lstm_input = graph->CreateTensor(lstm_input_spec); std::vector lstm_input_data = {2., 3., 3., 4., 1., 1.}; lstm_input->CopyDataToTensor(lstm_input_data.data(), lstm_input_data.size() * 4); auto fw_output_tensor = graph->CreateTensor(fw_output_spec); auto bw_output_tensor = graph->CreateTensor(bw_output_spec); std::vector fw_weight_i2i = {-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, -0.34856534, 0.43890524}; std::vector fw_weight_i2f = {0.09701663, 0.20334584, -0.50592935, -0.31343272, -0.40032279, 0.44781327, 0.01387155, -0.35593212}; std::vector fw_weight_i2c = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, -0.20583314, 0.44344562, 0.22077113, -0.29909778}; std::vector fw_weight_i2o = {-0.25065863, -0.28290087, 0.04613829, 0.40525138, 0.44272184, 0.03897077, -0.1556896, 0.19487578}; auto fw_weight_i2i_tensor = graph->CreateTensor(fw_weight_i2i_spec, fw_weight_i2i.data()); auto fw_weight_i2f_tensor = graph->CreateTensor(fw_weight_i2f_spec, fw_weight_i2f.data()); auto fw_weight_i2c_tensor = graph->CreateTensor(fw_weight_i2c_spec, fw_weight_i2c.data()); auto fw_weight_i2o_tensor = graph->CreateTensor(fw_weight_i2o_spec, fw_weight_i2o.data()); std::vector fw_weight_r2i = { -0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}; std::vector fw_weight_r2f = { -0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, 0.28053468, 0.01560611, -0.20127171, -0.01140004}; std::vector fw_weight_r2c = { -0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, -0.46367589, 0.26016325, -0.03894562, -0.16368064}; std::vector fw_weight_r2o = { 0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, -0.51818722, -0.15390486, 0.0468148, 0.39922136}; auto fw_weight_r2i_tensor = graph->CreateTensor(fw_weight_r2i_spec, fw_weight_r2i.data()); auto fw_weight_r2f_tensor = graph->CreateTensor(fw_weight_r2f_spec, fw_weight_r2f.data()); auto fw_weight_r2c_tensor = graph->CreateTensor(fw_weight_r2c_spec, fw_weight_r2c.data()); auto fw_weight_r2o_tensor = graph->CreateTensor(fw_weight_r2o_spec, fw_weight_r2o.data()); std::vector fw_bias_i = {0.0, 0.0, 0.0, 0.0}; std::vector fw_bias_f = {1., 1., 1., 1.}; std::vector fw_bias_c = {0.0, 0.0, 0.0, 0.0}; std::vector fw_bias_o = {0.0, 0.0, 0.0, 0.0}; auto fw_bias_i_tensor = graph->CreateTensor(fw_bias_i_spec, fw_bias_i.data()); auto fw_bias_f_tensor = graph->CreateTensor(fw_bias_f_spec, fw_bias_f.data()); auto fw_bias_c_tensor = graph->CreateTensor(fw_bias_c_spec, fw_bias_c.data()); auto fw_bias_o_tensor = graph->CreateTensor(fw_bias_o_spec, fw_bias_o.data()); std::vector bw_weight_i2i = {-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, -0.34856534, 0.43890524}; std::vector bw_weight_i2f = {0.09701663, 0.20334584, -0.50592935, -0.31343272, -0.40032279, 0.44781327, 0.01387155, -0.35593212}; std::vector bw_weight_i2c = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, -0.20583314, 0.44344562, 0.22077113, -0.29909778}; std::vector bw_weight_i2o = {-0.25065863, -0.28290087, 0.04613829, 0.40525138, 0.44272184, 0.03897077, -0.1556896, 0.19487578}; auto bw_weight_i2i_tensor = graph->CreateTensor(bw_weight_i2i_spec, bw_weight_i2i.data()); auto bw_weight_i2f_tensor = graph->CreateTensor(bw_weight_i2f_spec, bw_weight_i2f.data()); auto bw_weight_i2c_tensor = graph->CreateTensor(bw_weight_i2c_spec, bw_weight_i2c.data()); auto bw_weight_i2o_tensor = graph->CreateTensor(bw_weight_i2o_spec, bw_weight_i2o.data()); std::vector bw_weight_r2i = { -0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}; std::vector bw_weight_r2f = { -0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, 0.28053468, 0.01560611, -0.20127171, -0.01140004}; std::vector bw_weight_r2c = { -0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, -0.46367589, 0.26016325, -0.03894562, -0.16368064}; std::vector bw_weight_r2o = { 0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, -0.51818722, -0.15390486, 0.0468148, 0.39922136}; auto bw_weight_r2i_tensor = graph->CreateTensor(bw_weight_r2i_spec, bw_weight_r2i.data()); auto bw_weight_r2f_tensor = graph->CreateTensor(bw_weight_r2f_spec, bw_weight_r2f.data()); auto bw_weight_r2c_tensor = graph->CreateTensor(bw_weight_r2c_spec, bw_weight_r2c.data()); auto bw_weight_r2o_tensor = graph->CreateTensor(bw_weight_r2o_spec, bw_weight_r2o.data()); std::vector bw_bias_i = {0.0, 0.0, 0.0, 0.0}; std::vector bw_bias_f = {1., 1., 1., 1.}; std::vector bw_bias_c = {0.0, 0.0, 0.0, 0.0}; std::vector bw_bias_o = {0.0, 0.0, 0.0, 0.0}; auto bw_bias_i_tensor = graph->CreateTensor(bw_bias_i_spec, bw_bias_i.data()); auto bw_bias_f_tensor = graph->CreateTensor(bw_bias_f_spec, bw_bias_f.data()); auto bw_bias_c_tensor = graph->CreateTensor(bw_bias_c_spec, bw_bias_c.data()); auto bw_bias_o_tensor = graph->CreateTensor(bw_bias_o_spec, fw_bias_o.data()); auto bidirectional_lstm = graph->CreateOperation( 0.0, 0.0, tim::vx::ops::BidirectionalSequenceLstm::ActivationType::kTANH, 0.0, false, tim::vx::ops::BidirectionalSequenceLstm::kSIGMOID, true); (*bidirectional_lstm) .BindInputs({ lstm_input, fw_weight_i2i_tensor, fw_weight_i2f_tensor, fw_weight_i2c_tensor, fw_weight_i2o_tensor, fw_weight_r2i_tensor, fw_weight_r2f_tensor, fw_weight_r2c_tensor, fw_weight_r2o_tensor, graph->CreateTensorPlaceHolder(), /*fw_weight_c2i*/ graph->CreateTensorPlaceHolder(), /*fw_weight_c2f*/ graph->CreateTensorPlaceHolder(), /*fw_weight_c2o*/ fw_bias_i_tensor, fw_bias_f_tensor, fw_bias_c_tensor, fw_bias_o_tensor, // optional for projection graph->CreateTensorPlaceHolder(), /*fw_weight_prj*/ graph->CreateTensorPlaceHolder(), /*fw_bias_prj*/ bw_weight_i2i_tensor, bw_weight_i2f_tensor, bw_weight_i2c_tensor, bw_weight_i2o_tensor, bw_weight_r2i_tensor, bw_weight_r2f_tensor, bw_weight_r2c_tensor, bw_weight_r2o_tensor, graph->CreateTensorPlaceHolder(), /*bw_weight_c2i*/ graph->CreateTensorPlaceHolder(), /*bw_weight_c2f*/ graph->CreateTensorPlaceHolder(), /*bw_weight_c2o*/ bw_bias_i_tensor, bw_bias_f_tensor, bw_bias_c_tensor, bw_bias_o_tensor, // optional for projection graph->CreateTensorPlaceHolder(), /*bw_weight_prj*/ graph->CreateTensorPlaceHolder(), /*bw_bias_prj*/ graph->CreateTensorPlaceHolder(), /*fw_h_state*/ graph->CreateTensorPlaceHolder(), /*fw_c_state*/ graph->CreateTensorPlaceHolder(), /*bw_h_state*/ graph->CreateTensorPlaceHolder(), /*bw_c_state*/ graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), // AUX graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), graph->CreateTensorPlaceHolder(), // Layer_norm }) .BindOutputs({ fw_output_tensor, make_empty_tensor( graph, tim::vx::ShapeType({n_output, n_batch}), tim::vx::TensorAttribute::OUTPUT), /*fw_h_state*/ make_empty_tensor( graph, tim::vx::ShapeType({n_cell, n_batch}), tim::vx::TensorAttribute::OUTPUT), /*fw_c_state*/ bw_output_tensor, make_empty_tensor( graph, tim::vx::ShapeType({n_output, n_batch}), tim::vx::TensorAttribute::OUTPUT), /*bw_h_state*/ make_empty_tensor( graph, tim::vx::ShapeType({n_cell, n_batch}), tim::vx::TensorAttribute::OUTPUT), /*bw_c_state*/ }); graph->Compile(); graph->Run(); std::vector lstm_fw_golden_output = { -0.02973187, 0.1229473, 0.20885126, -0.15358765, -0.03716109, 0.12507336, 0.41193449, -0.20860538, -0.15053082, 0.09120187, 0.24278517, -0.12222792}; std::vector lstm_bw_golden_output = { -0.02973187, 0.1229473, 0.20885126, -0.15358765, -0.03716109, 0.12507336, 0.41193449, -0.20860538, -0.15053082, 0.09120187, 0.24278517, -0.12222792}; std::vector fw_output(lstm_fw_golden_output.size()); std::vector bw_output(lstm_bw_golden_output.size()); fw_output_tensor->CopyDataFromTensor(fw_output.data()); bw_output_tensor->CopyDataFromTensor(bw_output.data()); EXPECT_TRUE(ArraysMatch(lstm_fw_golden_output, fw_output, 1e-4f)); EXPECT_TRUE(ArraysMatch(lstm_bw_golden_output, bw_output, 1e-4f)); }