From b916e1301acdd05c307f391f2b3de7a18531da4d Mon Sep 17 00:00:00 2001 From: Antkillerfarm Date: Mon, 18 Apr 2022 15:45:15 +0800 Subject: [PATCH] Add Broadcast op (#365) --- docs/Operators.md | 130 +++++++++++- include/tim/vx/ops.h | 1 + include/tim/vx/ops/broadcast.h | 61 ++++++ src/tim/vx/ops/README.md | 1 + src/tim/vx/ops/broadcast.cc | 57 ++++++ src/tim/vx/ops/broadcast_test.cc | 329 +++++++++++++++++++++++++++++++ 6 files changed, 573 insertions(+), 6 deletions(-) create mode 100644 include/tim/vx/ops/broadcast.h create mode 100644 src/tim/vx/ops/broadcast.cc create mode 100644 src/tim/vx/ops/broadcast_test.cc diff --git a/docs/Operators.md b/docs/Operators.md index e782f86..17e8bf7 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -7,9 +7,11 @@ - [ArgMin/ArgMax](#argminargmax) - [Batch2Space](#batch2space) - [BatchNorm](#batchnorm) + - [Broadcast](#broadcast) - [Clip](#clip) - [Concat](#concat) - [Conv2d](#conv2d) + - [Conv3d](#conv3d) - [DeConv2d](#deconv2d) - [DeConv1d](#deconv1d) - [DepthToSpace](#depthtospace) @@ -22,9 +24,11 @@ - [Minimum](#minimum) - [Maximum](#maximum) - [FloorDiv](#floordiv) + - [Erf](#erf) - [FullyConnected](#fullyconnected) - [Gather](#gather) - [GatherNd](#gathernd) + - [GroupedConv1d](#groupedconv1d) - [GroupedConv2d](#groupedconv2d) - [L2Normalization](#l2normalization) - [LocalResponseNormalization](#localresponsenormalization) @@ -36,8 +40,12 @@ - [MaxUnpool2d](#maxunpool2d) - [Moments](#moments) - [NBG](#nbg) + - [OneHot](#onehot) - [Pad](#pad) - [Pool2d](#pool2d) + - [Classic Pool2d](#classic-pool2d) + - [Global Pool2d](#global-pool2d) + - [Adaptive Pool2d](#adaptive-pool2d) - [ReduceMin](#reducemin) - [ReduceMax](#reducemax) - [ReduceAny](#reduceany) @@ -78,6 +86,7 @@ - [Squeeze](#squeeze) - [Stack](#stack) - [StridedSlice](#stridedslice) + - [Svdf](#svdf) - [Tile](#tile) - [Transpose](#transpose) - [Unidirectional sequence lstm](#unidirectional-sequence-lstm) @@ -108,12 +117,12 @@ Swish(x) : x * sigmoid(x) HardSwish(x) : 0 if x <= -3; x(x + 3)/6 if -3 < x < 3; x if x >= 3 -Mish(x) : x if x >= 0 else alpha * x - HardSigmoid(x) : min(max(alpha*x + beta, 0), 1) SoftRelu(x) : log(1 + e^x). Also known as SoftPlus. +Mish(x) : x * tanh(softrelu(x)) + LeakyRelu(x) : alpha * x if x <= 0; x if x > 0. alpha is a scalar. Prelu(x) : alpha * x if x <= 0; x if x > 0. alpha is a tensor. @@ -153,7 +162,22 @@ rank as the input. This is the reverse transformation of Space2Batch. Carries out batch normalization as described in the paper https://arxiv.org/abs/1502.03167. -Y = (X - Mean) / Sqrt( Var + Eps) * Gama + Beta +$$\hat x_i\leftarrow \frac{x_i-\mu_\mathcal{B}}{\sqrt{\sigma_\mathcal{B}^2+\epsilon}}$$ + +$$y_i=\gamma\hat x_i+\beta\equiv BN_{\gamma,\beta}(x_i)$$ + + +## Broadcast + +Broadcast an array for a compatible shape. See also numpy.broadcast_to(). + +Input: +- input. + +Attribute: +- shape: the shape which broadcast to. +- dimensions(optional): Which dimension in the target shape each dimension +of the operand shape corresponds to. For BroadcastInDim. ## Clip @@ -189,6 +213,28 @@ Attribute: but the value is different. multiplier = weights / group. - layout : WHCN or CWHN. + +## Conv3d + +Performs a 3-D convolution operation + +Input: +- input [WHDCN]. +- kernel [ WHDIcOc ] (Ic: Input Channels. Oc: Output Channels). +- bias [ O ]. Optional. + +Attribute: +- weights : the output channel number for weight tensor. +- ksize : the height and width for weight tensor. +- padding : AUTO, VALID or SAME. +- pad : pad value for each spatial axis. (left, right, top, bottom, front, rear). +- stride : stride along each spatial axis. +- dilation : dilation value along each spatial axis of the filter. +- multiplier: function similar to group attribute on other framework, +but the value is different. multiplier = weights / group. +- input_layout : WHDCN or WHCDN. +- kernel_layout : WHDIcOc + ## DeConv2d @@ -292,6 +338,13 @@ Maximum(x, y) : max(x, y). This operation supports broadcasting. FloorDiv(x, y): floor( x / y ). This operation supports broadcasting. + +## Erf + +Computes the Gauss error function of x element-wise. + +- no parameters + ## FullyConnected @@ -311,6 +364,26 @@ Gather slices from input, **axis** according to **indices**. An operation similar to Gather but gathers across multiple axis at once. + +## GroupedConv1d + +Performs a grouped 1-D convolution operation. + +Input: +- input [WCN]. +- kernel [ WIcOc ] (Ic: Input Channels. Oc: Output Channels).Ic*group=C. +- bias [ O ]. Optional. + +Attribute: +- weights : the output channel number for weight tensor. +- ksize : the height and width for weight tensor. +- padding : AUTO, VALID or SAME. +- pad : pad value for each spatial axis. +- stride : stride along each spatial axis. +- dilation : dilation value along each spatial axis of the filter. +- group: Split conv to n group. +- layout : WCN or CWN. + ## GroupedConv2d @@ -415,24 +488,59 @@ If x is 1-D and axes = [0] this is just the mean and variance of a vector. Network Binary Graph is a precompile technology, which can compile a fuse graph into a bianry file. + +## OneHot + +Create a one-hot tensor. + +- depth : A scalar defining the depth of the one hot dimension. +- on_value : A scalar defining the value to fill in output. +- off_value : A scalar defining the value to fill in output. +- axis : The axis to fill. + ## Pad Pads a tensor. - const_val : the value to pad. +- pad_mode : the mode of pad. +- front_size : Add pad values to the left and top. +- back_size : Add pad values to the right and bottom. ## Pool2d + +### Classic Pool2d + Performs an 2-D pooling operation. - type : MAX, AVG, L2 or AVG_ANDROID. - padding : AUTO, VALID or SAME. +- pad : Specify the number of pad values for left, right, top, and bottom. - ksize : filter size. - stride : stride along each spatial axis. - round_type : CEILING or FLOOR. + +### Global Pool2d + +- type : MAX, AVG, L2 or AVG_ANDROID. +- input_size : input size(only [W, H]) +- round_type : CEILING or FLOOR. + + +### Adaptive Pool2d + +Same as torch.nn.AdaptiveXXXPool2d. + +- type : MAX, AVG, L2 or AVG_ANDROID. +- input_size : input size(only [W, H]) +- output_size : output size(only [W, H]) +- round_type : CEILING or FLOOR. + + ## ReduceMin @@ -714,12 +822,13 @@ Removes dimensions of size 1 from the shape of a tensor. ## Stack Packs the list of tensors in inputs into a tensor with rank one higher than -each tensor in values, by packing them along the **axis** dimension. +each tensor in values, by packing them along the **axis** dimension. +Dimensions below the dimension specified by axis will be packed together with other inputs. ## StridedSlice -Extracts a strided slice of a tensor. +Extracts a strided slice of a tensor.Same as tensorflow. Roughly speaking, this op extracts a slice of size (end - begin) / stride from the given input tensor. Starting at the location specified by begin the slice @@ -738,6 +847,15 @@ specification shrinks the dimensionality by 1, taking on the value at index begi In this case, the ith specification must define a slice of size 1, e.g. begin[i] = x, end[i] = x + 1. + +## Svdf + +Performs an 2-D pooling operation. + +- rank : The rank of the SVD approximation. +- num_units : corresponds to the number of units. +- spectrogram_length : corresponds to the fixed-size of the memory. + ## Tile @@ -765,4 +883,4 @@ how to bind input/output: take unidirectional_sequence_lstm_test.cc Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors. - axis : An int. The axis to unstack along. Defaults to the first dimension. -Negative values wrap around, so the valid range is [-R, R). +Negative values wrap around, so the valid range is [-R, R). \ No newline at end of file diff --git a/include/tim/vx/ops.h b/include/tim/vx/ops.h index 9f2cc87..253f698 100644 --- a/include/tim/vx/ops.h +++ b/include/tim/vx/ops.h @@ -29,6 +29,7 @@ #include "tim/vx/ops/arg.h" #include "tim/vx/ops/batch2space.h" #include "tim/vx/ops/batchnorm.h" +#include "tim/vx/ops/broadcast.h" #include "tim/vx/ops/clip.h" #include "tim/vx/ops/concat.h" #include "tim/vx/ops/conv1d.h" diff --git a/include/tim/vx/ops/broadcast.h b/include/tim/vx/ops/broadcast.h new file mode 100644 index 0000000..c6e4a67 --- /dev/null +++ b/include/tim/vx/ops/broadcast.h @@ -0,0 +1,61 @@ +/**************************************************************************** +* +* Copyright (c) 2020 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef OVXLIBXX_OPERATIONS_BROADCAST_H_ +#define OVXLIBXX_OPERATIONS_BROADCAST_H_ +#include "tim/vx/direct_map_op.h" + +namespace tim { +namespace vx { +namespace ops { + +/** + * ## Broadcast + * + * Broadcast an array for a compatible shape. See also numpy.broadcast_to(). + * + * Input: + * - input. + * + * Attribute: + * - shape: the shape which broadcast to. + * - dimensions(optional): Which dimension in the target shape each dimension + * of the operand shape corresponds to. For BroadcastInDim. + */ + +class Broadcast : public DirectMapOp { + public: + Broadcast(Graph* graph, const std::vector& shape, const std::vector& dimensions = {}); + + std::shared_ptr Clone(std::shared_ptr& graph) const override; + + protected: + const std::vector shape_; + const std::vector dimensions_; +}; + +} // namespace ops +} // namespace vx +} // namespace tim + +#endif /* OVXLIBXX_OPERATIONS_BROADCAST_H_ */ diff --git a/src/tim/vx/ops/README.md b/src/tim/vx/ops/README.md index dd25069..50322c3 100644 --- a/src/tim/vx/ops/README.md +++ b/src/tim/vx/ops/README.md @@ -104,6 +104,7 @@ Erf|ERF|Mapped|[tf.math.erf](https://tensorflow.google.cn/api_docs/python/tf/mat GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/Conv1D?hl=en) |SignalFrame|SIGNAL_FRAME|Mapped|[tf.signal.frame](https://tensorflow.google.cn/api_docs/python/tf/signal/frame) |RNNCell|RNNCELL_OVXLIB|Mapped|[ANEURALNETWORKS_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0acd2684ac9c73bb29767b534e78a332e8) +|BroadCast|EXPAND_BROADCAST|Mapped|[numpy.broadcast_to](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html) ||PROPOSAL| TBD |[Faster-RCNN Proposal Layer](https://github.com/intel/caffe/blob/master/examples/faster-rcnn/lib/rpn/proposal_layer.py) ||ROI_POOL|Planned 22Q1 |[ANEURALNETWORKS_ROI_POOLING](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a6736198af337b2efbdb0b6b64dee7fe4) ||ROI_ALIGN| TBD |[ANEURALNETWORKS_ROI_ALIGN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a2848b39dd4bfba78f2438fda0d9397a4) diff --git a/src/tim/vx/ops/broadcast.cc b/src/tim/vx/ops/broadcast.cc new file mode 100644 index 0000000..4a1d5ac --- /dev/null +++ b/src/tim/vx/ops/broadcast.cc @@ -0,0 +1,57 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/ops/broadcast.h" + +#include +#include "direct_map_op_impl.h" +#include "vsi_nn_pub.h" + +namespace tim { +namespace vx { +namespace ops { +Broadcast::Broadcast(Graph* graph, const std::vector& shape, + const std::vector& dimensions) + : DirectMapOp(graph, VSI_NN_OP_EXPAND_BROADCAST), + shape_(shape), + dimensions_(dimensions) { + this->impl()->node()->nn_param.expand_broadcast.dim_num = shape_.size(); + this->impl()->node()->nn_param.expand_broadcast.shape = (uint32_t*)shape_.data(); + this->impl()->node()->nn_param.expand_broadcast.dimensions_num = dimensions_.size(); + if (dimensions.size() > 0) + { + this->impl()->node()->nn_param.expand_broadcast.dimensions = (uint32_t*)dimensions_.data(); + } else { + this->impl()->node()->nn_param.expand_broadcast.dimensions = nullptr; + } + +} + +std::shared_ptr Broadcast::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(this->shape_, this->dimensions_); +} + +} // namespace ops +} // namespace vx +} // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/broadcast_test.cc b/src/tim/vx/ops/broadcast_test.cc new file mode 100644 index 0000000..b84d55e --- /dev/null +++ b/src/tim/vx/ops/broadcast_test.cc @@ -0,0 +1,329 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops/broadcast.h" +#include "tim/transform/layout_inference.h" + +#include "gtest/gtest.h" +#include "test_utils.h" + +static void CheckResult(std::shared_ptr& graph, + std::vector& golden, + std::shared_ptr& output_tensor) { + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size() * sizeof(float)); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); +} + +TEST(Broadcast, ScalarTo2D_2x3) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({1}); + tim::vx::ShapeType output_shape({3, 2}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 2.25f, + }; + std::vector golden = { + 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, + }; + std::vector shape = {3, 2}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(shape); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + CheckResult(graph, golden, output_tensor); +} + +TEST(Broadcast, 1DTo2D) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({3}); + tim::vx::ShapeType output_shape({3, 2}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 1.f, 2.f, 3.f, + }; + std::vector golden = { + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + }; + std::vector shape = {3, 2}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(shape); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + CheckResult(graph, golden, output_tensor); +} + +TEST(Broadcast, 1DTo2D_WithDims0) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({2}); + tim::vx::ShapeType output_shape({2, 2}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 1.f, 2.f, + }; + std::vector golden = { + 1.f, 2.f, + 1.f, 2.f, + }; + std::vector shape = {2, 2}; + std::vector dimensions = {0}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(shape, dimensions); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + CheckResult(graph, golden, output_tensor); +} + +TEST(Broadcast, 1DTo2D_WithDims1) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({2}); + tim::vx::ShapeType output_shape({2, 2}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 1.f, 2.f, + }; + std::vector golden = { + 1.f, 1.f, + 2.f, 2.f, + }; + std::vector shape = {2, 2}; + std::vector dimensions = {1}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(shape, dimensions); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + CheckResult(graph, golden, output_tensor); +} + +TEST(Broadcast, 1DTo3D_WithDims0) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({2}); + tim::vx::ShapeType output_shape({2, 2, 2}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 1.f, 2.f, + }; + std::vector golden = { + 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, + }; + std::vector shape = {2, 2, 2}; + std::vector dimensions = {0}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(shape, dimensions); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + CheckResult(graph, golden, output_tensor); +} + +TEST(Broadcast, 1DTo3D_WithDims1) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({2}); + tim::vx::ShapeType output_shape({2, 2, 2}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 1.f, 2.f, + }; + std::vector golden = { + 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, + }; + std::vector shape = {2, 2, 2}; + std::vector dimensions = {1}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(shape, dimensions); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + CheckResult(graph, golden, output_tensor); +} + +TEST(Broadcast, 1DTo3D_WithDims2) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({2}); + tim::vx::ShapeType output_shape({2, 2, 2}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 1.f, 2.f, + }; + std::vector golden = { + 1.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 2.f, + }; + std::vector shape = {2, 2, 2}; + std::vector dimensions = {2}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(shape, dimensions); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + CheckResult(graph, golden, output_tensor); +} + +TEST(Broadcast, 2DTo3D_WithDims02) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({2, 2}); + tim::vx::ShapeType output_shape({2, 2, 2}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 1.f, 5.f, 2.f, 6.f + }; + std::vector golden = { + 1.f, 5.f, 1.f, 5.f, 2.f, 6.f, 2.f, 6.f, + }; + std::vector shape = {2, 2, 2}; + std::vector dimensions = {0, 2}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(shape, dimensions); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + CheckResult(graph, golden, output_tensor); +} + +TEST(Broadcast, 2DTo3D_WithDims12) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({2, 2}); + tim::vx::ShapeType output_shape({2, 2, 2}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 1.f, 5.f, 2.f, 6.f + }; + std::vector golden = { + 1.f, 1.f, 5.f, 5.f, 2.f, 2.f, 6.f, 6.f, + }; + std::vector shape = {2, 2, 2}; + std::vector dimensions = {1, 2}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(shape, dimensions); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + CheckResult(graph, golden, output_tensor); +}