From 1f85d2155819c7ac4b9c30d4e6d4b7f2252ce487 Mon Sep 17 00:00:00 2001 From: chxin66 <57057788+chxin66@users.noreply.github.com> Date: Thu, 9 Dec 2021 10:33:40 +0800 Subject: [PATCH] mapped signal frame & unit test (#234) Signed-off-by: Chen Xin Co-authored-by: Chen Xin --- include/tim/vx/ops/shuffle_channel.h | 4 +- include/tim/vx/ops/signal_frame.h | 60 +++++++++++++++++++ src/tim/vx/ops/README.md | 9 +-- src/tim/vx/ops/shuffle_channel.cc | 6 +- src/tim/vx/ops/shuffle_channel_test.cc | 20 +++---- src/tim/vx/ops/signal_frame.cc | 52 +++++++++++++++++ src/tim/vx/ops/signal_frame_test.cc | 81 ++++++++++++++++++++++++++ 7 files changed, 211 insertions(+), 21 deletions(-) create mode 100644 include/tim/vx/ops/signal_frame.h create mode 100644 src/tim/vx/ops/signal_frame.cc create mode 100644 src/tim/vx/ops/signal_frame_test.cc diff --git a/include/tim/vx/ops/shuffle_channel.h b/include/tim/vx/ops/shuffle_channel.h index 382aa89..c3c7bfa 100644 --- a/include/tim/vx/ops/shuffle_channel.h +++ b/include/tim/vx/ops/shuffle_channel.h @@ -38,9 +38,9 @@ namespace ops { * ``` */ -class shuffle_channel : public Operation { +class ShuffleChannel : public Operation { public: - explicit shuffle_channel(Graph* graph, int32_t num_groups, int32_t index_axis); + explicit ShuffleChannel(Graph* graph, int32_t num_groups, int32_t index_axis); std::shared_ptr Clone(std::shared_ptr& graph) const override; }; diff --git a/include/tim/vx/ops/signal_frame.h b/include/tim/vx/ops/signal_frame.h new file mode 100644 index 0000000..3203873 --- /dev/null +++ b/include/tim/vx/ops/signal_frame.h @@ -0,0 +1,60 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_SIGNALFRAME_H_ +#define TIM_VX_OPS_SIGNALFRAME_H_ +#include "tim/vx/operation.h" + +namespace tim { +namespace vx { +namespace ops { + +/** + * ## Signalframe + * + * ``` + * tf.signal.frame( + signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=0, name=None + ) : Expands signal's axis dimension into frames of frame_length. + * ``` + */ + +class SignalFrame : public Operation { + public: + SignalFrame(Graph* graph, uint32_t window_length, uint32_t step, uint32_t pad_end=0, + uint32_t axis=0); + + std::shared_ptr Clone(std::shared_ptr& graph) const override; + + protected: + const uint32_t window_length_; + const uint32_t step_; + const uint32_t pad_end_; + const uint32_t axis_; +}; + +} // namespace ops +} // namespace vx +} // namespace tim + +#endif /* TIM_VX_OPS_SIGNALFRAME_H_ */ \ No newline at end of file diff --git a/src/tim/vx/ops/README.md b/src/tim/vx/ops/README.md index a301e56..89d698a 100644 --- a/src/tim/vx/ops/README.md +++ b/src/tim/vx/ops/README.md @@ -97,15 +97,15 @@ Unstack|UNSTACK|Mapped|[tf.unstack](https://tensorflow.google.cn/api_docs/python Tile|TILE|Mapped|[tf.tile](https://tensorflow.google.cn/api_docs/python/tf/tile) GroupedConv2d|GROUPED_CONV2D|Mapped|[ANEURALNETWORKS_GROUPED_CONV_2D](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a847acf8d9f3d2343328c3dbe6d447c50) SpatialTransformer|SPATIAL_TRANSFORMER|Mapped|[SpatialTransformer](https://github.com/daerduoCarey/SpatialTransformerLayer) -shuffle_channel|SHUFFLECHANNEL|Mapped|[ANEURALNETWORKS_CHANNEL_SHUFFLE](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a5b993c1211c4b1bc52fb595a3025251d) +ShuffleChannel|SHUFFLECHANNEL|Mapped|[ANEURALNETWORKS_CHANNEL_SHUFFLE](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a5b993c1211c4b1bc52fb595a3025251d) Gelu|GELU|Mapped|[tf.nn.gelu](https://tensorflow.google.cn/api_docs/python/tf/nn/gelu) Svdf|SVDF|Mapped|[ANEURALNETWORKS_SVDF](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a7096de21038c1ce49d354a00cba7b552) Erf|ERF|Mapped|[tf.math.erf](https://tensorflow.google.cn/api_docs/python/tf/math/erf) -GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv1D) +GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/Conv1D?hl=en) +|SignalFrame|SIGNAL_FRAME|Mapped|[tf.signal.frame](https://tensorflow.google.cn/api_docs/python/tf/signal/frame) ||PROPOSAL| TBD |[Faster-RCNN Proposal Layer](https://github.com/intel/caffe/blob/master/examples/faster-rcnn/lib/rpn/proposal_layer.py) ||ROI_POOL|Planned 22Q1 |[ANEURALNETWORKS_ROI_POOLING](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a6736198af337b2efbdb0b6b64dee7fe4) ||ROI_ALIGN| TBD |[ANEURALNETWORKS_ROI_ALIGN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a2848b39dd4bfba78f2438fda0d9397a4) -||SIGNAL_FRAME|Planned 21Q3|[tf.signal.frame](https://tensorflow.google.cn/api_docs/python/tf/signal/frame) ||TOPK|Planned 21Q4|[tf.math.top_k](https://tensorflow.google.cn/api_docs/python/tf/math/top_k) |GRUCell|GRUCELL_OVXLIB|Planned 21Q3|[tf.keras.layers.GRUCell](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/GRUCell?hl=en) |UnidirectionalSequenceGRU|GRU_OVXLIB|Planned 21Q4|[tf.keras.layers.GRU](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/GRUCell?hl=en) @@ -119,7 +119,6 @@ GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://www.tensorflow.org/api_do ||HASHTABLE_LOOKUP|Planned 21Q4|[ANEURALNETWORKS_HASHTABLE_LOOKUP](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0aca92716c8c73c1f0fa7f0757916fee26) ||EMBEDDING_LOOKUP|Planned 21Q4|[ANEURALNETWORKS_EMBEDDING_LOOKUP](developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a8d2ada77adb74357fc0770405bca0e3) ||LSH_PROJECTION|Planned 21Q4|[ANEURALNETWORKS_LSH_PROJECTION](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a800cdcec5d7ba776789cb2d1ef669965) -||SVDF|Mapped |[ANEURALNETWORKS_SVDF](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a7096de21038c1ce49d354a00cba7b552) ||HEATMAP_MAX_KEYPOINT|Planned 21Q4|[ANEURALNETWORKS_HEATMAP_MAX_KEYPOINT](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a5ffccf92d127766a741225ff7ad6f743) ||AXIS_ALIGNED_BBOX_TRANSFORM|Planned 21Q4|[ANEURALNETWORKS_AXIS_ALIGNED_BBOX_TRANSFORM](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0afd7603dd54060e6a52f5861674448528) ||BOX_WITH_NMS_LIMIT|Planned 21Q4|[ANEURALNETWORKS_BOX_WITH_NMX_LIMIT](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a2d81e878c19e15700dad111ba6c0be89) @@ -132,10 +131,8 @@ GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://www.tensorflow.org/api_do ||CEIL|Planned 21Q4|[tf.math.ceil](https://tensorflow.google.cn/api_docs/python/tf/math/ceil) ||SEQUENCE_MASK|Planned 21Q4|[tf.math.ceil](https://tensorflow.google.cn/api_docs/python/tf/sequence_mask) ||REPEAT|Planned 21Q4|[tf.repeat](https://tensorflow.google.cn/api_docs/python/tf/repeat) -||ERF|Planned 21Q4|[tf.math.erf](https://tensorflow.google.cn/api_docs/python/tf/math/erf) ||ONE_HOT|Planned 21Q4|[tf.one_hot](https://tensorflow.google.cn/api_docs/python/tf/one_hot) ||NMS|Planned 21Q4|[tf.image.non_max_suppression](https://tensorflow.google.cn/api_docs/python/tf/image/non_max_suppression) -||GROUPED_CONV1D|Planned 21Q4| ||SCATTER_ND_UPDATE|Planned 21Q4|[tf.compat.v1.scatter_nd_update](https://tensorflow.google.cn/api_docs/python/tf/compat/v1/scatter_nd_update) ||GELU|Planned 21Q4|[tf.nn.gelu](https://tensorflow.google.cn/api_docs/python/tf/nn/gelu) ||CONV_RELU|Deprecated diff --git a/src/tim/vx/ops/shuffle_channel.cc b/src/tim/vx/ops/shuffle_channel.cc index 18b3b45..2557cd0 100644 --- a/src/tim/vx/ops/shuffle_channel.cc +++ b/src/tim/vx/ops/shuffle_channel.cc @@ -28,16 +28,16 @@ namespace tim { namespace vx { namespace ops { -shuffle_channel::shuffle_channel(Graph* graph, int32_t num_groups, +ShuffleChannel::ShuffleChannel(Graph* graph, int32_t num_groups, int32_t index_axis) : Operation(graph, VSI_NN_OP_SHUFFLECHANNEL, 1, 1) { this->impl()->node()->nn_param.shufflechannel.group_number = num_groups; this->impl()->node()->nn_param.shufflechannel.axis = index_axis; } -std::shared_ptr shuffle_channel::Clone( +std::shared_ptr ShuffleChannel::Clone( std::shared_ptr& graph) const { - return graph->CreateOperation( + return graph->CreateOperation( this->impl()->node()->nn_param.shufflechannel.group_number, this->impl()->node()->nn_param.shufflechannel.axis); } diff --git a/src/tim/vx/ops/shuffle_channel_test.cc b/src/tim/vx/ops/shuffle_channel_test.cc index 9e1e028..52ae8a7 100644 --- a/src/tim/vx/ops/shuffle_channel_test.cc +++ b/src/tim/vx/ops/shuffle_channel_test.cc @@ -29,7 +29,7 @@ #include "gtest/gtest.h" -TEST(shuffle_channel, shape_3_6_groupnum2_dim1_float32) { +TEST(ShuffleChannel, shape_3_6_groupnum2_dim1_float32) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -61,7 +61,7 @@ TEST(shuffle_channel, shape_3_6_groupnum2_dim1_float32) { }; EXPECT_TRUE(in_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float))); - auto op = graph->CreateOperation(2, 1); + auto op = graph->CreateOperation(2, 1); (*op).BindInput(in_tensor).BindOutput(out_tensor); EXPECT_TRUE(graph->Compile()); @@ -72,7 +72,7 @@ TEST(shuffle_channel, shape_3_6_groupnum2_dim1_float32) { EXPECT_EQ(golden, output); } -TEST(shuffle_channel, shape_4_2_2_groupnum2_dim0_float32) { +TEST(ShuffleChannel, shape_4_2_2_groupnum2_dim0_float32) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -94,7 +94,7 @@ TEST(shuffle_channel, shape_4_2_2_groupnum2_dim0_float32) { }; EXPECT_TRUE(in_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float))); - auto op = graph->CreateOperation(2, 0); + auto op = graph->CreateOperation(2, 0); (*op).BindInput(in_tensor).BindOutput(out_tensor); EXPECT_TRUE(graph->Compile()); @@ -105,7 +105,7 @@ TEST(shuffle_channel, shape_4_2_2_groupnum2_dim0_float32) { EXPECT_EQ(golden, output); } -TEST(shuffle_channel, shape_1_4_2_2_groupnum2_dim1_float32) { +TEST(ShuffleChannel, shape_1_4_2_2_groupnum2_dim1_float32) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -127,7 +127,7 @@ TEST(shuffle_channel, shape_1_4_2_2_groupnum2_dim1_float32) { }; EXPECT_TRUE(in_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float))); - auto op = graph->CreateOperation(2, 1); + auto op = graph->CreateOperation(2, 1); (*op).BindInput(in_tensor).BindOutput(out_tensor); EXPECT_TRUE(graph->Compile()); @@ -138,7 +138,7 @@ TEST(shuffle_channel, shape_1_4_2_2_groupnum2_dim1_float32) { EXPECT_EQ(golden, output); } -TEST(shuffle_channel, shape_4_1_2_2_groupnum4_dim0_float32) { +TEST(ShuffleChannel, shape_4_1_2_2_groupnum4_dim0_float32) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -160,7 +160,7 @@ TEST(shuffle_channel, shape_4_1_2_2_groupnum4_dim0_float32) { }; EXPECT_TRUE(in_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float))); - auto op = graph->CreateOperation(4, 0); + auto op = graph->CreateOperation(4, 0); (*op).BindInput(in_tensor).BindOutput(out_tensor); EXPECT_TRUE(graph->Compile()); @@ -171,7 +171,7 @@ TEST(shuffle_channel, shape_4_1_2_2_groupnum4_dim0_float32) { EXPECT_EQ(golden, output); } -TEST(shuffle_channel, shape_4_1_2_2_groupnum1_dim3_float32) { +TEST(ShuffleChannel, shape_4_1_2_2_groupnum1_dim3_float32) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -193,7 +193,7 @@ TEST(shuffle_channel, shape_4_1_2_2_groupnum1_dim3_float32) { }; EXPECT_TRUE(in_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float))); - auto op = graph->CreateOperation(1, 3); + auto op = graph->CreateOperation(1, 3); (*op).BindInput(in_tensor).BindOutput(out_tensor); EXPECT_TRUE(graph->Compile()); diff --git a/src/tim/vx/ops/signal_frame.cc b/src/tim/vx/ops/signal_frame.cc new file mode 100644 index 0000000..f7b90ee --- /dev/null +++ b/src/tim/vx/ops/signal_frame.cc @@ -0,0 +1,52 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "operation_private.h" +#include "tim/vx/ops/signal_frame.h" +#include "vsi_nn_pub.h" +namespace tim { +namespace vx { +namespace ops { + +SignalFrame::SignalFrame(Graph* graph, uint32_t window_length, uint32_t step, uint32_t pad_end, + uint32_t axis) + : Operation(graph, VSI_NN_OP_SIGNAL_FRAME), + window_length_(window_length), + step_(step), + pad_end_(pad_end), + axis_(axis) { + this->impl()->node()->nn_param.signalframe.window_length = window_length_; + this->impl()->node()->nn_param.signalframe.step = step_; + this->impl()->node()->nn_param.signalframe.pad_end = pad_end_; + this->impl()->node()->nn_param.signalframe.axis = axis_; +} + +std::shared_ptr SignalFrame::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation( + this->window_length_, this->step_, this->pad_end_, this->axis_); +} + +} // namespace ops +} // namespace vx +} // namespace tim diff --git a/src/tim/vx/ops/signal_frame_test.cc b/src/tim/vx/ops/signal_frame_test.cc new file mode 100644 index 0000000..3dd87b6 --- /dev/null +++ b/src/tim/vx/ops/signal_frame_test.cc @@ -0,0 +1,81 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops/signal_frame.h" +#include "test_utils.h" +#include "gtest/gtest.h" + +TEST(SignalFrame, shape_10_3_float_step_2_windows_4) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType in_shape({10, 3}); + tim::vx::ShapeType out_shape({4, 4, 3}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + out_shape, tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 0.9854245 , 1.3478903 , 2.079034 , 0.5336022 , -0.8521084 , + 1.4714626 , -1.6673858 , 1.1760164 , 0.58944523, -0.38136077, + 0.4713266 , -0.54476035, 0.17260066, 0.4458921 , 0.07180826, + -0.5209453 , 0.67287415, -0.40036386, 1.819254 , -0.83165807, + 0.7842376 , -0.51183605, 0.5516365 , -0.3449794 , -0.4545289 , + 1.4418068 , 2.6290808 , 0.26231438, -0.50589 , -1.903558 , + }; + + std::vector golden = { + 0.9854245 , 1.3478903 , 2.079034 , 0.5336022 , + 2.079034 , 0.5336022 , -0.8521084 , 1.4714626 , + -0.8521084 , 1.4714626 , -1.6673858 , 1.1760164 , + -1.6673858 , 1.1760164 , 0.58944523, -0.38136077, + + 0.4713266 , -0.54476035, 0.17260066, 0.4458921 , + 0.17260066, 0.4458921 , 0.07180826, -0.5209453 , + 0.07180826, -0.5209453 , 0.67287415, -0.40036386, + 0.67287415, -0.40036386, 1.819254 , -0.83165807, + + 0.7842376 , -0.51183605, 0.5516365 , -0.3449794 , + 0.5516365 , -0.3449794 , -0.4545289 , 1.4418068 , + -0.4545289 , 1.4418068 , 2.6290808 , 0.26231438, + 2.6290808 , 0.26231438, -0.50589 , -1.903558 , + }; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(4, 2, 0, 0); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size() * sizeof(float)); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); +} \ No newline at end of file