Add Broadcast op (#365)
This commit is contained in:
parent
96dedc1453
commit
b916e1301a
|
|
@ -7,9 +7,11 @@
|
||||||
- [ArgMin/ArgMax](#argminargmax)
|
- [ArgMin/ArgMax](#argminargmax)
|
||||||
- [Batch2Space](#batch2space)
|
- [Batch2Space](#batch2space)
|
||||||
- [BatchNorm](#batchnorm)
|
- [BatchNorm](#batchnorm)
|
||||||
|
- [Broadcast](#broadcast)
|
||||||
- [Clip](#clip)
|
- [Clip](#clip)
|
||||||
- [Concat](#concat)
|
- [Concat](#concat)
|
||||||
- [Conv2d](#conv2d)
|
- [Conv2d](#conv2d)
|
||||||
|
- [Conv3d](#conv3d)
|
||||||
- [DeConv2d](#deconv2d)
|
- [DeConv2d](#deconv2d)
|
||||||
- [DeConv1d](#deconv1d)
|
- [DeConv1d](#deconv1d)
|
||||||
- [DepthToSpace](#depthtospace)
|
- [DepthToSpace](#depthtospace)
|
||||||
|
|
@ -22,9 +24,11 @@
|
||||||
- [Minimum](#minimum)
|
- [Minimum](#minimum)
|
||||||
- [Maximum](#maximum)
|
- [Maximum](#maximum)
|
||||||
- [FloorDiv](#floordiv)
|
- [FloorDiv](#floordiv)
|
||||||
|
- [Erf](#erf)
|
||||||
- [FullyConnected](#fullyconnected)
|
- [FullyConnected](#fullyconnected)
|
||||||
- [Gather](#gather)
|
- [Gather](#gather)
|
||||||
- [GatherNd](#gathernd)
|
- [GatherNd](#gathernd)
|
||||||
|
- [GroupedConv1d](#groupedconv1d)
|
||||||
- [GroupedConv2d](#groupedconv2d)
|
- [GroupedConv2d](#groupedconv2d)
|
||||||
- [L2Normalization](#l2normalization)
|
- [L2Normalization](#l2normalization)
|
||||||
- [LocalResponseNormalization](#localresponsenormalization)
|
- [LocalResponseNormalization](#localresponsenormalization)
|
||||||
|
|
@ -36,8 +40,12 @@
|
||||||
- [MaxUnpool2d](#maxunpool2d)
|
- [MaxUnpool2d](#maxunpool2d)
|
||||||
- [Moments](#moments)
|
- [Moments](#moments)
|
||||||
- [NBG](#nbg)
|
- [NBG](#nbg)
|
||||||
|
- [OneHot](#onehot)
|
||||||
- [Pad](#pad)
|
- [Pad](#pad)
|
||||||
- [Pool2d](#pool2d)
|
- [Pool2d](#pool2d)
|
||||||
|
- [Classic Pool2d](#classic-pool2d)
|
||||||
|
- [Global Pool2d](#global-pool2d)
|
||||||
|
- [Adaptive Pool2d](#adaptive-pool2d)
|
||||||
- [ReduceMin](#reducemin)
|
- [ReduceMin](#reducemin)
|
||||||
- [ReduceMax](#reducemax)
|
- [ReduceMax](#reducemax)
|
||||||
- [ReduceAny](#reduceany)
|
- [ReduceAny](#reduceany)
|
||||||
|
|
@ -78,6 +86,7 @@
|
||||||
- [Squeeze](#squeeze)
|
- [Squeeze](#squeeze)
|
||||||
- [Stack](#stack)
|
- [Stack](#stack)
|
||||||
- [StridedSlice](#stridedslice)
|
- [StridedSlice](#stridedslice)
|
||||||
|
- [Svdf](#svdf)
|
||||||
- [Tile](#tile)
|
- [Tile](#tile)
|
||||||
- [Transpose](#transpose)
|
- [Transpose](#transpose)
|
||||||
- [Unidirectional sequence lstm](#unidirectional-sequence-lstm)
|
- [Unidirectional sequence lstm](#unidirectional-sequence-lstm)
|
||||||
|
|
@ -108,12 +117,12 @@ Swish(x) : x * sigmoid(x)
|
||||||
|
|
||||||
HardSwish(x) : 0 if x <= -3; x(x + 3)/6 if -3 < x < 3; x if x >= 3
|
HardSwish(x) : 0 if x <= -3; x(x + 3)/6 if -3 < x < 3; x if x >= 3
|
||||||
|
|
||||||
Mish(x) : x if x >= 0 else alpha * x
|
|
||||||
|
|
||||||
HardSigmoid(x) : min(max(alpha*x + beta, 0), 1)
|
HardSigmoid(x) : min(max(alpha*x + beta, 0), 1)
|
||||||
|
|
||||||
SoftRelu(x) : log(1 + e^x). Also known as SoftPlus.
|
SoftRelu(x) : log(1 + e^x). Also known as SoftPlus.
|
||||||
|
|
||||||
|
Mish(x) : x * tanh(softrelu(x))
|
||||||
|
|
||||||
LeakyRelu(x) : alpha * x if x <= 0; x if x > 0. alpha is a scalar.
|
LeakyRelu(x) : alpha * x if x <= 0; x if x > 0. alpha is a scalar.
|
||||||
|
|
||||||
Prelu(x) : alpha * x if x <= 0; x if x > 0. alpha is a tensor.
|
Prelu(x) : alpha * x if x <= 0; x if x > 0. alpha is a tensor.
|
||||||
|
|
@ -153,7 +162,22 @@ rank as the input. This is the reverse transformation of Space2Batch.
|
||||||
Carries out batch normalization as described in the paper
|
Carries out batch normalization as described in the paper
|
||||||
https://arxiv.org/abs/1502.03167.
|
https://arxiv.org/abs/1502.03167.
|
||||||
|
|
||||||
Y = (X - Mean) / Sqrt( Var + Eps) * Gama + Beta
|
$$\hat x_i\leftarrow \frac{x_i-\mu_\mathcal{B}}{\sqrt{\sigma_\mathcal{B}^2+\epsilon}}$$
|
||||||
|
|
||||||
|
$$y_i=\gamma\hat x_i+\beta\equiv BN_{\gamma,\beta}(x_i)$$
|
||||||
|
|
||||||
|
<a class="mk-toclify" id="broadcast"></a>
|
||||||
|
## Broadcast
|
||||||
|
|
||||||
|
Broadcast an array for a compatible shape. See also numpy.broadcast_to().
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- input.
|
||||||
|
|
||||||
|
Attribute:
|
||||||
|
- shape: the shape which broadcast to.
|
||||||
|
- dimensions(optional): Which dimension in the target shape each dimension
|
||||||
|
of the operand shape corresponds to. For BroadcastInDim.
|
||||||
|
|
||||||
<a class="mk-toclify" id="clip"></a>
|
<a class="mk-toclify" id="clip"></a>
|
||||||
## Clip
|
## Clip
|
||||||
|
|
@ -189,6 +213,28 @@ Attribute:
|
||||||
but the value is different. multiplier = weights / group.
|
but the value is different. multiplier = weights / group.
|
||||||
- layout : WHCN or CWHN.
|
- layout : WHCN or CWHN.
|
||||||
|
|
||||||
|
<a class="mk-toclify" id="conv3d"></a>
|
||||||
|
## Conv3d
|
||||||
|
|
||||||
|
Performs a 3-D convolution operation
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- input [WHDCN].
|
||||||
|
- kernel [ WHDIcOc ] (Ic: Input Channels. Oc: Output Channels).
|
||||||
|
- bias [ O ]. Optional.
|
||||||
|
|
||||||
|
Attribute:
|
||||||
|
- weights : the output channel number for weight tensor.
|
||||||
|
- ksize : the height and width for weight tensor.
|
||||||
|
- padding : AUTO, VALID or SAME.
|
||||||
|
- pad : pad value for each spatial axis. (left, right, top, bottom, front, rear).
|
||||||
|
- stride : stride along each spatial axis.
|
||||||
|
- dilation : dilation value along each spatial axis of the filter.
|
||||||
|
- multiplier: function similar to group attribute on other framework,
|
||||||
|
but the value is different. multiplier = weights / group.
|
||||||
|
- input_layout : WHDCN or WHCDN.
|
||||||
|
- kernel_layout : WHDIcOc
|
||||||
|
|
||||||
<a class="mk-toclify" id="deconv2d"></a>
|
<a class="mk-toclify" id="deconv2d"></a>
|
||||||
## DeConv2d
|
## DeConv2d
|
||||||
|
|
||||||
|
|
@ -292,6 +338,13 @@ Maximum(x, y) : max(x, y). This operation supports broadcasting.
|
||||||
|
|
||||||
FloorDiv(x, y): floor( x / y ). This operation supports broadcasting.
|
FloorDiv(x, y): floor( x / y ). This operation supports broadcasting.
|
||||||
|
|
||||||
|
<a class="mk-toclify" id="erf"></a>
|
||||||
|
## Erf
|
||||||
|
|
||||||
|
Computes the Gauss error function of x element-wise.
|
||||||
|
|
||||||
|
- no parameters
|
||||||
|
|
||||||
<a class="mk-toclify" id="fullyconnected"></a>
|
<a class="mk-toclify" id="fullyconnected"></a>
|
||||||
## FullyConnected
|
## FullyConnected
|
||||||
|
|
||||||
|
|
@ -311,6 +364,26 @@ Gather slices from input, **axis** according to **indices**.
|
||||||
|
|
||||||
An operation similar to Gather but gathers across multiple axis at once.
|
An operation similar to Gather but gathers across multiple axis at once.
|
||||||
|
|
||||||
|
<a class="mk-toclify" id="groupedconv1d"></a>
|
||||||
|
## GroupedConv1d
|
||||||
|
|
||||||
|
Performs a grouped 1-D convolution operation.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- input [WCN].
|
||||||
|
- kernel [ WIcOc ] (Ic: Input Channels. Oc: Output Channels).Ic*group=C.
|
||||||
|
- bias [ O ]. Optional.
|
||||||
|
|
||||||
|
Attribute:
|
||||||
|
- weights : the output channel number for weight tensor.
|
||||||
|
- ksize : the height and width for weight tensor.
|
||||||
|
- padding : AUTO, VALID or SAME.
|
||||||
|
- pad : pad value for each spatial axis.
|
||||||
|
- stride : stride along each spatial axis.
|
||||||
|
- dilation : dilation value along each spatial axis of the filter.
|
||||||
|
- group: Split conv to n group.
|
||||||
|
- layout : WCN or CWN.
|
||||||
|
|
||||||
<a class="mk-toclify" id="groupedconv2d"></a>
|
<a class="mk-toclify" id="groupedconv2d"></a>
|
||||||
## GroupedConv2d
|
## GroupedConv2d
|
||||||
|
|
||||||
|
|
@ -415,24 +488,59 @@ If x is 1-D and axes = [0] this is just the mean and variance of a vector.
|
||||||
Network Binary Graph is a precompile technology, which can compile a fuse graph into
|
Network Binary Graph is a precompile technology, which can compile a fuse graph into
|
||||||
a bianry file.
|
a bianry file.
|
||||||
|
|
||||||
|
<a class="mk-toclify" id="onehot"></a>
|
||||||
|
## OneHot
|
||||||
|
|
||||||
|
Create a one-hot tensor.
|
||||||
|
|
||||||
|
- depth : A scalar defining the depth of the one hot dimension.
|
||||||
|
- on_value : A scalar defining the value to fill in output.
|
||||||
|
- off_value : A scalar defining the value to fill in output.
|
||||||
|
- axis : The axis to fill.
|
||||||
|
|
||||||
<a class="mk-toclify" id="pad"></a>
|
<a class="mk-toclify" id="pad"></a>
|
||||||
## Pad
|
## Pad
|
||||||
|
|
||||||
Pads a tensor.
|
Pads a tensor.
|
||||||
|
|
||||||
- const_val : the value to pad.
|
- const_val : the value to pad.
|
||||||
|
- pad_mode : the mode of pad.
|
||||||
|
- front_size : Add pad values to the left and top.
|
||||||
|
- back_size : Add pad values to the right and bottom.
|
||||||
|
|
||||||
<a class="mk-toclify" id="pool2d"></a>
|
<a class="mk-toclify" id="pool2d"></a>
|
||||||
## Pool2d
|
## Pool2d
|
||||||
|
|
||||||
|
<a class="mk-toclify" id="classic-pool2d"></a>
|
||||||
|
### Classic Pool2d
|
||||||
|
|
||||||
Performs an 2-D pooling operation.
|
Performs an 2-D pooling operation.
|
||||||
|
|
||||||
- type : MAX, AVG, L2 or AVG_ANDROID.
|
- type : MAX, AVG, L2 or AVG_ANDROID.
|
||||||
- padding : AUTO, VALID or SAME.
|
- padding : AUTO, VALID or SAME.
|
||||||
|
- pad : Specify the number of pad values for left, right, top, and bottom.
|
||||||
- ksize : filter size.
|
- ksize : filter size.
|
||||||
- stride : stride along each spatial axis.
|
- stride : stride along each spatial axis.
|
||||||
- round_type : CEILING or FLOOR.
|
- round_type : CEILING or FLOOR.
|
||||||
|
|
||||||
|
<a class="mk-toclify" id="global-pool2d"></a>
|
||||||
|
### Global Pool2d
|
||||||
|
|
||||||
|
- type : MAX, AVG, L2 or AVG_ANDROID.
|
||||||
|
- input_size : input size(only [W, H])
|
||||||
|
- round_type : CEILING or FLOOR.
|
||||||
|
|
||||||
|
<a class="mk-toclify" id="adaptive-pool2d"></a>
|
||||||
|
### Adaptive Pool2d
|
||||||
|
|
||||||
|
Same as torch.nn.AdaptiveXXXPool2d.
|
||||||
|
|
||||||
|
- type : MAX, AVG, L2 or AVG_ANDROID.
|
||||||
|
- input_size : input size(only [W, H])
|
||||||
|
- output_size : output size(only [W, H])
|
||||||
|
- round_type : CEILING or FLOOR.
|
||||||
|
|
||||||
|
|
||||||
<a class="mk-toclify" id="reducemin"></a>
|
<a class="mk-toclify" id="reducemin"></a>
|
||||||
## ReduceMin
|
## ReduceMin
|
||||||
|
|
||||||
|
|
@ -715,11 +823,12 @@ Removes dimensions of size 1 from the shape of a tensor.
|
||||||
|
|
||||||
Packs the list of tensors in inputs into a tensor with rank one higher than
|
Packs the list of tensors in inputs into a tensor with rank one higher than
|
||||||
each tensor in values, by packing them along the **axis** dimension.
|
each tensor in values, by packing them along the **axis** dimension.
|
||||||
|
Dimensions below the dimension specified by axis will be packed together with other inputs.
|
||||||
|
|
||||||
<a class="mk-toclify" id="stridedslice"></a>
|
<a class="mk-toclify" id="stridedslice"></a>
|
||||||
## StridedSlice
|
## StridedSlice
|
||||||
|
|
||||||
Extracts a strided slice of a tensor.
|
Extracts a strided slice of a tensor.Same as tensorflow.
|
||||||
|
|
||||||
Roughly speaking, this op extracts a slice of size (end - begin) / stride from
|
Roughly speaking, this op extracts a slice of size (end - begin) / stride from
|
||||||
the given input tensor. Starting at the location specified by begin the slice
|
the given input tensor. Starting at the location specified by begin the slice
|
||||||
|
|
@ -738,6 +847,15 @@ specification shrinks the dimensionality by 1, taking on the value at index begi
|
||||||
In this case, the ith specification must define a slice of size 1,
|
In this case, the ith specification must define a slice of size 1,
|
||||||
e.g. begin[i] = x, end[i] = x + 1.
|
e.g. begin[i] = x, end[i] = x + 1.
|
||||||
|
|
||||||
|
<a class="mk-toclify" id="svdf"></a>
|
||||||
|
## Svdf
|
||||||
|
|
||||||
|
Performs an 2-D pooling operation.
|
||||||
|
|
||||||
|
- rank : The rank of the SVD approximation.
|
||||||
|
- num_units : corresponds to the number of units.
|
||||||
|
- spectrogram_length : corresponds to the fixed-size of the memory.
|
||||||
|
|
||||||
<a class="mk-toclify" id="tile"></a>
|
<a class="mk-toclify" id="tile"></a>
|
||||||
## Tile
|
## Tile
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@
|
||||||
#include "tim/vx/ops/arg.h"
|
#include "tim/vx/ops/arg.h"
|
||||||
#include "tim/vx/ops/batch2space.h"
|
#include "tim/vx/ops/batch2space.h"
|
||||||
#include "tim/vx/ops/batchnorm.h"
|
#include "tim/vx/ops/batchnorm.h"
|
||||||
|
#include "tim/vx/ops/broadcast.h"
|
||||||
#include "tim/vx/ops/clip.h"
|
#include "tim/vx/ops/clip.h"
|
||||||
#include "tim/vx/ops/concat.h"
|
#include "tim/vx/ops/concat.h"
|
||||||
#include "tim/vx/ops/conv1d.h"
|
#include "tim/vx/ops/conv1d.h"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2020 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#ifndef OVXLIBXX_OPERATIONS_BROADCAST_H_
|
||||||
|
#define OVXLIBXX_OPERATIONS_BROADCAST_H_
|
||||||
|
#include "tim/vx/direct_map_op.h"
|
||||||
|
|
||||||
|
namespace tim {
|
||||||
|
namespace vx {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ## Broadcast
|
||||||
|
*
|
||||||
|
* Broadcast an array for a compatible shape. See also numpy.broadcast_to().
|
||||||
|
*
|
||||||
|
* Input:
|
||||||
|
* - input.
|
||||||
|
*
|
||||||
|
* Attribute:
|
||||||
|
* - shape: the shape which broadcast to.
|
||||||
|
* - dimensions(optional): Which dimension in the target shape each dimension
|
||||||
|
* of the operand shape corresponds to. For BroadcastInDim.
|
||||||
|
*/
|
||||||
|
|
||||||
|
class Broadcast : public DirectMapOp {
|
||||||
|
public:
|
||||||
|
Broadcast(Graph* graph, const std::vector<int32_t>& shape, const std::vector<int32_t>& dimensions = {});
|
||||||
|
|
||||||
|
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const std::vector<int32_t> shape_;
|
||||||
|
const std::vector<int32_t> dimensions_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace vx
|
||||||
|
} // namespace tim
|
||||||
|
|
||||||
|
#endif /* OVXLIBXX_OPERATIONS_BROADCAST_H_ */
|
||||||
|
|
@ -104,6 +104,7 @@ Erf|ERF|Mapped|[tf.math.erf](https://tensorflow.google.cn/api_docs/python/tf/mat
|
||||||
GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/Conv1D?hl=en)
|
GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/Conv1D?hl=en)
|
||||||
|SignalFrame|SIGNAL_FRAME|Mapped|[tf.signal.frame](https://tensorflow.google.cn/api_docs/python/tf/signal/frame)
|
|SignalFrame|SIGNAL_FRAME|Mapped|[tf.signal.frame](https://tensorflow.google.cn/api_docs/python/tf/signal/frame)
|
||||||
|RNNCell|RNNCELL_OVXLIB|Mapped|[ANEURALNETWORKS_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0acd2684ac9c73bb29767b534e78a332e8)
|
|RNNCell|RNNCELL_OVXLIB|Mapped|[ANEURALNETWORKS_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0acd2684ac9c73bb29767b534e78a332e8)
|
||||||
|
|BroadCast|EXPAND_BROADCAST|Mapped|[numpy.broadcast_to](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html)
|
||||||
||PROPOSAL| TBD |[Faster-RCNN Proposal Layer](https://github.com/intel/caffe/blob/master/examples/faster-rcnn/lib/rpn/proposal_layer.py)
|
||PROPOSAL| TBD |[Faster-RCNN Proposal Layer](https://github.com/intel/caffe/blob/master/examples/faster-rcnn/lib/rpn/proposal_layer.py)
|
||||||
||ROI_POOL|Planned 22Q1 |[ANEURALNETWORKS_ROI_POOLING](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a6736198af337b2efbdb0b6b64dee7fe4)
|
||ROI_POOL|Planned 22Q1 |[ANEURALNETWORKS_ROI_POOLING](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a6736198af337b2efbdb0b6b64dee7fe4)
|
||||||
||ROI_ALIGN| TBD |[ANEURALNETWORKS_ROI_ALIGN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a2848b39dd4bfba78f2438fda0d9397a4)
|
||ROI_ALIGN| TBD |[ANEURALNETWORKS_ROI_ALIGN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a2848b39dd4bfba78f2438fda0d9397a4)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2021 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#include "tim/vx/ops/broadcast.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include "direct_map_op_impl.h"
|
||||||
|
#include "vsi_nn_pub.h"
|
||||||
|
|
||||||
|
namespace tim {
|
||||||
|
namespace vx {
|
||||||
|
namespace ops {
|
||||||
|
Broadcast::Broadcast(Graph* graph, const std::vector<int32_t>& shape,
|
||||||
|
const std::vector<int32_t>& dimensions)
|
||||||
|
: DirectMapOp(graph, VSI_NN_OP_EXPAND_BROADCAST),
|
||||||
|
shape_(shape),
|
||||||
|
dimensions_(dimensions) {
|
||||||
|
this->impl()->node()->nn_param.expand_broadcast.dim_num = shape_.size();
|
||||||
|
this->impl()->node()->nn_param.expand_broadcast.shape = (uint32_t*)shape_.data();
|
||||||
|
this->impl()->node()->nn_param.expand_broadcast.dimensions_num = dimensions_.size();
|
||||||
|
if (dimensions.size() > 0)
|
||||||
|
{
|
||||||
|
this->impl()->node()->nn_param.expand_broadcast.dimensions = (uint32_t*)dimensions_.data();
|
||||||
|
} else {
|
||||||
|
this->impl()->node()->nn_param.expand_broadcast.dimensions = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Operation> Broadcast::Clone(
|
||||||
|
std::shared_ptr<Graph>& graph) const {
|
||||||
|
return graph->CreateOperation<Broadcast>(this->shape_, this->dimensions_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace vx
|
||||||
|
} // namespace tim
|
||||||
|
|
@ -0,0 +1,329 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2021 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#include "tim/vx/context.h"
|
||||||
|
#include "tim/vx/graph.h"
|
||||||
|
#include "tim/vx/ops/broadcast.h"
|
||||||
|
#include "tim/transform/layout_inference.h"
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "test_utils.h"
|
||||||
|
|
||||||
|
static void CheckResult(std::shared_ptr<tim::vx::Graph>& graph,
|
||||||
|
std::vector<float>& golden,
|
||||||
|
std::shared_ptr<tim::vx::Tensor>& output_tensor) {
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
|
||||||
|
std::vector<float> output(golden.size() * sizeof(float));
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||||
|
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Broadcast, ScalarTo2D_2x3) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({1});
|
||||||
|
tim::vx::ShapeType output_shape({3, 2});
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
|
||||||
|
tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
|
||||||
|
tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<float> in_data = {
|
||||||
|
2.25f,
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f,
|
||||||
|
};
|
||||||
|
std::vector<int32_t> shape = {3, 2};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
|
||||||
|
in_data.size() * sizeof(float)));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape);
|
||||||
|
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||||
|
|
||||||
|
CheckResult(graph, golden, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Broadcast, 1DTo2D) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({3});
|
||||||
|
tim::vx::ShapeType output_shape({3, 2});
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
|
||||||
|
tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
|
||||||
|
tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<float> in_data = {
|
||||||
|
1.f, 2.f, 3.f,
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
1.f, 2.f, 3.f, 1.f, 2.f, 3.f,
|
||||||
|
};
|
||||||
|
std::vector<int32_t> shape = {3, 2};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
|
||||||
|
in_data.size() * sizeof(float)));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape);
|
||||||
|
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||||
|
|
||||||
|
CheckResult(graph, golden, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Broadcast, 1DTo2D_WithDims0) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({2});
|
||||||
|
tim::vx::ShapeType output_shape({2, 2});
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
|
||||||
|
tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
|
||||||
|
tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<float> in_data = {
|
||||||
|
1.f, 2.f,
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
1.f, 2.f,
|
||||||
|
1.f, 2.f,
|
||||||
|
};
|
||||||
|
std::vector<int32_t> shape = {2, 2};
|
||||||
|
std::vector<int32_t> dimensions = {0};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
|
||||||
|
in_data.size() * sizeof(float)));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
|
||||||
|
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||||
|
|
||||||
|
CheckResult(graph, golden, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Broadcast, 1DTo2D_WithDims1) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({2});
|
||||||
|
tim::vx::ShapeType output_shape({2, 2});
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
|
||||||
|
tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
|
||||||
|
tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<float> in_data = {
|
||||||
|
1.f, 2.f,
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
1.f, 1.f,
|
||||||
|
2.f, 2.f,
|
||||||
|
};
|
||||||
|
std::vector<int32_t> shape = {2, 2};
|
||||||
|
std::vector<int32_t> dimensions = {1};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
|
||||||
|
in_data.size() * sizeof(float)));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
|
||||||
|
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||||
|
|
||||||
|
CheckResult(graph, golden, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Broadcast, 1DTo3D_WithDims0) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({2});
|
||||||
|
tim::vx::ShapeType output_shape({2, 2, 2});
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
|
||||||
|
tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
|
||||||
|
tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<float> in_data = {
|
||||||
|
1.f, 2.f,
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f,
|
||||||
|
};
|
||||||
|
std::vector<int32_t> shape = {2, 2, 2};
|
||||||
|
std::vector<int32_t> dimensions = {0};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
|
||||||
|
in_data.size() * sizeof(float)));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
|
||||||
|
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||||
|
|
||||||
|
CheckResult(graph, golden, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Broadcast, 1DTo3D_WithDims1) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({2});
|
||||||
|
tim::vx::ShapeType output_shape({2, 2, 2});
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
|
||||||
|
tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
|
||||||
|
tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<float> in_data = {
|
||||||
|
1.f, 2.f,
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f,
|
||||||
|
};
|
||||||
|
std::vector<int32_t> shape = {2, 2, 2};
|
||||||
|
std::vector<int32_t> dimensions = {1};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
|
||||||
|
in_data.size() * sizeof(float)));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
|
||||||
|
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||||
|
|
||||||
|
CheckResult(graph, golden, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Broadcast, 1DTo3D_WithDims2) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({2});
|
||||||
|
tim::vx::ShapeType output_shape({2, 2, 2});
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
|
||||||
|
tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
|
||||||
|
tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<float> in_data = {
|
||||||
|
1.f, 2.f,
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
1.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 2.f,
|
||||||
|
};
|
||||||
|
std::vector<int32_t> shape = {2, 2, 2};
|
||||||
|
std::vector<int32_t> dimensions = {2};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
|
||||||
|
in_data.size() * sizeof(float)));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
|
||||||
|
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||||
|
|
||||||
|
CheckResult(graph, golden, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Broadcast, 2DTo3D_WithDims02) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({2, 2});
|
||||||
|
tim::vx::ShapeType output_shape({2, 2, 2});
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
|
||||||
|
tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
|
||||||
|
tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<float> in_data = {
|
||||||
|
1.f, 5.f, 2.f, 6.f
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
1.f, 5.f, 1.f, 5.f, 2.f, 6.f, 2.f, 6.f,
|
||||||
|
};
|
||||||
|
std::vector<int32_t> shape = {2, 2, 2};
|
||||||
|
std::vector<int32_t> dimensions = {0, 2};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
|
||||||
|
in_data.size() * sizeof(float)));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
|
||||||
|
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||||
|
|
||||||
|
CheckResult(graph, golden, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Broadcast, 2DTo3D_WithDims12) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({2, 2});
|
||||||
|
tim::vx::ShapeType output_shape({2, 2, 2});
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
|
||||||
|
tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
|
||||||
|
tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<float> in_data = {
|
||||||
|
1.f, 5.f, 2.f, 6.f
|
||||||
|
};
|
||||||
|
std::vector<float> golden = {
|
||||||
|
1.f, 1.f, 5.f, 5.f, 2.f, 2.f, 6.f, 6.f,
|
||||||
|
};
|
||||||
|
std::vector<int32_t> shape = {2, 2, 2};
|
||||||
|
std::vector<int32_t> dimensions = {1, 2};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(),
|
||||||
|
in_data.size() * sizeof(float)));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::Broadcast>(shape, dimensions);
|
||||||
|
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||||
|
|
||||||
|
CheckResult(graph, golden, output_tensor);
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue