diff --git a/docs/Operators.md b/docs/Operators.md
index e782f86..17e8bf7 100644
--- a/docs/Operators.md
+++ b/docs/Operators.md
@@ -7,9 +7,11 @@
- [ArgMin/ArgMax](#argminargmax)
- [Batch2Space](#batch2space)
- [BatchNorm](#batchnorm)
+ - [Broadcast](#broadcast)
- [Clip](#clip)
- [Concat](#concat)
- [Conv2d](#conv2d)
+ - [Conv3d](#conv3d)
- [DeConv2d](#deconv2d)
- [DeConv1d](#deconv1d)
- [DepthToSpace](#depthtospace)
@@ -22,9 +24,11 @@
- [Minimum](#minimum)
- [Maximum](#maximum)
- [FloorDiv](#floordiv)
+ - [Erf](#erf)
- [FullyConnected](#fullyconnected)
- [Gather](#gather)
- [GatherNd](#gathernd)
+ - [GroupedConv1d](#groupedconv1d)
- [GroupedConv2d](#groupedconv2d)
- [L2Normalization](#l2normalization)
- [LocalResponseNormalization](#localresponsenormalization)
@@ -36,8 +40,12 @@
- [MaxUnpool2d](#maxunpool2d)
- [Moments](#moments)
- [NBG](#nbg)
+ - [OneHot](#onehot)
- [Pad](#pad)
- [Pool2d](#pool2d)
+ - [Classic Pool2d](#classic-pool2d)
+ - [Global Pool2d](#global-pool2d)
+ - [Adaptive Pool2d](#adaptive-pool2d)
- [ReduceMin](#reducemin)
- [ReduceMax](#reducemax)
- [ReduceAny](#reduceany)
@@ -78,6 +86,7 @@
- [Squeeze](#squeeze)
- [Stack](#stack)
- [StridedSlice](#stridedslice)
+ - [Svdf](#svdf)
- [Tile](#tile)
- [Transpose](#transpose)
- [Unidirectional sequence lstm](#unidirectional-sequence-lstm)
@@ -108,12 +117,12 @@ Swish(x) : x * sigmoid(x)
HardSwish(x) : 0 if x <= -3; x(x + 3)/6 if -3 < x < 3; x if x >= 3
-Mish(x) : x if x >= 0 else alpha * x
-
HardSigmoid(x) : min(max(alpha*x + beta, 0), 1)
SoftRelu(x) : log(1 + e^x). Also known as SoftPlus.
+Mish(x) : x * tanh(softrelu(x))
+
LeakyRelu(x) : alpha * x if x <= 0; x if x > 0. alpha is a scalar.
Prelu(x) : alpha * x if x <= 0; x if x > 0. alpha is a tensor.
@@ -153,7 +162,22 @@ rank as the input. This is the reverse transformation of Space2Batch.
Carries out batch normalization as described in the paper
https://arxiv.org/abs/1502.03167.
-Y = (X - Mean) / Sqrt( Var + Eps) * Gama + Beta
+$$\hat x_i\leftarrow \frac{x_i-\mu_\mathcal{B}}{\sqrt{\sigma_\mathcal{B}^2+\epsilon}}$$
+
+$$y_i=\gamma\hat x_i+\beta\equiv BN_{\gamma,\beta}(x_i)$$
+
+
+## Broadcast
+
+Broadcast an array for a compatible shape. See also numpy.broadcast_to().
+
+Input:
+- input.
+
+Attribute:
+- shape: the shape which broadcast to.
+- dimensions(optional): Which dimension in the target shape each dimension
+of the operand shape corresponds to. For BroadcastInDim.
## Clip
@@ -189,6 +213,28 @@ Attribute:
but the value is different. multiplier = weights / group.
- layout : WHCN or CWHN.
+
+## Conv3d
+
+Performs a 3-D convolution operation
+
+Input:
+- input [WHDCN].
+- kernel [ WHDIcOc ] (Ic: Input Channels. Oc: Output Channels).
+- bias [ O ]. Optional.
+
+Attribute:
+- weights : the output channel number for weight tensor.
+- ksize : the height and width for weight tensor.
+- padding : AUTO, VALID or SAME.
+- pad : pad value for each spatial axis. (left, right, top, bottom, front, rear).
+- stride : stride along each spatial axis.
+- dilation : dilation value along each spatial axis of the filter.
+- multiplier: function similar to group attribute on other framework,
+but the value is different. multiplier = weights / group.
+- input_layout : WHDCN or WHCDN.
+- kernel_layout : WHDIcOc
+
## DeConv2d
@@ -292,6 +338,13 @@ Maximum(x, y) : max(x, y). This operation supports broadcasting.
FloorDiv(x, y): floor( x / y ). This operation supports broadcasting.
+
+## Erf
+
+Computes the Gauss error function of x element-wise.
+
+- no parameters
+
## FullyConnected
@@ -311,6 +364,26 @@ Gather slices from input, **axis** according to **indices**.
An operation similar to Gather but gathers across multiple axis at once.
+
+## GroupedConv1d
+
+Performs a grouped 1-D convolution operation.
+
+Input:
+- input [WCN].
+- kernel [ WIcOc ] (Ic: Input Channels. Oc: Output Channels).Ic*group=C.
+- bias [ O ]. Optional.
+
+Attribute:
+- weights : the output channel number for weight tensor.
+- ksize : the height and width for weight tensor.
+- padding : AUTO, VALID or SAME.
+- pad : pad value for each spatial axis.
+- stride : stride along each spatial axis.
+- dilation : dilation value along each spatial axis of the filter.
+- group: Split conv to n group.
+- layout : WCN or CWN.
+
## GroupedConv2d
@@ -415,24 +488,59 @@ If x is 1-D and axes = [0] this is just the mean and variance of a vector.
Network Binary Graph is a precompile technology, which can compile a fuse graph into
a bianry file.
+
+## OneHot
+
+Create a one-hot tensor.
+
+- depth : A scalar defining the depth of the one hot dimension.
+- on_value : A scalar defining the value to fill in output.
+- off_value : A scalar defining the value to fill in output.
+- axis : The axis to fill.
+
## Pad
Pads a tensor.
- const_val : the value to pad.
+- pad_mode : the mode of pad.
+- front_size : Add pad values to the left and top.
+- back_size : Add pad values to the right and bottom.
## Pool2d
+
+### Classic Pool2d
+
Performs an 2-D pooling operation.
- type : MAX, AVG, L2 or AVG_ANDROID.
- padding : AUTO, VALID or SAME.
+- pad : Specify the number of pad values for left, right, top, and bottom.
- ksize : filter size.
- stride : stride along each spatial axis.
- round_type : CEILING or FLOOR.
+
+### Global Pool2d
+
+- type : MAX, AVG, L2 or AVG_ANDROID.
+- input_size : input size(only [W, H])
+- round_type : CEILING or FLOOR.
+
+
+### Adaptive Pool2d
+
+Same as torch.nn.AdaptiveXXXPool2d.
+
+- type : MAX, AVG, L2 or AVG_ANDROID.
+- input_size : input size(only [W, H])
+- output_size : output size(only [W, H])
+- round_type : CEILING or FLOOR.
+
+
## ReduceMin
@@ -714,12 +822,13 @@ Removes dimensions of size 1 from the shape of a tensor.
## Stack
Packs the list of tensors in inputs into a tensor with rank one higher than
-each tensor in values, by packing them along the **axis** dimension.
+each tensor in values, by packing them along the **axis** dimension.
+Dimensions below the dimension specified by axis will be packed together with other inputs.
## StridedSlice
-Extracts a strided slice of a tensor.
+Extracts a strided slice of a tensor.Same as tensorflow.
Roughly speaking, this op extracts a slice of size (end - begin) / stride from
the given input tensor. Starting at the location specified by begin the slice
@@ -738,6 +847,15 @@ specification shrinks the dimensionality by 1, taking on the value at index begi
In this case, the ith specification must define a slice of size 1,
e.g. begin[i] = x, end[i] = x + 1.
+
+## Svdf
+
+Performs an 2-D pooling operation.
+
+- rank : The rank of the SVD approximation.
+- num_units : corresponds to the number of units.
+- spectrogram_length : corresponds to the fixed-size of the memory.
+
## Tile
@@ -765,4 +883,4 @@ how to bind input/output: take unidirectional_sequence_lstm_test.cc
Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.
- axis : An int. The axis to unstack along. Defaults to the first dimension.
-Negative values wrap around, so the valid range is [-R, R).
+Negative values wrap around, so the valid range is [-R, R).
\ No newline at end of file
diff --git a/include/tim/vx/ops.h b/include/tim/vx/ops.h
index 9f2cc87..253f698 100644
--- a/include/tim/vx/ops.h
+++ b/include/tim/vx/ops.h
@@ -29,6 +29,7 @@
#include "tim/vx/ops/arg.h"
#include "tim/vx/ops/batch2space.h"
#include "tim/vx/ops/batchnorm.h"
+#include "tim/vx/ops/broadcast.h"
#include "tim/vx/ops/clip.h"
#include "tim/vx/ops/concat.h"
#include "tim/vx/ops/conv1d.h"
diff --git a/include/tim/vx/ops/broadcast.h b/include/tim/vx/ops/broadcast.h
new file mode 100644
index 0000000..c6e4a67
--- /dev/null
+++ b/include/tim/vx/ops/broadcast.h
@@ -0,0 +1,61 @@
+/****************************************************************************
+*
+* Copyright (c) 2020 Vivante Corporation
+*
+* Permission is hereby granted, free of charge, to any person obtaining a
+* copy of this software and associated documentation files (the "Software"),
+* to deal in the Software without restriction, including without limitation
+* the rights to use, copy, modify, merge, publish, distribute, sublicense,
+* and/or sell copies of the Software, and to permit persons to whom the
+* Software is furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in
+* all copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+* DEALINGS IN THE SOFTWARE.
+*
+*****************************************************************************/
+#ifndef OVXLIBXX_OPERATIONS_BROADCAST_H_
+#define OVXLIBXX_OPERATIONS_BROADCAST_H_
+#include "tim/vx/direct_map_op.h"
+
+namespace tim {
+namespace vx {
+namespace ops {
+
+/**
+ * ## Broadcast
+ *
+ * Broadcast an array for a compatible shape. See also numpy.broadcast_to().
+ *
+ * Input:
+ * - input.
+ *
+ * Attribute:
+ * - shape: the shape which broadcast to.
+ * - dimensions(optional): Which dimension in the target shape each dimension
+ * of the operand shape corresponds to. For BroadcastInDim.
+ */
+
+class Broadcast : public DirectMapOp {
+ public:
+ Broadcast(Graph* graph, const std::vector& shape, const std::vector& dimensions = {});
+
+ std::shared_ptr Clone(std::shared_ptr& graph) const override;
+
+ protected:
+ const std::vector shape_;
+ const std::vector dimensions_;
+};
+
+} // namespace ops
+} // namespace vx
+} // namespace tim
+
+#endif /* OVXLIBXX_OPERATIONS_BROADCAST_H_ */
diff --git a/src/tim/vx/ops/README.md b/src/tim/vx/ops/README.md
index dd25069..50322c3 100644
--- a/src/tim/vx/ops/README.md
+++ b/src/tim/vx/ops/README.md
@@ -104,6 +104,7 @@ Erf|ERF|Mapped|[tf.math.erf](https://tensorflow.google.cn/api_docs/python/tf/mat
GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/Conv1D?hl=en)
|SignalFrame|SIGNAL_FRAME|Mapped|[tf.signal.frame](https://tensorflow.google.cn/api_docs/python/tf/signal/frame)
|RNNCell|RNNCELL_OVXLIB|Mapped|[ANEURALNETWORKS_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0acd2684ac9c73bb29767b534e78a332e8)
+|BroadCast|EXPAND_BROADCAST|Mapped|[numpy.broadcast_to](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html)
||PROPOSAL| TBD |[Faster-RCNN Proposal Layer](https://github.com/intel/caffe/blob/master/examples/faster-rcnn/lib/rpn/proposal_layer.py)
||ROI_POOL|Planned 22Q1 |[ANEURALNETWORKS_ROI_POOLING](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a6736198af337b2efbdb0b6b64dee7fe4)
||ROI_ALIGN| TBD |[ANEURALNETWORKS_ROI_ALIGN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a2848b39dd4bfba78f2438fda0d9397a4)
diff --git a/src/tim/vx/ops/broadcast.cc b/src/tim/vx/ops/broadcast.cc
new file mode 100644
index 0000000..4a1d5ac
--- /dev/null
+++ b/src/tim/vx/ops/broadcast.cc
@@ -0,0 +1,57 @@
+/****************************************************************************
+*
+* Copyright (c) 2021 Vivante Corporation
+*
+* Permission is hereby granted, free of charge, to any person obtaining a
+* copy of this software and associated documentation files (the "Software"),
+* to deal in the Software without restriction, including without limitation
+* the rights to use, copy, modify, merge, publish, distribute, sublicense,
+* and/or sell copies of the Software, and to permit persons to whom the
+* Software is furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in
+* all copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+* DEALINGS IN THE SOFTWARE.
+*
+*****************************************************************************/
+#include "tim/vx/ops/broadcast.h"
+
+#include
+#include "direct_map_op_impl.h"
+#include "vsi_nn_pub.h"
+
+namespace tim {
+namespace vx {
+namespace ops {
+Broadcast::Broadcast(Graph* graph, const std::vector& shape,
+ const std::vector& dimensions)
+ : DirectMapOp(graph, VSI_NN_OP_EXPAND_BROADCAST),
+ shape_(shape),
+ dimensions_(dimensions) {
+ this->impl()->node()->nn_param.expand_broadcast.dim_num = shape_.size();
+ this->impl()->node()->nn_param.expand_broadcast.shape = (uint32_t*)shape_.data();
+ this->impl()->node()->nn_param.expand_broadcast.dimensions_num = dimensions_.size();
+ if (dimensions.size() > 0)
+ {
+ this->impl()->node()->nn_param.expand_broadcast.dimensions = (uint32_t*)dimensions_.data();
+ } else {
+ this->impl()->node()->nn_param.expand_broadcast.dimensions = nullptr;
+ }
+
+}
+
+std::shared_ptr Broadcast::Clone(
+ std::shared_ptr& graph) const {
+ return graph->CreateOperation(this->shape_, this->dimensions_);
+}
+
+} // namespace ops
+} // namespace vx
+} // namespace tim
\ No newline at end of file
diff --git a/src/tim/vx/ops/broadcast_test.cc b/src/tim/vx/ops/broadcast_test.cc
new file mode 100644
index 0000000..b84d55e
--- /dev/null
+++ b/src/tim/vx/ops/broadcast_test.cc
@@ -0,0 +1,329 @@
+/****************************************************************************
+*
+* Copyright (c) 2021 Vivante Corporation
+*
+* Permission is hereby granted, free of charge, to any person obtaining a
+* copy of this software and associated documentation files (the "Software"),
+* to deal in the Software without restriction, including without limitation
+* the rights to use, copy, modify, merge, publish, distribute, sublicense,
+* and/or sell copies of the Software, and to permit persons to whom the
+* Software is furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in
+* all copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+* DEALINGS IN THE SOFTWARE.
+*
+*****************************************************************************/
+#include "tim/vx/context.h"
+#include "tim/vx/graph.h"
+#include "tim/vx/ops/broadcast.h"
+#include "tim/transform/layout_inference.h"
+
+#include "gtest/gtest.h"
+#include "test_utils.h"
+
+static void CheckResult(std::shared_ptr& graph,
+ std::vector& golden,
+ std::shared_ptr& output_tensor) {
+ EXPECT_TRUE(graph->Compile());
+ EXPECT_TRUE(graph->Run());
+
+ std::vector output(golden.size() * sizeof(float));
+ EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
+ EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
+}
+
+TEST(Broadcast, ScalarTo2D_2x3) {
+ auto ctx = tim::vx::Context::Create();
+ auto graph = ctx->CreateGraph();
+
+ tim::vx::ShapeType input_shape({1});
+ tim::vx::ShapeType output_shape({3, 2});
+ tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
+ tim::vx::TensorAttribute::INPUT);
+ tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
+ tim::vx::TensorAttribute::OUTPUT);
+
+ auto input_tensor = graph->CreateTensor(input_spec);
+ auto output_tensor = graph->CreateTensor(output_spec);
+
+ std::vector in_data = {
+ 2.25f,
+ };
+ std::vector golden = {
+ 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f,
+ };
+ std::vector shape = {3, 2};
+
+ EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
+ in_data.size() * sizeof(float)));
+
+ auto op = graph->CreateOperation(shape);
+ (*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
+
+ CheckResult(graph, golden, output_tensor);
+}
+
+TEST(Broadcast, 1DTo2D) {
+ auto ctx = tim::vx::Context::Create();
+ auto graph = ctx->CreateGraph();
+
+ tim::vx::ShapeType input_shape({3});
+ tim::vx::ShapeType output_shape({3, 2});
+ tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
+ tim::vx::TensorAttribute::INPUT);
+ tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
+ tim::vx::TensorAttribute::OUTPUT);
+
+ auto input_tensor = graph->CreateTensor(input_spec);
+ auto output_tensor = graph->CreateTensor(output_spec);
+
+ std::vector in_data = {
+ 1.f, 2.f, 3.f,
+ };
+ std::vector golden = {
+ 1.f, 2.f, 3.f, 1.f, 2.f, 3.f,
+ };
+ std::vector shape = {3, 2};
+
+ EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
+ in_data.size() * sizeof(float)));
+
+ auto op = graph->CreateOperation(shape);
+ (*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
+
+ CheckResult(graph, golden, output_tensor);
+}
+
+TEST(Broadcast, 1DTo2D_WithDims0) {
+ auto ctx = tim::vx::Context::Create();
+ auto graph = ctx->CreateGraph();
+
+ tim::vx::ShapeType input_shape({2});
+ tim::vx::ShapeType output_shape({2, 2});
+ tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
+ tim::vx::TensorAttribute::INPUT);
+ tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
+ tim::vx::TensorAttribute::OUTPUT);
+
+ auto input_tensor = graph->CreateTensor(input_spec);
+ auto output_tensor = graph->CreateTensor(output_spec);
+
+ std::vector in_data = {
+ 1.f, 2.f,
+ };
+ std::vector golden = {
+ 1.f, 2.f,
+ 1.f, 2.f,
+ };
+ std::vector shape = {2, 2};
+ std::vector dimensions = {0};
+
+ EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
+ in_data.size() * sizeof(float)));
+
+ auto op = graph->CreateOperation(shape, dimensions);
+ (*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
+
+ CheckResult(graph, golden, output_tensor);
+}
+
+TEST(Broadcast, 1DTo2D_WithDims1) {
+ auto ctx = tim::vx::Context::Create();
+ auto graph = ctx->CreateGraph();
+
+ tim::vx::ShapeType input_shape({2});
+ tim::vx::ShapeType output_shape({2, 2});
+ tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
+ tim::vx::TensorAttribute::INPUT);
+ tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
+ tim::vx::TensorAttribute::OUTPUT);
+
+ auto input_tensor = graph->CreateTensor(input_spec);
+ auto output_tensor = graph->CreateTensor(output_spec);
+
+ std::vector in_data = {
+ 1.f, 2.f,
+ };
+ std::vector golden = {
+ 1.f, 1.f,
+ 2.f, 2.f,
+ };
+ std::vector shape = {2, 2};
+ std::vector dimensions = {1};
+
+ EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
+ in_data.size() * sizeof(float)));
+
+ auto op = graph->CreateOperation(shape, dimensions);
+ (*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
+
+ CheckResult(graph, golden, output_tensor);
+}
+
+TEST(Broadcast, 1DTo3D_WithDims0) {
+ auto ctx = tim::vx::Context::Create();
+ auto graph = ctx->CreateGraph();
+
+ tim::vx::ShapeType input_shape({2});
+ tim::vx::ShapeType output_shape({2, 2, 2});
+ tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
+ tim::vx::TensorAttribute::INPUT);
+ tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
+ tim::vx::TensorAttribute::OUTPUT);
+
+ auto input_tensor = graph->CreateTensor(input_spec);
+ auto output_tensor = graph->CreateTensor(output_spec);
+
+ std::vector in_data = {
+ 1.f, 2.f,
+ };
+ std::vector golden = {
+ 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f,
+ };
+ std::vector shape = {2, 2, 2};
+ std::vector dimensions = {0};
+
+ EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
+ in_data.size() * sizeof(float)));
+
+ auto op = graph->CreateOperation(shape, dimensions);
+ (*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
+
+ CheckResult(graph, golden, output_tensor);
+}
+
+TEST(Broadcast, 1DTo3D_WithDims1) {
+ auto ctx = tim::vx::Context::Create();
+ auto graph = ctx->CreateGraph();
+
+ tim::vx::ShapeType input_shape({2});
+ tim::vx::ShapeType output_shape({2, 2, 2});
+ tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
+ tim::vx::TensorAttribute::INPUT);
+ tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
+ tim::vx::TensorAttribute::OUTPUT);
+
+ auto input_tensor = graph->CreateTensor(input_spec);
+ auto output_tensor = graph->CreateTensor(output_spec);
+
+ std::vector in_data = {
+ 1.f, 2.f,
+ };
+ std::vector golden = {
+ 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f,
+ };
+ std::vector shape = {2, 2, 2};
+ std::vector dimensions = {1};
+
+ EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
+ in_data.size() * sizeof(float)));
+
+ auto op = graph->CreateOperation(shape, dimensions);
+ (*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
+
+ CheckResult(graph, golden, output_tensor);
+}
+
+TEST(Broadcast, 1DTo3D_WithDims2) {
+ auto ctx = tim::vx::Context::Create();
+ auto graph = ctx->CreateGraph();
+
+ tim::vx::ShapeType input_shape({2});
+ tim::vx::ShapeType output_shape({2, 2, 2});
+ tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
+ tim::vx::TensorAttribute::INPUT);
+ tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
+ tim::vx::TensorAttribute::OUTPUT);
+
+ auto input_tensor = graph->CreateTensor(input_spec);
+ auto output_tensor = graph->CreateTensor(output_spec);
+
+ std::vector in_data = {
+ 1.f, 2.f,
+ };
+ std::vector golden = {
+ 1.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 2.f,
+ };
+ std::vector shape = {2, 2, 2};
+ std::vector dimensions = {2};
+
+ EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
+ in_data.size() * sizeof(float)));
+
+ auto op = graph->CreateOperation(shape, dimensions);
+ (*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
+
+ CheckResult(graph, golden, output_tensor);
+}
+
+TEST(Broadcast, 2DTo3D_WithDims02) {
+ auto ctx = tim::vx::Context::Create();
+ auto graph = ctx->CreateGraph();
+
+ tim::vx::ShapeType input_shape({2, 2});
+ tim::vx::ShapeType output_shape({2, 2, 2});
+ tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
+ tim::vx::TensorAttribute::INPUT);
+ tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
+ tim::vx::TensorAttribute::OUTPUT);
+
+ auto input_tensor = graph->CreateTensor(input_spec);
+ auto output_tensor = graph->CreateTensor(output_spec);
+
+ std::vector in_data = {
+ 1.f, 5.f, 2.f, 6.f
+ };
+ std::vector golden = {
+ 1.f, 5.f, 1.f, 5.f, 2.f, 6.f, 2.f, 6.f,
+ };
+ std::vector shape = {2, 2, 2};
+ std::vector dimensions = {0, 2};
+
+ EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
+ in_data.size() * sizeof(float)));
+
+ auto op = graph->CreateOperation(shape, dimensions);
+ (*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
+
+ CheckResult(graph, golden, output_tensor);
+}
+
+TEST(Broadcast, 2DTo3D_WithDims12) {
+ auto ctx = tim::vx::Context::Create();
+ auto graph = ctx->CreateGraph();
+
+ tim::vx::ShapeType input_shape({2, 2});
+ tim::vx::ShapeType output_shape({2, 2, 2});
+ tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
+ tim::vx::TensorAttribute::INPUT);
+ tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
+ tim::vx::TensorAttribute::OUTPUT);
+
+ auto input_tensor = graph->CreateTensor(input_spec);
+ auto output_tensor = graph->CreateTensor(output_spec);
+
+ std::vector in_data = {
+ 1.f, 5.f, 2.f, 6.f
+ };
+ std::vector golden = {
+ 1.f, 1.f, 5.f, 5.f, 2.f, 2.f, 6.f, 6.f,
+ };
+ std::vector shape = {2, 2, 2};
+ std::vector dimensions = {1, 2};
+
+ EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
+ in_data.size() * sizeof(float)));
+
+ auto op = graph->CreateOperation(shape, dimensions);
+ (*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
+
+ CheckResult(graph, golden, output_tensor);
+}