This reverts commit f2e71a3deb.
This commit is contained in:
parent
18ce7b45fb
commit
171abb0f1b
|
|
@ -40,7 +40,6 @@
|
|||
#include "tim/vx/ops/elementwise.h"
|
||||
#include "tim/vx/ops/erf.h"
|
||||
#include "tim/vx/ops/fullyconnected.h"
|
||||
#include "tim/vx/ops/dense.h"
|
||||
#include "tim/vx/ops/gather.h"
|
||||
#include "tim/vx/ops/gathernd.h"
|
||||
#include "tim/vx/ops/groupedconv2d.h"
|
||||
|
|
|
|||
|
|
@ -1,54 +0,0 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2022 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_DENSE_H_
|
||||
#define TIM_VX_OPS_DENSE_H_
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
/**
|
||||
* ## Dense
|
||||
*
|
||||
* Denotes a fully (densely) connected layer, which connects all elements in the
|
||||
* input tensor with each element in the output tensor.
|
||||
*
|
||||
* - axis: Describes the axis of the inputs when coerced to 2D.
|
||||
* - weights: the output channel number for weight tensor.
|
||||
*/
|
||||
|
||||
class Dense : public Operation {
|
||||
public:
|
||||
Dense(Graph* graph, uint32_t axis);
|
||||
Dense(Graph* graph, uint32_t axis, uint32_t weights);
|
||||
|
||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
||||
#endif /* TIM_VX_OPS_DENSE_H_ */
|
||||
|
|
@ -44,7 +44,8 @@ class RNNCell : public Operation {
|
|||
std::shared_ptr<Operation> Clone(
|
||||
std::shared_ptr<Graph>& graph) const override;
|
||||
|
||||
|
||||
protected:
|
||||
const ActivationType activation_;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
|
|
|
|||
|
|
@ -1,131 +0,0 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2022 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#include "tim/vx/ops.h"
|
||||
|
||||
#include "op_impl.h"
|
||||
#include "vsi_nn_pub.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
class DenseImpl : public OpImpl {
|
||||
public:
|
||||
DenseImpl(Graph* graph, int input_cnt, int output_cnt, uint32_t axis,
|
||||
uint32_t weights, DataLayout layout = DataLayout::ANY)
|
||||
: OpImpl(graph, -1, input_cnt, output_cnt, layout),
|
||||
axis_(axis),
|
||||
weights_(weights) {
|
||||
FC_op_ =
|
||||
graph->CreateOperation<tim::vx::ops::FullyConnected>(axis, weights);
|
||||
}
|
||||
|
||||
~DenseImpl() {}
|
||||
|
||||
DenseImpl& BindInput(const std::shared_ptr<Tensor>& tensor) override {
|
||||
in_tensors_[input_tensor_index] = tensor;
|
||||
if (this->input_tensor_index == 1) {
|
||||
auto input_tensor = in_tensors_[0];
|
||||
auto weight_tensor = in_tensors_[1];
|
||||
|
||||
if (input_tensor->GetShape().size() > 2 ||
|
||||
(input_tensor->GetShape().size() == 2 &&
|
||||
input_tensor->GetShape()[0] != weight_tensor->GetShape()[0])) {
|
||||
uint32_t input_size = weight_tensor->GetShape()[0];
|
||||
uint32_t total_input_size = 1;
|
||||
for (uint8_t i = 0; i < input_tensor->GetShape().size(); i++) {
|
||||
total_input_size *= input_tensor->GetShape()[i];
|
||||
}
|
||||
uint32_t input_batch = total_input_size / input_size;
|
||||
tim::vx::TensorSpec reshape_spec(tim::vx::DataType::FLOAT32, {0, 0},
|
||||
tim::vx::TensorAttribute::TRANSIENT);
|
||||
auto reshape_output = graph_->CreateTensor(reshape_spec);
|
||||
std::vector<uint32_t> new_shape{input_size, input_batch};
|
||||
auto reshape_op =
|
||||
graph_->CreateOperation<tim::vx::ops::Reshape>(new_shape);
|
||||
(*reshape_op).BindInput(in_tensors_[0]);
|
||||
(*reshape_op).BindOutput(reshape_output);
|
||||
in_tensors_[0] = reshape_output;
|
||||
}
|
||||
FC_op_->BindInput(in_tensors_[0]);
|
||||
FC_op_->BindInput(in_tensors_[1]);
|
||||
}
|
||||
if (this->input_tensor_index == 2) {
|
||||
FC_op_->BindInput(in_tensors_[input_tensor_index]);
|
||||
}
|
||||
input_tensor_index++;
|
||||
return *this;
|
||||
}
|
||||
|
||||
DenseImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) override {
|
||||
out_tensors_[output_tensor_index] = tensor;
|
||||
if (tensor->GetShape().size() > 2) {
|
||||
tim::vx::TensorSpec fc_spec(tim::vx::DataType::FLOAT32, {0, 0},
|
||||
tim::vx::TensorAttribute::TRANSIENT);
|
||||
auto fc_out = graph_->CreateTensor(fc_spec);
|
||||
FC_op_->BindOutput(fc_out);
|
||||
auto reshape_op =
|
||||
graph_->CreateOperation<tim::vx::ops::Reshape>(tensor->GetShape());
|
||||
(*reshape_op).BindInput(fc_out);
|
||||
(*reshape_op).BindOutput(tensor);
|
||||
} else {
|
||||
FC_op_->BindOutput(tensor);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
vsi_nn_node_t* node() override { return nullptr; }
|
||||
|
||||
std::vector<std::shared_ptr<Tensor>> InputsTensor() override {
|
||||
return inputs_tensor_;
|
||||
}
|
||||
std::vector<std::shared_ptr<Tensor>> OutputsTensor() override {
|
||||
return outputs_tensor_;
|
||||
}
|
||||
|
||||
uint32_t axis_;
|
||||
uint32_t weights_;
|
||||
|
||||
private:
|
||||
std::shared_ptr<tim::vx::Operation> FC_op_;
|
||||
std::array<std::shared_ptr<tim::vx::Tensor>, 3> in_tensors_;
|
||||
std::array<std::shared_ptr<tim::vx::Tensor>, 1> out_tensors_;
|
||||
};
|
||||
|
||||
Dense::Dense(Graph* graph, uint32_t axis) : Dense(graph, axis, 0) {}
|
||||
|
||||
Dense::Dense(Graph* graph, uint32_t axis, uint32_t weights) {
|
||||
impl_ =
|
||||
std::make_unique<DenseImpl>(graph, 0, 0, axis, weights, DataLayout::ANY);
|
||||
}
|
||||
|
||||
std::shared_ptr<Operation> Dense::Clone(std::shared_ptr<Graph>& graph) const {
|
||||
return graph->CreateOperation<Dense>(
|
||||
dynamic_cast<DenseImpl*>(this->impl_.get())->axis_,
|
||||
dynamic_cast<DenseImpl*>(this->impl_.get())->weights_);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2021 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#include "tim/vx/context.h"
|
||||
#include "tim/vx/graph.h"
|
||||
#include "tim/vx/ops.h"
|
||||
|
||||
#include "test_utils.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
TEST(Dense, shape_2_2) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
tim::vx::ShapeType in_shape({2, 2});
|
||||
tim::vx::ShapeType weight_shape({2, 3});
|
||||
tim::vx::ShapeType bias_shape({3, 1});
|
||||
tim::vx::ShapeType out_shape({3, 2});
|
||||
tim::vx::TensorSpec in_spec(tim::vx::DataType::FLOAT32, in_shape,
|
||||
tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec weight_spec(tim::vx::DataType::FLOAT32, weight_shape,
|
||||
tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, bias_shape,
|
||||
tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32, out_shape,
|
||||
tim::vx::TensorAttribute::OUTPUT);
|
||||
auto in_tensor = graph->CreateTensor(in_spec);
|
||||
auto weight_tensor = graph->CreateTensor(weight_spec);
|
||||
auto bias_tensor = graph->CreateTensor(bias_spec);
|
||||
auto out_tensor = graph->CreateTensor(out_spec);
|
||||
|
||||
std::vector<float> in_data = {
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
6,
|
||||
};
|
||||
std::vector<float> weight_data = {
|
||||
-3, 3, 2, 1, 0, 4,
|
||||
};
|
||||
std::vector<float> bias_data = {
|
||||
0.1,
|
||||
0.4,
|
||||
0.5,
|
||||
};
|
||||
std::vector<float> golden = {9.1, 6.4, 16.5, 12.1, 10.4, 24.5};
|
||||
|
||||
EXPECT_TRUE(in_tensor->CopyDataToTensor(in_data.data(),
|
||||
in_data.size() * sizeof(float)));
|
||||
EXPECT_TRUE(weight_tensor->CopyDataToTensor(
|
||||
weight_data.data(), weight_data.size() * sizeof(float)));
|
||||
EXPECT_TRUE(bias_tensor->CopyDataToTensor(bias_data.data(),
|
||||
bias_data.size() * sizeof(float)));
|
||||
auto op = graph->CreateOperation<tim::vx::ops::Dense>(0, 3);
|
||||
(*op)
|
||||
.BindInputs({in_tensor, weight_tensor, bias_tensor})
|
||||
.BindOutputs({out_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
std::vector<float> output(golden.size());
|
||||
|
||||
EXPECT_TRUE(out_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
|
||||
}
|
||||
|
||||
TEST(Dense, shape_1_2_3_2) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
tim::vx::ShapeType in_shape({2, 1, 2, 3}); // (d2, d1, d0, batch_size)
|
||||
tim::vx::ShapeType weight_shape({2, 3}); // (input_size, weights)
|
||||
tim::vx::ShapeType bias_shape({3, 1}); // (weights, 1)
|
||||
tim::vx::ShapeType out_shape({3, 1, 2, 3}); // (weights, d1, d0, batch_size)
|
||||
tim::vx::TensorSpec in_spec(tim::vx::DataType::FLOAT32, in_shape,
|
||||
tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec weight_spec(tim::vx::DataType::FLOAT32, weight_shape,
|
||||
tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, bias_shape,
|
||||
tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec out_spec(tim::vx::DataType::FLOAT32, out_shape,
|
||||
tim::vx::TensorAttribute::OUTPUT);
|
||||
auto in_tensor = graph->CreateTensor(in_spec);
|
||||
auto weight_tensor = graph->CreateTensor(weight_spec);
|
||||
auto bias_tensor = graph->CreateTensor(bias_spec);
|
||||
auto out_tensor = graph->CreateTensor(out_spec);
|
||||
|
||||
std::vector<float> in_data = {
|
||||
0.12609188, 0.46347019, 0.89598465, 0.27901134, 0.35867718, 0.36897406,
|
||||
0.73463392, 0.27901134, 0.12609188, 0.46347019, 0.89598465, 0.27901134,
|
||||
};
|
||||
std::vector<float> weight_data = {
|
||||
-0.31930989, 0.37613347, 0.27901134, -1.36916667, 0.38031587, 0.21580373,
|
||||
};
|
||||
std::vector<float> bias_data = {
|
||||
0.12609188,
|
||||
0.46347019,
|
||||
0.21580373,
|
||||
};
|
||||
std::vector<float> golden = {
|
||||
0.260156, -0.135917, 0.363777, -0.0550594, 0.331447, 0.616773,
|
||||
0.150346, 0.0583582, 0.43184, -0.0035385, 0.286428, 0.555408,
|
||||
0.260156, -0.135917, 0.363777, -0.0550594, 0.331447, 0.616773,
|
||||
};
|
||||
|
||||
EXPECT_TRUE(in_tensor->CopyDataToTensor(in_data.data(),
|
||||
in_data.size() * sizeof(float)));
|
||||
EXPECT_TRUE(weight_tensor->CopyDataToTensor(
|
||||
weight_data.data(), weight_data.size() * sizeof(float)));
|
||||
EXPECT_TRUE(bias_tensor->CopyDataToTensor(bias_data.data(),
|
||||
bias_data.size() * sizeof(float)));
|
||||
auto op = graph->CreateOperation<tim::vx::ops::Dense>(0, 3);
|
||||
(*op)
|
||||
.BindInputs({in_tensor, weight_tensor, bias_tensor})
|
||||
.BindOutputs({out_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
std::vector<float> output(golden.size());
|
||||
|
||||
EXPECT_TRUE(out_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
|
||||
}
|
||||
|
|
@ -27,7 +27,7 @@
|
|||
#include "test_utils.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
TEST(GroupedConv1d, shape_6_2_1_float_ksize_6_stride_1_group_2_no_bias_wcn_PaddingTest) {
|
||||
TEST(GroupedConv1d, shape_6_2_1_float_ksize_6_stride_1_group_2_no_bias_wcn) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
|
|
@ -47,21 +47,20 @@ TEST(GroupedConv1d, shape_6_2_1_float_ksize_6_stride_1_group_2_no_bias_wcn_Paddi
|
|||
|
||||
std::vector<float> in_data = {
|
||||
-1, 0, 1, -1.5, 0.5, 1.5,
|
||||
2, -0.5, 2, -2.5, 0, 2.5,
|
||||
-2, -0.5, 2, -2.5, 0, 2.5,
|
||||
};
|
||||
std::vector<float> weight = {
|
||||
-3, 2, -1.5, 1.5, 2, 3,
|
||||
-3, -2, -1.5, 1.5, 2, 3,
|
||||
-2.5, -2, -1.5, 1.5, 2, 2.5,
|
||||
};
|
||||
std::vector<float> golden = {
|
||||
-1, -6
|
||||
4.75, 5.5,
|
||||
};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float)));
|
||||
EXPECT_TRUE(weight_tensor->CopyDataToTensor(weight.data(), weight.size() * sizeof(float)));
|
||||
|
||||
std::array<uint32_t, 2> pad = {0,0};
|
||||
auto op = graph->CreateOperation<tim::vx::ops::GroupedConv1d>(tim::vx::PadType::SAME, pad, 1, 1, 2);
|
||||
auto op = graph->CreateOperation<tim::vx::ops::GroupedConv1d>(tim::vx::PadType::VALID, 1, 1, 2);
|
||||
(*op).BindInputs({input_tensor, weight_tensor}).BindOutputs({output_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
|
|
@ -72,7 +71,7 @@ TEST(GroupedConv1d, shape_6_2_1_float_ksize_6_stride_1_group_2_no_bias_wcn_Paddi
|
|||
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
|
||||
}
|
||||
|
||||
TEST(GroupedConv1d, shape_6_2_1_float_ksize_6_stride_1_group_2_no_bias_wcn) {
|
||||
TEST(GroupedConv1d, shape_6_2_1_float_ksize_6_stride_1_group_2_no_bias_wcn_PaddingTest) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
|
|
@ -114,5 +113,6 @@ TEST(GroupedConv1d, shape_6_2_1_float_ksize_6_stride_1_group_2_no_bias_wcn) {
|
|||
|
||||
std::vector<float> output(golden.size());
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
|
||||
// EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,9 +48,9 @@ class RNNCellImpl : public OpImpl {
|
|||
// signature end
|
||||
};
|
||||
|
||||
RNNCellImpl(Graph* graph, int input_cnt, int output_cnt, RNNCell::ActivationType activation,
|
||||
RNNCellImpl(Graph* graph, int input_cnt, int output_cnt,
|
||||
DataLayout layout = DataLayout::ANY)
|
||||
: OpImpl(graph, -1, input_cnt, output_cnt, layout), activation_(activation){
|
||||
: OpImpl(graph, -1, input_cnt, output_cnt, layout) {
|
||||
fc0_ = graph->CreateOperation<tim::vx::ops::FullyConnected>(0, 4);
|
||||
fc1_ = graph->CreateOperation<tim::vx::ops::FullyConnected>(0, 4);
|
||||
add_ = graph->CreateOperation<tim::vx::ops::Add>();
|
||||
|
|
@ -126,16 +126,15 @@ class RNNCellImpl : public OpImpl {
|
|||
|
||||
std::array<std::shared_ptr<tim::vx::Tensor>, INPUT_CNT> in_tensors_;
|
||||
std::array<std::shared_ptr<tim::vx::Tensor>, OUT_CNT> out_tensors_;
|
||||
public:
|
||||
const RNNCell::ActivationType activation_;
|
||||
};
|
||||
|
||||
RNNCell::RNNCell(Graph* graph, ActivationType activation) {
|
||||
impl_ = std::make_unique<RNNCellImpl>(graph, 0, 0, activation, DataLayout::ANY);
|
||||
RNNCell::RNNCell(Graph* graph, ActivationType activation)
|
||||
: activation_(activation) {
|
||||
impl_ = std::make_unique<RNNCellImpl>(graph, 0, 0, DataLayout::ANY);
|
||||
}
|
||||
|
||||
std::shared_ptr<Operation> RNNCell::Clone(std::shared_ptr<Graph>& graph) const {
|
||||
return graph->CreateOperation<RNNCell>(dynamic_cast<RNNCellImpl*>(this->impl_.get())->activation_);
|
||||
return graph->CreateOperation<RNNCell>(this->activation_);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
|
|
|
|||
Loading…
Reference in New Issue