Add Broadcast op (#365)

This commit is contained in:
Antkillerfarm 2022-04-18 15:45:15 +08:00 committed by GitHub
parent 96dedc1453
commit b916e1301a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 573 additions and 6 deletions

View File

@ -7,9 +7,11 @@
- [ArgMin/ArgMax](#argminargmax)
- [Batch2Space](#batch2space)
- [BatchNorm](#batchnorm)
- [Broadcast](#broadcast)
- [Clip](#clip)
- [Concat](#concat)
- [Conv2d](#conv2d)
- [Conv3d](#conv3d)
- [DeConv2d](#deconv2d)
- [DeConv1d](#deconv1d)
- [DepthToSpace](#depthtospace)
@ -22,9 +24,11 @@
- [Minimum](#minimum)
- [Maximum](#maximum)
- [FloorDiv](#floordiv)
- [Erf](#erf)
- [FullyConnected](#fullyconnected)
- [Gather](#gather)
- [GatherNd](#gathernd)
- [GroupedConv1d](#groupedconv1d)
- [GroupedConv2d](#groupedconv2d)
- [L2Normalization](#l2normalization)
- [LocalResponseNormalization](#localresponsenormalization)
@ -36,8 +40,12 @@
- [MaxUnpool2d](#maxunpool2d)
- [Moments](#moments)
- [NBG](#nbg)
- [OneHot](#onehot)
- [Pad](#pad)
- [Pool2d](#pool2d)
- [Classic Pool2d](#classic-pool2d)
- [Global Pool2d](#global-pool2d)
- [Adaptive Pool2d](#adaptive-pool2d)
- [ReduceMin](#reducemin)
- [ReduceMax](#reducemax)
- [ReduceAny](#reduceany)
@ -78,6 +86,7 @@
- [Squeeze](#squeeze)
- [Stack](#stack)
- [StridedSlice](#stridedslice)
- [Svdf](#svdf)
- [Tile](#tile)
- [Transpose](#transpose)
- [Unidirectional sequence lstm](#unidirectional-sequence-lstm)
@ -108,12 +117,12 @@ Swish(x) : x * sigmoid(x)
HardSwish(x) : 0 if x <= -3; x(x + 3)/6 if -3 < x < 3; x if x >= 3
Mish(x) : x if x >= 0 else alpha * x
HardSigmoid(x) : min(max(alpha*x + beta, 0), 1)
SoftRelu(x) : log(1 + e^x). Also known as SoftPlus.
Mish(x) : x * tanh(softrelu(x))
LeakyRelu(x) : alpha * x if x <= 0; x if x > 0. alpha is a scalar.
Prelu(x) : alpha * x if x <= 0; x if x > 0. alpha is a tensor.
@ -153,7 +162,22 @@ rank as the input. This is the reverse transformation of Space2Batch.
Carries out batch normalization as described in the paper
https://arxiv.org/abs/1502.03167.
Y = (X - Mean) / Sqrt( Var + Eps) * Gama + Beta
$$\hat x_i\leftarrow \frac{x_i-\mu_\mathcal{B}}{\sqrt{\sigma_\mathcal{B}^2+\epsilon}}$$
$$y_i=\gamma\hat x_i+\beta\equiv BN_{\gamma,\beta}(x_i)$$
<a class="mk-toclify" id="broadcast"></a>
## Broadcast
Broadcast an array for a compatible shape. See also numpy.broadcast_to().
Input:
- input.
Attribute:
- shape: the shape which broadcast to.
- dimensions(optional): Which dimension in the target shape each dimension
of the operand shape corresponds to. For BroadcastInDim.
<a class="mk-toclify" id="clip"></a>
## Clip
@ -189,6 +213,28 @@ Attribute:
but the value is different. multiplier = weights / group.
- layout : WHCN or CWHN.
<a class="mk-toclify" id="conv3d"></a>
## Conv3d
Performs a 3-D convolution operation
Input:
- input [WHDCN].
- kernel [ WHDIcOc ] (Ic: Input Channels. Oc: Output Channels).
- bias [ O ]. Optional.
Attribute:
- weights : the output channel number for weight tensor.
- ksize : the height and width for weight tensor.
- padding : AUTO, VALID or SAME.
- pad : pad value for each spatial axis. (left, right, top, bottom, front, rear).
- stride : stride along each spatial axis.
- dilation : dilation value along each spatial axis of the filter.
- multiplier: function similar to group attribute on other framework,
but the value is different. multiplier = weights / group.
- input_layout : WHDCN or WHCDN.
- kernel_layout : WHDIcOc
<a class="mk-toclify" id="deconv2d"></a>
## DeConv2d
@ -292,6 +338,13 @@ Maximum(x, y) : max(x, y). This operation supports broadcasting.
FloorDiv(x, y): floor( x / y ). This operation supports broadcasting.
<a class="mk-toclify" id="erf"></a>
## Erf
Computes the Gauss error function of x element-wise.
- no parameters
<a class="mk-toclify" id="fullyconnected"></a>
## FullyConnected
@ -311,6 +364,26 @@ Gather slices from input, **axis** according to **indices**.
An operation similar to Gather but gathers across multiple axis at once.
<a class="mk-toclify" id="groupedconv1d"></a>
## GroupedConv1d
Performs a grouped 1-D convolution operation.
Input:
- input [WCN].
- kernel [ WIcOc ] (Ic: Input Channels. Oc: Output Channels).Ic*group=C.
- bias [ O ]. Optional.
Attribute:
- weights : the output channel number for weight tensor.
- ksize : the height and width for weight tensor.
- padding : AUTO, VALID or SAME.
- pad : pad value for each spatial axis.
- stride : stride along each spatial axis.
- dilation : dilation value along each spatial axis of the filter.
- group: Split conv to n group.
- layout : WCN or CWN.
<a class="mk-toclify" id="groupedconv2d"></a>
## GroupedConv2d
@ -415,24 +488,59 @@ If x is 1-D and axes = [0] this is just the mean and variance of a vector.
Network Binary Graph is a precompile technology, which can compile a fuse graph into
a bianry file.
<a class="mk-toclify" id="onehot"></a>
## OneHot
Create a one-hot tensor.
- depth : A scalar defining the depth of the one hot dimension.
- on_value : A scalar defining the value to fill in output.
- off_value : A scalar defining the value to fill in output.
- axis : The axis to fill.
<a class="mk-toclify" id="pad"></a>
## Pad
Pads a tensor.
- const_val : the value to pad.
- pad_mode : the mode of pad.
- front_size : Add pad values to the left and top.
- back_size : Add pad values to the right and bottom.
<a class="mk-toclify" id="pool2d"></a>
## Pool2d
<a class="mk-toclify" id="classic-pool2d"></a>
### Classic Pool2d
Performs an 2-D pooling operation.
- type : MAX, AVG, L2 or AVG_ANDROID.
- padding : AUTO, VALID or SAME.
- pad : Specify the number of pad values for left, right, top, and bottom.
- ksize : filter size.
- stride : stride along each spatial axis.
- round_type : CEILING or FLOOR.
<a class="mk-toclify" id="global-pool2d"></a>
### Global Pool2d
- type : MAX, AVG, L2 or AVG_ANDROID.
- input_size : input size(only [W H])
- round_type : CEILING or FLOOR.
<a class="mk-toclify" id="adaptive-pool2d"></a>
### Adaptive Pool2d
Same as torch.nn.AdaptiveXXXPool2d.
- type : MAX, AVG, L2 or AVG_ANDROID.
- input_size : input size(only [W H])
- output_size : output size(only [W H])
- round_type : CEILING or FLOOR.
<a class="mk-toclify" id="reducemin"></a>
## ReduceMin
@ -714,12 +822,13 @@ Removes dimensions of size 1 from the shape of a tensor.
## Stack
Packs the list of tensors in inputs into a tensor with rank one higher than
each tensor in values, by packing them along the **axis** dimension.
each tensor in values, by packing them along the **axis** dimension.
Dimensions below the dimension specified by axis will be packed together with other inputs.
<a class="mk-toclify" id="stridedslice"></a>
## StridedSlice
Extracts a strided slice of a tensor.
Extracts a strided slice of a tensor.Same as tensorflow.
Roughly speaking, this op extracts a slice of size (end - begin) / stride from
the given input tensor. Starting at the location specified by begin the slice
@ -738,6 +847,15 @@ specification shrinks the dimensionality by 1, taking on the value at index begi
In this case, the ith specification must define a slice of size 1,
e.g. begin[i] = x, end[i] = x + 1.
<a class="mk-toclify" id="svdf"></a>
## Svdf
Performs an 2-D pooling operation.
- rank : The rank of the SVD approximation.
- num_units : corresponds to the number of units.
- spectrogram_length : corresponds to the fixed-size of the memory.
<a class="mk-toclify" id="tile"></a>
## Tile
@ -765,4 +883,4 @@ how to bind input/output: take unidirectional_sequence_lstm_test.cc
Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.
- axis : An int. The axis to unstack along. Defaults to the first dimension.
Negative values wrap around, so the valid range is [-R, R).
Negative values wrap around, so the valid range is [-R, R).

View File

@ -29,6 +29,7 @@
#include "tim/vx/ops/arg.h"
#include "tim/vx/ops/batch2space.h"
#include "tim/vx/ops/batchnorm.h"
#include "tim/vx/ops/broadcast.h"
#include "tim/vx/ops/clip.h"
#include "tim/vx/ops/concat.h"
#include "tim/vx/ops/conv1d.h"

View File

@ -0,0 +1,61 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef OVXLIBXX_OPERATIONS_BROADCAST_H_
#define OVXLIBXX_OPERATIONS_BROADCAST_H_
#include "tim/vx/direct_map_op.h"
namespace tim {
namespace vx {
namespace ops {
/**
* ## Broadcast
*
* Broadcast an array for a compatible shape. See also numpy.broadcast_to().
*
* Input:
* - input.
*
* Attribute:
* - shape: the shape which broadcast to.
* - dimensions(optional): Which dimension in the target shape each dimension
* of the operand shape corresponds to. For BroadcastInDim.
*/
class Broadcast : public DirectMapOp {
public:
Broadcast(Graph* graph, const std::vector<int32_t>& shape, const std::vector<int32_t>& dimensions = {});
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
protected:
const std::vector<int32_t> shape_;
const std::vector<int32_t> dimensions_;
};
} // namespace ops
} // namespace vx
} // namespace tim
#endif /* OVXLIBXX_OPERATIONS_BROADCAST_H_ */

View File

@ -104,6 +104,7 @@ Erf|ERF|Mapped|[tf.math.erf](https://tensorflow.google.cn/api_docs/python/tf/mat
GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/Conv1D?hl=en)
|SignalFrame|SIGNAL_FRAME|Mapped|[tf.signal.frame](https://tensorflow.google.cn/api_docs/python/tf/signal/frame)
|RNNCell|RNNCELL_OVXLIB|Mapped|[ANEURALNETWORKS_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0acd2684ac9c73bb29767b534e78a332e8)
|BroadCast|EXPAND_BROADCAST|Mapped|[numpy.broadcast_to](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html)
||PROPOSAL| TBD |[Faster-RCNN Proposal Layer](https://github.com/intel/caffe/blob/master/examples/faster-rcnn/lib/rpn/proposal_layer.py)
||ROI_POOL|Planned 22Q1 |[ANEURALNETWORKS_ROI_POOLING](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a6736198af337b2efbdb0b6b64dee7fe4)
||ROI_ALIGN| TBD |[ANEURALNETWORKS_ROI_ALIGN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a2848b39dd4bfba78f2438fda0d9397a4)

View File

@ -0,0 +1,57 @@
/****************************************************************************
*
* Copyright (c) 2021 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include "tim/vx/ops/broadcast.h"
#include <cassert>
#include "direct_map_op_impl.h"
#include "vsi_nn_pub.h"
namespace tim {
namespace vx {
namespace ops {
Broadcast::Broadcast(Graph* graph, const std::vector<int32_t>& shape,
const std::vector<int32_t>& dimensions)
: DirectMapOp(graph, VSI_NN_OP_EXPAND_BROADCAST),
shape_(shape),
dimensions_(dimensions) {
this->impl()->node()->nn_param.expand_broadcast.dim_num = shape_.size();
this->impl()->node()->nn_param.expand_broadcast.shape = (uint32_t*)shape_.data();
this->impl()->node()->nn_param.expand_broadcast.dimensions_num = dimensions_.size();
if (dimensions.size() > 0)
{
this->impl()->node()->nn_param.expand_broadcast.dimensions = (uint32_t*)dimensions_.data();
} else {
this->impl()->node()->nn_param.expand_broadcast.dimensions = nullptr;
}
}
std::shared_ptr<Operation> Broadcast::Clone(
std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Broadcast>(this->shape_, this->dimensions_);
}
} // namespace ops
} // namespace vx
} // namespace tim

View File

@ -0,0 +1,329 @@
/****************************************************************************
*
* Copyright (c) 2021 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include "tim/vx/context.h"
#include "tim/vx/graph.h"
#include "tim/vx/ops/broadcast.h"
#include "tim/transform/layout_inference.h"
#include "gtest/gtest.h"
#include "test_utils.h"
static void CheckResult(std::shared_ptr<tim::vx::Graph>& graph,
std::vector<float>& golden,
std::shared_ptr<tim::vx::Tensor>& output_tensor) {
EXPECT_TRUE(graph->Compile());
EXPECT_TRUE(graph->Run());
std::vector<float> output(golden.size() * sizeof(float));
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
}
TEST(Broadcast, ScalarTo2D_2x3) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType input_shape({1});
tim::vx::ShapeType output_shape({3, 2});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {
2.25f,
};
std::vector<float> golden = {
2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f,
};
std::vector<int32_t> shape = {3, 2};
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
in_data.size() * sizeof(float)));
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape);
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
CheckResult(graph, golden, output_tensor);
}
TEST(Broadcast, 1DTo2D) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType input_shape({3});
tim::vx::ShapeType output_shape({3, 2});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {
1.f, 2.f, 3.f,
};
std::vector<float> golden = {
1.f, 2.f, 3.f, 1.f, 2.f, 3.f,
};
std::vector<int32_t> shape = {3, 2};
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
in_data.size() * sizeof(float)));
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape);
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
CheckResult(graph, golden, output_tensor);
}
TEST(Broadcast, 1DTo2D_WithDims0) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType input_shape({2});
tim::vx::ShapeType output_shape({2, 2});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {
1.f, 2.f,
};
std::vector<float> golden = {
1.f, 2.f,
1.f, 2.f,
};
std::vector<int32_t> shape = {2, 2};
std::vector<int32_t> dimensions = {0};
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
in_data.size() * sizeof(float)));
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
CheckResult(graph, golden, output_tensor);
}
TEST(Broadcast, 1DTo2D_WithDims1) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType input_shape({2});
tim::vx::ShapeType output_shape({2, 2});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {
1.f, 2.f,
};
std::vector<float> golden = {
1.f, 1.f,
2.f, 2.f,
};
std::vector<int32_t> shape = {2, 2};
std::vector<int32_t> dimensions = {1};
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
in_data.size() * sizeof(float)));
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
CheckResult(graph, golden, output_tensor);
}
TEST(Broadcast, 1DTo3D_WithDims0) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType input_shape({2});
tim::vx::ShapeType output_shape({2, 2, 2});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {
1.f, 2.f,
};
std::vector<float> golden = {
1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f,
};
std::vector<int32_t> shape = {2, 2, 2};
std::vector<int32_t> dimensions = {0};
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
in_data.size() * sizeof(float)));
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
CheckResult(graph, golden, output_tensor);
}
TEST(Broadcast, 1DTo3D_WithDims1) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType input_shape({2});
tim::vx::ShapeType output_shape({2, 2, 2});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {
1.f, 2.f,
};
std::vector<float> golden = {
1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f,
};
std::vector<int32_t> shape = {2, 2, 2};
std::vector<int32_t> dimensions = {1};
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
in_data.size() * sizeof(float)));
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
CheckResult(graph, golden, output_tensor);
}
TEST(Broadcast, 1DTo3D_WithDims2) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType input_shape({2});
tim::vx::ShapeType output_shape({2, 2, 2});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {
1.f, 2.f,
};
std::vector<float> golden = {
1.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 2.f,
};
std::vector<int32_t> shape = {2, 2, 2};
std::vector<int32_t> dimensions = {2};
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
in_data.size() * sizeof(float)));
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
CheckResult(graph, golden, output_tensor);
}
TEST(Broadcast, 2DTo3D_WithDims02) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType input_shape({2, 2});
tim::vx::ShapeType output_shape({2, 2, 2});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {
1.f, 5.f, 2.f, 6.f
};
std::vector<float> golden = {
1.f, 5.f, 1.f, 5.f, 2.f, 6.f, 2.f, 6.f,
};
std::vector<int32_t> shape = {2, 2, 2};
std::vector<int32_t> dimensions = {0, 2};
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
in_data.size() * sizeof(float)));
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
CheckResult(graph, golden, output_tensor);
}
TEST(Broadcast, 2DTo3D_WithDims12) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType input_shape({2, 2});
tim::vx::ShapeType output_shape({2, 2, 2});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {
1.f, 5.f, 2.f, 6.f
};
std::vector<float> golden = {
1.f, 1.f, 5.f, 5.f, 2.f, 2.f, 6.f, 6.f,
};
std::vector<int32_t> shape = {2, 2, 2};
std::vector<int32_t> dimensions = {1, 2};
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
in_data.size() * sizeof(float)));
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
CheckResult(graph, golden, output_tensor);
}