/**************************************************************************** * * Copyright (c) 2021 Vivante Corporation * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), * to deal in the Software without restriction, including without limitation * the rights to use, copy, modify, merge, publish, distribute, sublicense, * and/or sell copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ #include "tim/vx/ops.h" #include "vsi_nn_pub.h" #include "op_impl.h" #include namespace tim { namespace vx { namespace ops { class RNNCellImpl : public OpImpl { public: enum { // signature FULLY_CONNECTED_0_IN = 0, FULLY_CONNECTED_0_WEIGHT = 1, FULLY_CONNECTED_0_BIAS = 2, FULLY_CONNECTED_1_WEIGHT = 3, FULLY_CONNECTED_1_STATE_IN = 4, INPUT_CNT, OUT = 0, STATE_OUT, OUT_CNT, // signature end }; RNNCellImpl(Graph* graph, int input_cnt, int output_cnt, DataLayout layout = DataLayout::ANY) : OpImpl(graph, -1, input_cnt, output_cnt, layout) { fc0_ = graph->CreateOperation(0, 4); fc1_ = graph->CreateOperation(0, 4); add_ = graph->CreateOperation(); tanh_ = graph->CreateOperation(); data_convert_ = graph->CreateOperation(); } ~RNNCellImpl() {} RNNCellImpl& BindInput(const std::shared_ptr& tensor) override { in_tensors_[input_tensor_index] = tensor; if (this->input_tensor_index == INPUT_CNT - 1) { // Get all input tensor tim::vx::ShapeType shape = {0, 0}; tim::vx::TensorSpec FC0_spec(tim::vx::DataType::FLOAT32, shape, tim::vx::TensorAttribute::TRANSIENT); tim::vx::TensorSpec FC1_spec(tim::vx::DataType::FLOAT32, shape, tim::vx::TensorAttribute::TRANSIENT); tim::vx::TensorSpec add_spec(tim::vx::DataType::FLOAT32, shape, tim::vx::TensorAttribute::TRANSIENT); auto FC0_tensor = graph_->CreateTensor(FC0_spec); auto FC1_tensor = graph_->CreateTensor(FC1_spec); auto add_tensor = graph_->CreateTensor(add_spec); fc0_->BindInput(in_tensors_[FULLY_CONNECTED_0_IN]); fc0_->BindInput(in_tensors_[FULLY_CONNECTED_0_WEIGHT]); fc0_->BindInput(in_tensors_[FULLY_CONNECTED_0_BIAS]); fc0_->BindOutput(FC0_tensor); fc1_->BindInput(in_tensors_[FULLY_CONNECTED_1_WEIGHT]); fc1_->BindInput(in_tensors_[FULLY_CONNECTED_1_STATE_IN]); fc1_->BindOutput(FC1_tensor); add_->BindInput(FC0_tensor); add_->BindInput(FC1_tensor); add_->BindOutput(add_tensor); tanh_->BindInput(add_tensor); } this->input_tensor_index++; return *this; } RNNCellImpl& BindOutput(const std::shared_ptr& tensor) override { out_tensors_[output_tensor_index] = tensor; tanh_->BindOutput(out_tensors_[OUT]); data_convert_->BindInput(out_tensors_[OUT]); if (this->output_tensor_index == OUT_CNT - 1) { data_convert_->BindOutput(out_tensors_[STATE_OUT]); } this->output_tensor_index++; return *this; } vsi_nn_node_t* node() override { return nullptr; } std::vector> InputsTensor() override { return inputs_tensor_; } std::vector> OutputsTensor() override { return outputs_tensor_; } private: std::shared_ptr fc0_; std::shared_ptr fc1_; std::shared_ptr add_; std::shared_ptr tanh_; std::shared_ptr data_convert_; std::array, INPUT_CNT> in_tensors_; std::array, OUT_CNT> out_tensors_; }; RNNCell::RNNCell(Graph* graph, ActivationType activation) : activation_(activation) { impl_ = std::make_unique(graph, 0, 0, DataLayout::ANY); } std::shared_ptr RNNCell::Clone(std::shared_ptr& graph) const { return graph->CreateOperation(this->activation_); } } // namespace ops } // namespace vx } // namespace tim