From cea11422b89223403f15d4cbcac3fb5de2307ad1 Mon Sep 17 00:00:00 2001 From: chxin66 <57057788+chxin66@users.noreply.github.com> Date: Wed, 29 Dec 2021 11:08:24 +0800 Subject: [PATCH] Added RNNCell & unit test (#249) Signed-off-by: Chen Xin Co-authored-by: Chen Xin --- BUILD | 8 +- include/tim/vx/direct_map_op.h | 42 +++ include/tim/vx/operation.h | 11 +- include/tim/vx/ops.h | 2 + include/tim/vx/ops/activations.h | 12 +- include/tim/vx/ops/addn.h | 4 +- include/tim/vx/ops/arg.h | 4 +- include/tim/vx/ops/batch2space.h | 4 +- include/tim/vx/ops/batchnorm.h | 4 +- include/tim/vx/ops/clip.h | 4 +- include/tim/vx/ops/concat.h | 4 +- include/tim/vx/ops/conv1d.h | 4 +- include/tim/vx/ops/conv2d.h | 4 +- include/tim/vx/ops/deconv.h | 4 +- include/tim/vx/ops/deconv1d.h | 4 +- include/tim/vx/ops/depth2space.h | 4 +- include/tim/vx/ops/dropout.h | 4 +- include/tim/vx/ops/elementwise.h | 8 +- include/tim/vx/ops/erf.h | 4 +- include/tim/vx/ops/fullyconnected.h | 4 +- include/tim/vx/ops/gather.h | 4 +- include/tim/vx/ops/gathernd.h | 4 +- include/tim/vx/ops/groupedconv1d.h | 4 +- include/tim/vx/ops/groupedconv2d.h | 4 +- include/tim/vx/ops/instancenormalization.h | 4 +- include/tim/vx/ops/l2normalization.h | 4 +- include/tim/vx/ops/layernormalization.h | 4 +- .../tim/vx/ops/localresponsenormalization.h | 4 +- include/tim/vx/ops/logical.h | 4 +- include/tim/vx/ops/logsoftmax.h | 4 +- include/tim/vx/ops/matmul.h | 4 +- include/tim/vx/ops/maxpoolwithargmax.h | 4 +- include/tim/vx/ops/maxunpool2d.h | 4 +- include/tim/vx/ops/moments.h | 4 +- include/tim/vx/ops/nbg.h | 4 +- include/tim/vx/ops/pad.h | 4 +- include/tim/vx/ops/pool2d.h | 4 +- include/tim/vx/ops/reduce.h | 4 +- include/tim/vx/ops/relational_operations.h | 4 +- include/tim/vx/ops/reorg.h | 4 +- include/tim/vx/ops/reshape.h | 4 +- include/tim/vx/ops/resize.h | 4 +- include/tim/vx/ops/resize1d.h | 4 +- include/tim/vx/ops/reverse.h | 4 +- include/tim/vx/ops/rnn_cell.h | 54 ++++ include/tim/vx/ops/scatternd.h | 4 +- include/tim/vx/ops/select.h | 4 +- include/tim/vx/ops/shuffle_channel.h | 4 +- include/tim/vx/ops/signal_frame.h | 4 +- include/tim/vx/ops/simple_operations.h | 4 +- include/tim/vx/ops/slice.h | 4 +- include/tim/vx/ops/softmax.h | 4 +- include/tim/vx/ops/space2batch.h | 4 +- include/tim/vx/ops/space2depth.h | 4 +- include/tim/vx/ops/spatial_transformer.h | 4 +- include/tim/vx/ops/split.h | 4 +- include/tim/vx/ops/squeeze.h | 4 +- include/tim/vx/ops/stack.h | 4 +- include/tim/vx/ops/stridedslice.h | 4 +- include/tim/vx/ops/svdf.h | 4 +- include/tim/vx/ops/tile.h | 4 +- include/tim/vx/ops/transpose.h | 4 +- .../tim/vx/ops/unidirectional_sequence_lstm.h | 4 +- include/tim/vx/ops/unstack.h | 4 +- src/tim/transform/layout_inference.cc | 2 +- .../ops/activation_layout_inference.h | 2 +- src/tim/transform/ops/addn_layout_inference.h | 2 +- src/tim/transform/ops/arg_layout_inference.h | 2 +- .../ops/batch2space_layout_inference.h | 2 +- .../ops/batchnorm_layout_inference.h | 2 +- .../transform/ops/concat_layout_inferene.h | 2 +- .../transform/ops/conv2d_layout_inference.h | 2 +- .../transform/ops/deconv2d_layout_inference.h | 2 +- .../transform/ops/default_layout_inference.h | 2 +- .../ops/depth2space_layout_inference.h | 2 +- .../ops/elementwise_layout_inference.h | 2 +- .../ops/fullyconnected_layout_inference.h | 2 +- .../transform/ops/gather_layout_inference.h | 2 +- .../ops/gather_nd_layout_inference.h | 2 +- .../ops/l2normalization_layout_inference.h | 2 +- .../transform/ops/logical_layout_inference.h | 2 +- src/tim/transform/ops/lrn_layout_inference.h | 2 +- src/tim/transform/ops/op_layout_inference.cc | 2 +- src/tim/transform/ops/pad_layout_inference.h | 2 +- .../transform/ops/pool2d_layout_inference.h | 2 +- .../transform/ops/reduce_layout_inference.h | 2 +- .../transform/ops/resize_layout_inference.h | 2 +- .../transform/ops/reverse_layout_inference.h | 2 +- .../transform/ops/select_layout_inference.h | 2 +- .../ops/simple_ops_layout_inference.h | 2 +- .../transform/ops/slice_layout_inference.h | 2 +- .../transform/ops/softmax_layout_inference.h | 2 +- .../ops/space2batch_layout_inference.h | 2 +- .../ops/space2depth_layout_inference.h | 2 +- .../transform/ops/split_layout_inference.h | 2 +- .../transform/ops/squeeze_layout_inference.h | 2 +- .../transform/ops/stack_layout_inference.h | 2 +- .../ops/stridedslice_layout_inference.h | 2 +- src/tim/vx/direct_map_op.cc | 36 +++ src/tim/vx/direct_map_op_impl.cc | 75 ++++++ ...eration_private.h => direct_map_op_impl.h} | 57 ++--- src/tim/vx/graph.cc | 2 +- src/tim/vx/op_impl.cc | 37 +++ src/tim/vx/op_impl.h | 58 +++++ src/tim/vx/operation.cc | 66 +---- src/tim/vx/ops/README.md | 2 +- src/tim/vx/ops/activations.cc | 18 +- src/tim/vx/ops/addn.cc | 5 +- src/tim/vx/ops/arg.cc | 4 +- src/tim/vx/ops/batch2space.cc | 4 +- src/tim/vx/ops/batchnorm.cc | 6 +- src/tim/vx/ops/clip.cc | 6 +- src/tim/vx/ops/concat.cc | 4 +- src/tim/vx/ops/conv1d.cc | 4 +- src/tim/vx/ops/conv2d.cc | 4 +- src/tim/vx/ops/deconv.cc | 4 +- src/tim/vx/ops/deconv1d.cc | 4 +- src/tim/vx/ops/depth2space.cc | 4 +- src/tim/vx/ops/dropout.cc | 4 +- src/tim/vx/ops/elementwise.cc | 12 +- src/tim/vx/ops/erf.cc | 4 +- src/tim/vx/ops/fullyconnected.cc | 4 +- src/tim/vx/ops/gather.cc | 4 +- src/tim/vx/ops/gathernd.cc | 4 +- src/tim/vx/ops/groupedconv1d.cc | 4 +- src/tim/vx/ops/groupedconv2d.cc | 6 +- src/tim/vx/ops/instancenormalization.cc | 4 +- src/tim/vx/ops/l2normalization.cc | 4 +- src/tim/vx/ops/layernormalization.cc | 4 +- src/tim/vx/ops/localresponsenormalization.cc | 4 +- src/tim/vx/ops/logical.cc | 4 +- src/tim/vx/ops/logsoftmax.cc | 4 +- src/tim/vx/ops/matmul.cc | 4 +- src/tim/vx/ops/maxpoolwithargmax.cc | 4 +- src/tim/vx/ops/maxunpool2d.cc | 4 +- src/tim/vx/ops/mements.cc | 4 +- src/tim/vx/ops/nbg.cc | 6 +- src/tim/vx/ops/pad.cc | 4 +- src/tim/vx/ops/pool2d.cc | 10 +- src/tim/vx/ops/reduce.cc | 4 +- src/tim/vx/ops/relational_operations.cc | 4 +- src/tim/vx/ops/reorg.cc | 4 +- src/tim/vx/ops/reshape.cc | 4 +- src/tim/vx/ops/resize.cc | 4 +- src/tim/vx/ops/resize1d.cc | 4 +- src/tim/vx/ops/reverse.cc | 4 +- src/tim/vx/ops/rnn_cell.cc | 142 +++++++++++ src/tim/vx/ops/rnn_cell_test.cc | 241 ++++++++++++++++++ src/tim/vx/ops/scatternd.cc | 4 +- src/tim/vx/ops/select.cc | 4 +- src/tim/vx/ops/shuffle_channel.cc | 4 +- src/tim/vx/ops/signal_frame.cc | 6 +- src/tim/vx/ops/simple_operations.cc | 2 +- src/tim/vx/ops/simple_operations_test.cc | 6 +- src/tim/vx/ops/slice.cc | 4 +- src/tim/vx/ops/softmax.cc | 4 +- src/tim/vx/ops/space2batch.cc | 4 +- src/tim/vx/ops/space2depth.cc | 4 +- src/tim/vx/ops/spatial_transformer.cc | 4 +- src/tim/vx/ops/split.cc | 4 +- src/tim/vx/ops/squeeze.cc | 4 +- src/tim/vx/ops/stack.cc | 4 +- src/tim/vx/ops/stridedslice.cc | 4 +- src/tim/vx/ops/svdf.cc | 4 +- src/tim/vx/ops/tile.cc | 4 +- src/tim/vx/ops/transpose.cc | 4 +- .../vx/ops/unidirectional_sequence_lstm.cc | 4 +- src/tim/vx/ops/unstack.cc | 4 +- 168 files changed, 1035 insertions(+), 393 deletions(-) create mode 100644 include/tim/vx/direct_map_op.h create mode 100644 include/tim/vx/ops/rnn_cell.h create mode 100644 src/tim/vx/direct_map_op.cc create mode 100644 src/tim/vx/direct_map_op_impl.cc rename src/tim/vx/{operation_private.h => direct_map_op_impl.h} (64%) create mode 100644 src/tim/vx/op_impl.cc create mode 100644 src/tim/vx/op_impl.h create mode 100644 src/tim/vx/ops/rnn_cell.cc create mode 100644 src/tim/vx/ops/rnn_cell_test.cc diff --git a/BUILD b/BUILD index 8db4ec8..0115a4b 100644 --- a/BUILD +++ b/BUILD @@ -28,8 +28,10 @@ cc_library( ], hdrs = [ "include/tim/vx/context.h", + "include/tim/vx/direct_map_op.h", "include/tim/vx/graph.h", "include/tim/vx/operation.h", + "include/tim/vx/ops.h", "include/tim/vx/tensor.h", "include/tim/vx/types.h", "include/tim/transform/layout_inference.h", @@ -41,8 +43,12 @@ cc_library( "src/tim/vx/context.cc", "src/tim/vx/graph_private.h", "src/tim/vx/graph.cc", + "src/tim/vx/direct_map_op_impl.cc", + "src/tim/vx/direct_map_op.cc", + "src/tim/vx/direct_map_op_impl.h", + "src/tim/vx/op_impl.cc", + "src/tim/vx/op_impl.h", "src/tim/vx/operation.cc", - "src/tim/vx/operation_private.h", "src/tim/vx/tensor.cc", "src/tim/vx/tensor_private.h", "src/tim/vx/type_utils.h", diff --git a/include/tim/vx/direct_map_op.h b/include/tim/vx/direct_map_op.h new file mode 100644 index 0000000..25092fb --- /dev/null +++ b/include/tim/vx/direct_map_op.h @@ -0,0 +1,42 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_DIRECTMAPOP_H +#define TIM_VX_DIRECTMAPOP_H + +#include "tim/vx/operation.h" + +namespace tim { +namespace vx { +// interface +class DirectMapOp : public Operation { + public: + DirectMapOp(Graph* graph, uint32_t kind, int in_cnt = 0, int out_cnt = 0, + DataLayout layout = DataLayout::ANY); +}; + +} // namespace vx + +} // namespace tim + +#endif diff --git a/include/tim/vx/operation.h b/include/tim/vx/operation.h index e978e45..eff4077 100644 --- a/include/tim/vx/operation.h +++ b/include/tim/vx/operation.h @@ -30,12 +30,11 @@ namespace tim { namespace vx { -class OperationImpl; +class OpImpl; class Operation { public: - Operation(Graph* graph, uint32_t operation_id, - int input_cnt = 0, int ouput_cnt = 0, DataLayout layout = DataLayout::ANY); + Operation(); virtual ~Operation(); virtual std::shared_ptr Clone(std::shared_ptr& graph) const = 0; Operation& BindInput(const std::shared_ptr& tensor); @@ -47,11 +46,11 @@ class Operation { RoundingPolicy rounding_policy = RoundingPolicy::RTNE, RoundType down_scale_size_rounding = RoundType::FLOOR, uint32_t accumulator_bits = 0); - std::unique_ptr& impl(); - const std::unique_ptr& impl() const; + std::unique_ptr& impl(); + const std::unique_ptr& impl() const; protected: - std::unique_ptr impl_; + std::unique_ptr impl_; }; } // namespace vx diff --git a/include/tim/vx/ops.h b/include/tim/vx/ops.h index e9b378f..3811c21 100644 --- a/include/tim/vx/ops.h +++ b/include/tim/vx/ops.h @@ -63,10 +63,12 @@ #include "tim/vx/ops/resize1d.h" #include "tim/vx/ops/resize.h" #include "tim/vx/ops/reverse.h" +#include "tim/vx/ops/rnn_cell.h" #include "tim/vx/ops/scatternd.h" #include "tim/vx/ops/select.h" #include "tim/vx/ops/shuffle_channel.h" #include "tim/vx/ops/simple_operations.h" +#include "tim/vx/ops/signal_frame.h" #include "tim/vx/ops/slice.h" #include "tim/vx/ops/softmax.h" #include "tim/vx/ops/space2batch.h" diff --git a/include/tim/vx/ops/activations.h b/include/tim/vx/ops/activations.h index b5b2ba7..3911493 100644 --- a/include/tim/vx/ops/activations.h +++ b/include/tim/vx/ops/activations.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_ACTIVATIONS_H_ #define TIM_VX_OPS_ACTIVATIONS_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -69,7 +69,7 @@ namespace ops { */ #define DECLARE_NO_PARAMETER_ACTIVATION(NAME) \ - class NAME : public Operation { \ + class NAME : public DirectMapOp { \ public: \ NAME(Graph* graph); \ std::shared_ptr Clone( \ @@ -90,7 +90,7 @@ DECLARE_NO_PARAMETER_ACTIVATION(SoftRelu) #undef DEFINE_NO_PARAMETER_ACTIVATION -class Prelu : public Operation { +class Prelu : public DirectMapOp { public: Prelu(Graph* graph, int axis); std::shared_ptr Clone( @@ -100,7 +100,7 @@ class Prelu : public Operation { int axis_; }; -class LeakyRelu : public Operation { +class LeakyRelu : public DirectMapOp { public: LeakyRelu(Graph* graph, float alpha); std::shared_ptr Clone( @@ -110,7 +110,7 @@ class LeakyRelu : public Operation { float alpha_; }; -class Linear : public Operation { +class Linear : public DirectMapOp { public: Linear(Graph* graph, float a, float b = 0.0); std::shared_ptr Clone( @@ -121,7 +121,7 @@ class Linear : public Operation { float b_; }; -class Gelu : public Operation { +class Gelu : public DirectMapOp { public: /**************************************************************************** *Non-approximate calculations will also have errors when the data type is diff --git a/include/tim/vx/ops/addn.h b/include/tim/vx/ops/addn.h index 3de97e5..cd7a105 100644 --- a/include/tim/vx/ops/addn.h +++ b/include/tim/vx/ops/addn.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_ADDN_H_ #define TIM_VX_OPS_ADDN_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -37,7 +37,7 @@ namespace ops { * ``` */ -class AddN : public Operation { +class AddN : public DirectMapOp { public: AddN(Graph* graph, uint32_t num_inputs); diff --git a/include/tim/vx/ops/arg.h b/include/tim/vx/ops/arg.h index 316ae8b..3f26cd0 100644 --- a/include/tim/vx/ops/arg.h +++ b/include/tim/vx/ops/arg.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_ARG_H_ #define TIM_VX_OPS_ARG_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -37,7 +37,7 @@ namespace ops { */ #define DECLARE_ARG_OP(NAME) \ - class Arg##NAME : public Operation { \ + class Arg##NAME : public DirectMapOp { \ public: \ Arg##NAME(Graph* graph, int32_t axis); \ std::shared_ptr Clone( \ diff --git a/include/tim/vx/ops/batch2space.h b/include/tim/vx/ops/batch2space.h index 61fe51d..809d1a7 100644 --- a/include/tim/vx/ops/batch2space.h +++ b/include/tim/vx/ops/batch2space.h @@ -26,7 +26,7 @@ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -43,7 +43,7 @@ namespace ops { * - crop : corp the output tensor for ROI usage. */ -class Batch2Space : public Operation { +class Batch2Space : public DirectMapOp { public: Batch2Space(Graph* graph, const std::vector& block_size, const std::vector& crop, diff --git a/include/tim/vx/ops/batchnorm.h b/include/tim/vx/ops/batchnorm.h index 6feb08d..57015ce 100644 --- a/include/tim/vx/ops/batchnorm.h +++ b/include/tim/vx/ops/batchnorm.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef OVXLIBXX_OPERATIONS_BATCHNORM_H_ #define OVXLIBXX_OPERATIONS_BATCHNORM_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -40,7 +40,7 @@ namespace ops { * $$y_i=\gamma\hat x_i+\beta\equiv BN_{\gamma,\beta}(x_i)$$ */ -class BatchNorm : public Operation { +class BatchNorm : public DirectMapOp { public: BatchNorm(Graph* graph, float eps, DataLayout input_layout = DataLayout::WHCN); diff --git a/include/tim/vx/ops/clip.h b/include/tim/vx/ops/clip.h index 9cc4c8a..6b0fda7 100644 --- a/include/tim/vx/ops/clip.h +++ b/include/tim/vx/ops/clip.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef OVXLIBXX_OPERATIONS_CLIP_H_ #define OVXLIBXX_OPERATIONS_CLIP_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { @@ -36,7 +36,7 @@ namespace ops { * Clip(x) : min if x <= min; x if min < x < max; max if x >= max */ -class Clip : public Operation { +class Clip : public DirectMapOp { public: Clip(Graph* graph, float min, float max); diff --git a/include/tim/vx/ops/concat.h b/include/tim/vx/ops/concat.h index 9c3c9aa..eab629d 100644 --- a/include/tim/vx/ops/concat.h +++ b/include/tim/vx/ops/concat.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_CONCAT_H_ #define TIM_VX_OPS_CONCAT_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -37,7 +37,7 @@ namespace ops { * - axis : Which axis to concat on. */ -class Concat : public Operation { +class Concat : public DirectMapOp { public: Concat(Graph* graph, uint32_t axis, int input_cnt); diff --git a/include/tim/vx/ops/conv1d.h b/include/tim/vx/ops/conv1d.h index 957d437..4d15f7f 100644 --- a/include/tim/vx/ops/conv1d.h +++ b/include/tim/vx/ops/conv1d.h @@ -26,13 +26,13 @@ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { namespace ops { -class Conv1d : public Operation { +class Conv1d : public DirectMapOp { public: Conv1d(Graph* graph, PadType padding, uint32_t stride, uint32_t dilation, int32_t multiplier = 0, diff --git a/include/tim/vx/ops/conv2d.h b/include/tim/vx/ops/conv2d.h index a997cf2..5c1b6fd 100644 --- a/include/tim/vx/ops/conv2d.h +++ b/include/tim/vx/ops/conv2d.h @@ -26,7 +26,7 @@ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -55,7 +55,7 @@ namespace ops { * - layout : WHCN or CWHN. */ -class Conv2d : public Operation { +class Conv2d : public DirectMapOp { public: Conv2d(Graph* graph, PadType padding, const std::array& stride, diff --git a/include/tim/vx/ops/deconv.h b/include/tim/vx/ops/deconv.h index 79c8f9b..86007e9 100644 --- a/include/tim/vx/ops/deconv.h +++ b/include/tim/vx/ops/deconv.h @@ -26,7 +26,7 @@ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -53,7 +53,7 @@ namespace ops { * - kernel_layout: Layout for kernel, WHIO by default. */ -class DeConv2d : public Operation { +class DeConv2d : public DirectMapOp { public: DeConv2d(Graph* graph, int32_t oc_count_, PadType pad_type, const std::array& ksize, diff --git a/include/tim/vx/ops/deconv1d.h b/include/tim/vx/ops/deconv1d.h index 1e30017..206cec7 100644 --- a/include/tim/vx/ops/deconv1d.h +++ b/include/tim/vx/ops/deconv1d.h @@ -26,7 +26,7 @@ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -49,7 +49,7 @@ namespace ops { * the output tensor. */ -class DeConv1d : public Operation { +class DeConv1d : public DirectMapOp { public: DeConv1d(Graph* graph, PadType pad_type, uint32_t stride, uint32_t output_padding, uint32_t group = 1, diff --git a/include/tim/vx/ops/depth2space.h b/include/tim/vx/ops/depth2space.h index ba33d14..275b8ae 100644 --- a/include/tim/vx/ops/depth2space.h +++ b/include/tim/vx/ops/depth2space.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_DEPTH2SPACE_H_ #define TIM_VX_OPS_DEPTH2SPACE_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -45,7 +45,7 @@ namespace ops { * - crop : corp the output tensor for ROI usage. */ -class DepthToSpace : public Operation { +class DepthToSpace : public DirectMapOp { public: DepthToSpace(Graph* Graph, int block_size, DataLayout layout = DataLayout::WHCN); diff --git a/include/tim/vx/ops/dropout.h b/include/tim/vx/ops/dropout.h index 7a5b6ff..a14b85c 100644 --- a/include/tim/vx/ops/dropout.h +++ b/include/tim/vx/ops/dropout.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef OVXLIBXX_OPERATIONS_DROPOUT_H_ #define OVXLIBXX_OPERATIONS_DROPOUT_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { @@ -40,7 +40,7 @@ namespace ops { * for Dropout operator. */ -class Dropout : public Operation { +class Dropout : public DirectMapOp { public: Dropout(Graph* graph, float ratio); diff --git a/include/tim/vx/ops/elementwise.h b/include/tim/vx/ops/elementwise.h index 23a758a..85d70b6 100644 --- a/include/tim/vx/ops/elementwise.h +++ b/include/tim/vx/ops/elementwise.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_ELEMENTWISE_H_ #define TIM_VX_OPS_ELEMENTWISE_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -67,7 +67,7 @@ namespace ops { */ #define DECLARE_ELEMENTWISE_OP(NAME) \ - class NAME : public Operation { \ + class NAME : public DirectMapOp { \ public: \ NAME(Graph* graph); \ std::shared_ptr Clone( \ @@ -81,14 +81,14 @@ DECLARE_ELEMENTWISE_OP(Sub) DECLARE_ELEMENTWISE_OP(Pow) DECLARE_ELEMENTWISE_OP(FloorDiv) -class Multiply : public Operation { +class Multiply : public DirectMapOp { public: Multiply(Graph* graph, float scale = 1.0f); std::shared_ptr Clone(std::shared_ptr& graph) const override; }; -class Div : public Operation { +class Div : public DirectMapOp { public: Div(Graph* graph, float scale = 1.0f); diff --git a/include/tim/vx/ops/erf.h b/include/tim/vx/ops/erf.h index 2aabeb5..0ecfa23 100644 --- a/include/tim/vx/ops/erf.h +++ b/include/tim/vx/ops/erf.h @@ -24,7 +24,7 @@ #ifndef TIM_VX_OPS_ERF_H_ #define TIM_VX_OPS_ERF_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" #include "tim/vx/types.h" namespace tim { @@ -39,7 +39,7 @@ namespace ops { * - no parameters */ -class Erf : public Operation { +class Erf : public DirectMapOp { public: Erf(Graph* graph); std::shared_ptr Clone(std::shared_ptr& graph) const override; diff --git a/include/tim/vx/ops/fullyconnected.h b/include/tim/vx/ops/fullyconnected.h index 38ffb04..7192bc5 100644 --- a/include/tim/vx/ops/fullyconnected.h +++ b/include/tim/vx/ops/fullyconnected.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_FULLYCONNECTED_H_ #define TIM_VX_OPS_FULLYCONNECTED_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -39,7 +39,7 @@ namespace ops { * - weights: the output channel number for weight tensor. */ -class FullyConnected : public Operation { +class FullyConnected : public DirectMapOp { public: FullyConnected(Graph* graph, uint32_t axis); FullyConnected(Graph* graph, uint32_t axis, uint32_t weights); diff --git a/include/tim/vx/ops/gather.h b/include/tim/vx/ops/gather.h index 953190f..1728ac2 100644 --- a/include/tim/vx/ops/gather.h +++ b/include/tim/vx/ops/gather.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_GATHER_H_ #define TIM_VX_OPS_GATHER_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -35,7 +35,7 @@ namespace ops { * Gather slices from input, **axis** according to **indices**. */ -class Gather : public Operation { +class Gather : public DirectMapOp { public: Gather(Graph* Graph, int axis); diff --git a/include/tim/vx/ops/gathernd.h b/include/tim/vx/ops/gathernd.h index 31f8151..1d2e8e7 100644 --- a/include/tim/vx/ops/gathernd.h +++ b/include/tim/vx/ops/gathernd.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_GATHERND_H_ #define TIM_VX_OPS_GATHERND_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -35,7 +35,7 @@ namespace ops { * An operation similar to Gather but gathers across multiple axis at once. */ -class GatherNd : public Operation { +class GatherNd : public DirectMapOp { public: GatherNd(Graph* Graph); diff --git a/include/tim/vx/ops/groupedconv1d.h b/include/tim/vx/ops/groupedconv1d.h index c00054e..388bb8d 100644 --- a/include/tim/vx/ops/groupedconv1d.h +++ b/include/tim/vx/ops/groupedconv1d.h @@ -26,7 +26,7 @@ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -53,7 +53,7 @@ namespace ops { * - layout : WCN or CWN. */ -class GroupedConv1d : public Operation { +class GroupedConv1d : public DirectMapOp { public: GroupedConv1d(Graph* graph, PadType padding, uint32_t stride, diff --git a/include/tim/vx/ops/groupedconv2d.h b/include/tim/vx/ops/groupedconv2d.h index 9e9a5eb..d888fe6 100644 --- a/include/tim/vx/ops/groupedconv2d.h +++ b/include/tim/vx/ops/groupedconv2d.h @@ -26,7 +26,7 @@ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -53,7 +53,7 @@ namespace ops { * - layout : WHCN or CWHN. */ -class GroupedConv2d : public Operation { +class GroupedConv2d : public DirectMapOp { public: GroupedConv2d(Graph* graph, PadType padding, const std::array& strides, diff --git a/include/tim/vx/ops/instancenormalization.h b/include/tim/vx/ops/instancenormalization.h index 0ffd60e..8764587 100644 --- a/include/tim/vx/ops/instancenormalization.h +++ b/include/tim/vx/ops/instancenormalization.h @@ -23,12 +23,12 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_INSTANCENOMALIZATION_H_ #define TIM_VX_OPS_INSTANCENOMALIZATION_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { namespace ops { -class InstanceNormalization : public Operation { +class InstanceNormalization : public DirectMapOp { public: InstanceNormalization(Graph* graph, float eps = 1e-5f); diff --git a/include/tim/vx/ops/l2normalization.h b/include/tim/vx/ops/l2normalization.h index 9e0e00e..876483f 100644 --- a/include/tim/vx/ops/l2normalization.h +++ b/include/tim/vx/ops/l2normalization.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_L2NOMALIZATION_H_ #define TIM_VX_OPS_L2NOMALIZATION_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" /** * ## L2Normalization @@ -40,7 +40,7 @@ namespace tim { namespace vx { namespace ops { -class L2Normalization : public Operation { +class L2Normalization : public DirectMapOp { public: L2Normalization(Graph* graph, int32_t axis); diff --git a/include/tim/vx/ops/layernormalization.h b/include/tim/vx/ops/layernormalization.h index 55e2d4e..34216a4 100644 --- a/include/tim/vx/ops/layernormalization.h +++ b/include/tim/vx/ops/layernormalization.h @@ -25,12 +25,12 @@ #define TIM_VX_OPS_LAYERNOMALIZATION_H_ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { namespace ops { -class LayerNormalization : public Operation { +class LayerNormalization : public DirectMapOp { public: LayerNormalization(Graph* graph, int32_t axis = 0, float eps = 1e-5f); diff --git a/include/tim/vx/ops/localresponsenormalization.h b/include/tim/vx/ops/localresponsenormalization.h index 0b0bc24..237151e 100644 --- a/include/tim/vx/ops/localresponsenormalization.h +++ b/include/tim/vx/ops/localresponsenormalization.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_LOCALRESPONSENORMALIZATION_H_ #define TIM_VX_OPS_LOCALRESPONSENORMALIZATION_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" /** * ## LocalResponseNormalization @@ -40,7 +40,7 @@ namespace tim { namespace vx { namespace ops { -class LocalResponseNormalization : public Operation { +class LocalResponseNormalization : public DirectMapOp { public: LocalResponseNormalization(Graph* graph, uint32_t size, float alpha, float beta, float bias, int32_t axis); diff --git a/include/tim/vx/ops/logical.h b/include/tim/vx/ops/logical.h index 66a2573..369c7b5 100644 --- a/include/tim/vx/ops/logical.h +++ b/include/tim/vx/ops/logical.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_LOGICAL_H_ #define TIM_VX_OPS_LOGICAL_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -40,7 +40,7 @@ namespace ops { */ #define DECLARE_LOGICAL_OP(NAME) \ - class Logical##NAME : public Operation { \ + class Logical##NAME : public DirectMapOp { \ public: \ Logical##NAME(Graph* graph); \ std::shared_ptr Clone( \ diff --git a/include/tim/vx/ops/logsoftmax.h b/include/tim/vx/ops/logsoftmax.h index 67cdde2..53fb94d 100644 --- a/include/tim/vx/ops/logsoftmax.h +++ b/include/tim/vx/ops/logsoftmax.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_LOG_SOFTMAX_H_ #define TIM_VX_OPS_LOG_SOFTMAX_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -39,7 +39,7 @@ namespace ops { * ``` */ -class LogSoftmax : public Operation { +class LogSoftmax : public DirectMapOp { public: LogSoftmax(Graph* graph, int32_t axis, float beta = 1.f); diff --git a/include/tim/vx/ops/matmul.h b/include/tim/vx/ops/matmul.h index 9a63013..a0c420e 100644 --- a/include/tim/vx/ops/matmul.h +++ b/include/tim/vx/ops/matmul.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_MATMUL_H_ #define TIM_VX_OPS_MATMUL_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -40,7 +40,7 @@ namespace ops { * - adjoint_b: If True, b is conjugated and transposed before multiplication. */ -class Matmul : public Operation { +class Matmul : public DirectMapOp { public: Matmul(Graph* graph, bool transpose_a = false, bool transpose_b = false, bool adjoint_a = false, bool adjoint_b = false); diff --git a/include/tim/vx/ops/maxpoolwithargmax.h b/include/tim/vx/ops/maxpoolwithargmax.h index 5bfba13..ca44ae5 100644 --- a/include/tim/vx/ops/maxpoolwithargmax.h +++ b/include/tim/vx/ops/maxpoolwithargmax.h @@ -26,7 +26,7 @@ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" #include "tim/vx/types.h" namespace tim { @@ -44,7 +44,7 @@ namespace ops { * - round_type : CEILING or FLOOR. */ -class MaxpoolWithArgmax : public Operation { +class MaxpoolWithArgmax : public DirectMapOp { public: MaxpoolWithArgmax(Graph* graph, PadType padding, const std::array& ksize, diff --git a/include/tim/vx/ops/maxunpool2d.h b/include/tim/vx/ops/maxunpool2d.h index bd004d6..60ae4ee 100644 --- a/include/tim/vx/ops/maxunpool2d.h +++ b/include/tim/vx/ops/maxunpool2d.h @@ -26,7 +26,7 @@ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" #include "tim/vx/types.h" namespace tim { @@ -42,7 +42,7 @@ namespace ops { * - ksize : filter size. */ -class MaxUnpool2d : public Operation { +class MaxUnpool2d : public DirectMapOp { public: MaxUnpool2d(Graph* graph, const std::array& ksize, const std::array& stride, DataLayout layout = DataLayout::WHCN); diff --git a/include/tim/vx/ops/moments.h b/include/tim/vx/ops/moments.h index 6d34481..b1fc39b 100644 --- a/include/tim/vx/ops/moments.h +++ b/include/tim/vx/ops/moments.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_MOMENTS_H_ #define TIM_VX_OPS_MOMENTS_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -39,7 +39,7 @@ namespace ops { * - keep_dims : Produce moments with the same dimensionality as input. */ -class Moments : public Operation { +class Moments : public DirectMapOp { public: Moments(Graph* graph, const std::vector& axes, bool keep_dims = false); diff --git a/include/tim/vx/ops/nbg.h b/include/tim/vx/ops/nbg.h index 0d0bd6d..051c302 100644 --- a/include/tim/vx/ops/nbg.h +++ b/include/tim/vx/ops/nbg.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_NBG_H_ #define TIM_VX_OPS_NBG_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -36,7 +36,7 @@ namespace ops { * a bianry file. */ -class NBG : public Operation { +class NBG : public DirectMapOp { public: NBG(Graph* graph, const char* binary, size_t input_count, size_t output_count); diff --git a/include/tim/vx/ops/pad.h b/include/tim/vx/ops/pad.h index cb863a1..4a214fa 100644 --- a/include/tim/vx/ops/pad.h +++ b/include/tim/vx/ops/pad.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPERATION_PAD_H_ #define TIM_VX_OPERATION_PAD_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -37,7 +37,7 @@ namespace ops { * - const_val : the value to pad. */ -class Pad : public Operation { +class Pad : public DirectMapOp { public: Pad(Graph* graph, const std::vector& front_size, const std::vector& back_size, int32_t const_val); diff --git a/include/tim/vx/ops/pool2d.h b/include/tim/vx/ops/pool2d.h index 73188e7..b5ba042 100644 --- a/include/tim/vx/ops/pool2d.h +++ b/include/tim/vx/ops/pool2d.h @@ -26,7 +26,7 @@ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" #include "tim/vx/types.h" namespace tim { @@ -63,7 +63,7 @@ namespace ops { * */ -class Pool2d : public Operation { +class Pool2d : public DirectMapOp { public: // for Classic Pool2d Pool2d(Graph* graph, PoolType type, PadType padding, diff --git a/include/tim/vx/ops/reduce.h b/include/tim/vx/ops/reduce.h index 592bfd7..1964971 100644 --- a/include/tim/vx/ops/reduce.h +++ b/include/tim/vx/ops/reduce.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_REDUCE_H_ #define TIM_VX_OPS_REDUCE_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -95,7 +95,7 @@ namespace ops { */ #define DECLARE_REDUCE_OP(NAME) \ - class Reduce##NAME : public Operation { \ + class Reduce##NAME : public DirectMapOp { \ public: \ Reduce##NAME(Graph* graph, const std::vector& axis, \ bool keep_dims); \ diff --git a/include/tim/vx/ops/relational_operations.h b/include/tim/vx/ops/relational_operations.h index a2a2732..ed02d8a 100644 --- a/include/tim/vx/ops/relational_operations.h +++ b/include/tim/vx/ops/relational_operations.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_RELATIONAL_H_ #define TIM_VX_OPS_RELATIONAL_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -56,7 +56,7 @@ namespace ops { */ #define DECLARE_RELATIONAL_OP(NAME) \ - class NAME : public Operation { \ + class NAME : public DirectMapOp { \ public: \ NAME(Graph* graph); \ std::shared_ptr Clone( \ diff --git a/include/tim/vx/ops/reorg.h b/include/tim/vx/ops/reorg.h index cd8fe62..69a5dcc 100644 --- a/include/tim/vx/ops/reorg.h +++ b/include/tim/vx/ops/reorg.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_REORG_H_ #define TIM_VX_OPS_REORG_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -35,7 +35,7 @@ namespace ops { * The layer used in YOLOv2. See also https://github.com/pjreddie/darknet/blob/master/src/reorg_layer.c */ -class Reorg : public Operation { +class Reorg : public DirectMapOp { public: Reorg(Graph* graph, const uint32_t stride); diff --git a/include/tim/vx/ops/reshape.h b/include/tim/vx/ops/reshape.h index 8b6b16b..b02ccc9 100644 --- a/include/tim/vx/ops/reshape.h +++ b/include/tim/vx/ops/reshape.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_RESHAPE_H_ #define TIM_VX_OPS_RESHAPE_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -37,7 +37,7 @@ namespace ops { * - size : defining the shape of the output tensor. */ -class Reshape : public Operation { +class Reshape : public DirectMapOp { public: Reshape(Graph* graph, const std::vector& size); diff --git a/include/tim/vx/ops/resize.h b/include/tim/vx/ops/resize.h index 32e5687..26bb7aa 100644 --- a/include/tim/vx/ops/resize.h +++ b/include/tim/vx/ops/resize.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_RESIZE_H_ #define TIM_VX_OPS_RESIZE_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -44,7 +44,7 @@ namespace ops { * - target_height / target_width : output height / width. DO NOT use it with factor together. */ -class Resize : public Operation { +class Resize : public DirectMapOp { public: Resize(Graph* graph, ResizeType type, float factor, bool align_corners, bool half_pixel_centers, int target_height, int target_width, diff --git a/include/tim/vx/ops/resize1d.h b/include/tim/vx/ops/resize1d.h index 0f76c76..c1f0968 100644 --- a/include/tim/vx/ops/resize1d.h +++ b/include/tim/vx/ops/resize1d.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_RESIZE1D_H_ #define TIM_VX_OPS_RESIZE1D_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -44,7 +44,7 @@ namespace ops { * - target_height / target_width : output height / width. DO NOT use it with factor together. */ -class Resize1d : public Operation { +class Resize1d : public DirectMapOp { public: Resize1d(Graph* graph, ResizeType type, float factor, bool align_corners, bool half_pixel_centers, int target_size, diff --git a/include/tim/vx/ops/reverse.h b/include/tim/vx/ops/reverse.h index 731950a..4bbc98a 100644 --- a/include/tim/vx/ops/reverse.h +++ b/include/tim/vx/ops/reverse.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_REVERSE_H_ #define TIM_VX_OPS_REVERSE_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -37,7 +37,7 @@ namespace ops { * - axis : The indices of the dimensions to reverse. */ -class Reverse : public Operation { +class Reverse : public DirectMapOp { public: Reverse(Graph* graph, const std::vector& axis); diff --git a/include/tim/vx/ops/rnn_cell.h b/include/tim/vx/ops/rnn_cell.h new file mode 100644 index 0000000..1419803 --- /dev/null +++ b/include/tim/vx/ops/rnn_cell.h @@ -0,0 +1,54 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_RNN_CELL_H_ +#define TIM_VX_OPS_RNN_CELL_H_ + +#include "tim/vx/operation.h" +namespace tim { +namespace vx { +namespace ops { + +class RNNCell : public Operation{ + public: + enum ActivationType { + kNONE = 0, + kRELU = 1, + kRELU1 = 2, + kRELU6 = 3, + kTANH = 4, + kSIGMOID = 6, + kHARDSIGMOID = 31, /* temporary use 31*/ + }; + RNNCell(Graph* graph, ActivationType activation); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + + protected: + const ActivationType activation_; +}; + +} // namespace ops +} // namespace vx +} // namespace tim + +#endif /* TIM_VX_OPS_RNN_CELL_H_ */ \ No newline at end of file diff --git a/include/tim/vx/ops/scatternd.h b/include/tim/vx/ops/scatternd.h index 8de7d4b..1a211f1 100644 --- a/include/tim/vx/ops/scatternd.h +++ b/include/tim/vx/ops/scatternd.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_SCATTERND_H_ #define TIM_VX_OPS_SCATTERND_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -37,7 +37,7 @@ namespace ops { * - shape : The shape of the resulting tensor. */ -class ScatterND : public Operation { +class ScatterND : public DirectMapOp { public: ScatterND(Graph* graph, const std::vector& shape); diff --git a/include/tim/vx/ops/select.h b/include/tim/vx/ops/select.h index 38dfb32..6f87630 100644 --- a/include/tim/vx/ops/select.h +++ b/include/tim/vx/ops/select.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_SELECT_H_ #define TIM_VX_OPS_SELECT_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -36,7 +36,7 @@ namespace ops { * from both input tensors: O[i] = C[i] ? x[i] : y[i]. */ -class Select : public Operation { +class Select : public DirectMapOp { public: Select(Graph* graph); diff --git a/include/tim/vx/ops/shuffle_channel.h b/include/tim/vx/ops/shuffle_channel.h index c3c7bfa..cb45f60 100644 --- a/include/tim/vx/ops/shuffle_channel.h +++ b/include/tim/vx/ops/shuffle_channel.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_SHUFFLE_H_ #define TIM_VX_OPS_SHUFFLE_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -38,7 +38,7 @@ namespace ops { * ``` */ -class ShuffleChannel : public Operation { +class ShuffleChannel : public DirectMapOp { public: explicit ShuffleChannel(Graph* graph, int32_t num_groups, int32_t index_axis); std::shared_ptr Clone(std::shared_ptr& graph) const override; diff --git a/include/tim/vx/ops/signal_frame.h b/include/tim/vx/ops/signal_frame.h index 3203873..c44f484 100644 --- a/include/tim/vx/ops/signal_frame.h +++ b/include/tim/vx/ops/signal_frame.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_SIGNALFRAME_H_ #define TIM_VX_OPS_SIGNALFRAME_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -39,7 +39,7 @@ namespace ops { * ``` */ -class SignalFrame : public Operation { +class SignalFrame : public DirectMapOp { public: SignalFrame(Graph* graph, uint32_t window_length, uint32_t step, uint32_t pad_end=0, uint32_t axis=0); diff --git a/include/tim/vx/ops/simple_operations.h b/include/tim/vx/ops/simple_operations.h index 715848c..084e377 100644 --- a/include/tim/vx/ops/simple_operations.h +++ b/include/tim/vx/ops/simple_operations.h @@ -23,14 +23,14 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_SIMPLE_OPERATIONS_H_ #define TIM_VX_OPS_SIMPLE_OPERATIONS_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { namespace ops { #define DECLARE_SIMPLE_OP(NAME) \ - class NAME : public Operation { \ + class NAME : public DirectMapOp { \ public: \ NAME(Graph* graph); \ std::shared_ptr Clone( \ diff --git a/include/tim/vx/ops/slice.h b/include/tim/vx/ops/slice.h index 8cc6e31..880ebd8 100644 --- a/include/tim/vx/ops/slice.h +++ b/include/tim/vx/ops/slice.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_SLICE_H_ #define TIM_VX_OPS_SLICE_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -38,7 +38,7 @@ namespace ops { * - length : the size of the slice in each dimension. */ -class Slice : public Operation { +class Slice : public DirectMapOp { public: Slice(Graph* graph, uint32_t dims, diff --git a/include/tim/vx/ops/softmax.h b/include/tim/vx/ops/softmax.h index 525ce55..12f0610 100644 --- a/include/tim/vx/ops/softmax.h +++ b/include/tim/vx/ops/softmax.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_SOFTMAX_H_ #define TIM_VX_OPS_SOFTMAX_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -42,7 +42,7 @@ namespace ops { * ``` */ -class Softmax : public Operation { +class Softmax : public DirectMapOp { public: Softmax(Graph* graph, float beta, int32_t axis); diff --git a/include/tim/vx/ops/space2batch.h b/include/tim/vx/ops/space2batch.h index e8e855a..9799104 100644 --- a/include/tim/vx/ops/space2batch.h +++ b/include/tim/vx/ops/space2batch.h @@ -26,7 +26,7 @@ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -46,7 +46,7 @@ namespace ops { * - pad : the paddings for each spatial dimension of the input tensor. */ -class Space2Batch : public Operation { +class Space2Batch : public DirectMapOp { public: Space2Batch(Graph* graph, const std::vector& block_size, const std::vector& pad, diff --git a/include/tim/vx/ops/space2depth.h b/include/tim/vx/ops/space2depth.h index 832197f..9bea5d7 100644 --- a/include/tim/vx/ops/space2depth.h +++ b/include/tim/vx/ops/space2depth.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_SPACE2DEPTH_H_ #define TIM_VX_OPS_SPACE2DEPTH_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -38,7 +38,7 @@ namespace ops { * transformation of DepthToSpace. */ -class SpaceToDepth : public Operation { +class SpaceToDepth : public DirectMapOp { public: SpaceToDepth(Graph* graph, std::vector block_size, DataLayout layout = DataLayout::WHCN); diff --git a/include/tim/vx/ops/spatial_transformer.h b/include/tim/vx/ops/spatial_transformer.h index 3972b07..5775a14 100644 --- a/include/tim/vx/ops/spatial_transformer.h +++ b/include/tim/vx/ops/spatial_transformer.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_SPATIAL_TRANSFORMER_H_ #define TIM_VX_OPS_SPATIAL_TRANSFORMER_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -40,7 +40,7 @@ namespace ops { It is the output of the localization network. */ -class SpatialTransformer : public Operation { +class SpatialTransformer : public DirectMapOp { public: SpatialTransformer(Graph* graph, uint32_t output_h, uint32_t output_w, bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3, diff --git a/include/tim/vx/ops/split.h b/include/tim/vx/ops/split.h index ec3ec09..0c062b6 100644 --- a/include/tim/vx/ops/split.h +++ b/include/tim/vx/ops/split.h @@ -25,7 +25,7 @@ #define TIM_VX_OPS_SPLIT_H_ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -40,7 +40,7 @@ namespace ops { * - slices : indicating the number of splits along given axis. */ -class Split : public Operation { +class Split : public DirectMapOp { public: Split(Graph* graph, uint32_t axis, std::vector slices); diff --git a/include/tim/vx/ops/squeeze.h b/include/tim/vx/ops/squeeze.h index af06edb..76841fd 100644 --- a/include/tim/vx/ops/squeeze.h +++ b/include/tim/vx/ops/squeeze.h @@ -24,7 +24,7 @@ #ifndef TIM_VX_OPS_SQUEEZE_H_ #define TIM_VX_OPS_SQUEEZE_H_ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -38,7 +38,7 @@ namespace ops { * - axis : the dimensions to squeeze. */ -class Squeeze : public Operation { +class Squeeze : public DirectMapOp { public: Squeeze(Graph* graph, std::vector axis); diff --git a/include/tim/vx/ops/stack.h b/include/tim/vx/ops/stack.h index eefe2c3..8f1fde6 100644 --- a/include/tim/vx/ops/stack.h +++ b/include/tim/vx/ops/stack.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_STACK_H_ #define TIM_VX_OPS_STACK_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -36,7 +36,7 @@ namespace ops { * each tensor in values, by packing them along the **axis** dimension. */ -class Stack : public Operation { +class Stack : public DirectMapOp { public: Stack(Graph* graph, uint32_t axis, int input_cnt); diff --git a/include/tim/vx/ops/stridedslice.h b/include/tim/vx/ops/stridedslice.h index d369069..c0d75f1 100644 --- a/include/tim/vx/ops/stridedslice.h +++ b/include/tim/vx/ops/stridedslice.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_STRIDEDSLICE_H_ #define TIM_VX_OPS_STRIDEDSLICE_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -52,7 +52,7 @@ namespace ops { * e.g. begin[i] = x, end[i] = x + 1. */ -class StridedSlice : public Operation { +class StridedSlice : public DirectMapOp { public: StridedSlice(Graph* graph, const std::vector begin_dims, const std::vector end_dims, diff --git a/include/tim/vx/ops/svdf.h b/include/tim/vx/ops/svdf.h index 4ed5bea..cc7256c 100644 --- a/include/tim/vx/ops/svdf.h +++ b/include/tim/vx/ops/svdf.h @@ -26,7 +26,7 @@ #include -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" #include "tim/vx/types.h" namespace tim { @@ -43,7 +43,7 @@ namespace ops { * - spectrogram_length : corresponds to the fixed-size of the memory. */ -class Svdf : public Operation { +class Svdf : public DirectMapOp { public: Svdf(Graph* graph, int32_t rank, int32_t num_units, int32_t spectrogram_length); diff --git a/include/tim/vx/ops/tile.h b/include/tim/vx/ops/tile.h index 14c393b..6caaa8f 100644 --- a/include/tim/vx/ops/tile.h +++ b/include/tim/vx/ops/tile.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_TILE_H_ #define TIM_VX_OPS_TILE_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -37,7 +37,7 @@ namespace ops { * Length must be the same as the number of dimensions in input. */ -class Tile : public Operation { +class Tile : public DirectMapOp { public: Tile(Graph* graph, const std::vector& multiples); diff --git a/include/tim/vx/ops/transpose.h b/include/tim/vx/ops/transpose.h index 6cc6176..70d6d9c 100644 --- a/include/tim/vx/ops/transpose.h +++ b/include/tim/vx/ops/transpose.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_TRANSPOSE_H_ #define TIM_VX_OPS_TRANSPOSE_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -41,7 +41,7 @@ namespace ops { * 2-D input Tensors. */ -class Transpose : public Operation { +class Transpose : public DirectMapOp { public: Transpose(Graph* graph, const std::vector& perm); diff --git a/include/tim/vx/ops/unidirectional_sequence_lstm.h b/include/tim/vx/ops/unidirectional_sequence_lstm.h index 629ad9f..a9ddba5 100644 --- a/include/tim/vx/ops/unidirectional_sequence_lstm.h +++ b/include/tim/vx/ops/unidirectional_sequence_lstm.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_UNIDIRECTIONAL_SEQUENCE_LSTM_H_ #define TIM_VX_OPS_UNIDIRECTIONAL_SEQUENCE_LSTM_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -32,7 +32,7 @@ namespace ops { * ## Unidirectional sequence lstm * how to bind input/output: take unidirectional_sequence_lstm_test.cc */ - class UnidirectionalSequenceLstm: public Operation { + class UnidirectionalSequenceLstm: public DirectMapOp { public: enum ActivationType { kNONE = 0, diff --git a/include/tim/vx/ops/unstack.h b/include/tim/vx/ops/unstack.h index 73e1d77..ff663f3 100644 --- a/include/tim/vx/ops/unstack.h +++ b/include/tim/vx/ops/unstack.h @@ -23,7 +23,7 @@ *****************************************************************************/ #ifndef TIM_VX_OPS_UNSTACK_H_ #define TIM_VX_OPS_UNSTACK_H_ -#include "tim/vx/operation.h" +#include "tim/vx/direct_map_op.h" namespace tim { namespace vx { @@ -37,7 +37,7 @@ namespace ops { * Negative values wrap around, so the valid range is [-R, R). */ -class Unstack : public Operation { +class Unstack : public DirectMapOp { public: Unstack(Graph* graph, int32_t axis, uint32_t output_num); diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index 7a94277..632d417 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -199,7 +199,7 @@ std::vector> HandleLayoutInfer( std::shared_ptr& ctx, const std::shared_ptr& op) { ctx->MarkVisited(op); - auto op_id = op->impl()->operation_id_; + auto op_id = op->impl()->kind_; std::vector> next_tensors; switch (op_id) { REGIST_LAYOUT_INFERENCE(VSI_NN_OP_CONV2D, Conv2d); diff --git a/src/tim/transform/ops/activation_layout_inference.h b/src/tim/transform/ops/activation_layout_inference.h index c1e22fb..bf35f6d 100644 --- a/src/tim/transform/ops/activation_layout_inference.h +++ b/src/tim/transform/ops/activation_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { diff --git a/src/tim/transform/ops/addn_layout_inference.h b/src/tim/transform/ops/addn_layout_inference.h index 51d2af3..9bac523 100644 --- a/src/tim/transform/ops/addn_layout_inference.h +++ b/src/tim/transform/ops/addn_layout_inference.h @@ -25,7 +25,7 @@ #define TIM_LAYOUT_INFER_ADDN_LAYOUT_INFERENCE_H_ #include "ops/op_layout_inference.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "tim/vx/ops/addn.h" namespace tim { diff --git a/src/tim/transform/ops/arg_layout_inference.h b/src/tim/transform/ops/arg_layout_inference.h index 18138d8..dfbf044 100644 --- a/src/tim/transform/ops/arg_layout_inference.h +++ b/src/tim/transform/ops/arg_layout_inference.h @@ -25,7 +25,7 @@ #define TIM_LAYOUT_INFER_ARG_OPS_LAYOUT_INFERENCE_H_ #include "ops/op_layout_inference.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "tim/vx/ops/arg.h" namespace tim { namespace transform { diff --git a/src/tim/transform/ops/batch2space_layout_inference.h b/src/tim/transform/ops/batch2space_layout_inference.h index 876b885..055029e 100644 --- a/src/tim/transform/ops/batch2space_layout_inference.h +++ b/src/tim/transform/ops/batch2space_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { class Batch2SpaceLayoutInfer : public OpLayoutInfer { diff --git a/src/tim/transform/ops/batchnorm_layout_inference.h b/src/tim/transform/ops/batchnorm_layout_inference.h index b1e7518..5afd119 100644 --- a/src/tim/transform/ops/batchnorm_layout_inference.h +++ b/src/tim/transform/ops/batchnorm_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "op_impl.h" namespace tim { namespace transform { class BatchNormLayoutInfer : public OpLayoutInfer { diff --git a/src/tim/transform/ops/concat_layout_inferene.h b/src/tim/transform/ops/concat_layout_inferene.h index 9643ddf..848800b 100644 --- a/src/tim/transform/ops/concat_layout_inferene.h +++ b/src/tim/transform/ops/concat_layout_inferene.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { diff --git a/src/tim/transform/ops/conv2d_layout_inference.h b/src/tim/transform/ops/conv2d_layout_inference.h index 368ea02..03f8478 100644 --- a/src/tim/transform/ops/conv2d_layout_inference.h +++ b/src/tim/transform/ops/conv2d_layout_inference.h @@ -26,7 +26,7 @@ #include "tim/vx/ops/conv2d.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "permute_vector.h" #include "ops/op_layout_inference.h" diff --git a/src/tim/transform/ops/deconv2d_layout_inference.h b/src/tim/transform/ops/deconv2d_layout_inference.h index b91be41..291f847 100644 --- a/src/tim/transform/ops/deconv2d_layout_inference.h +++ b/src/tim/transform/ops/deconv2d_layout_inference.h @@ -26,7 +26,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "tim/vx/ops/deconv.h" namespace tim { diff --git a/src/tim/transform/ops/default_layout_inference.h b/src/tim/transform/ops/default_layout_inference.h index 9f54c4b..1797de9 100644 --- a/src/tim/transform/ops/default_layout_inference.h +++ b/src/tim/transform/ops/default_layout_inference.h @@ -33,7 +33,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { diff --git a/src/tim/transform/ops/depth2space_layout_inference.h b/src/tim/transform/ops/depth2space_layout_inference.h index 84ec4d5..7df177a 100644 --- a/src/tim/transform/ops/depth2space_layout_inference.h +++ b/src/tim/transform/ops/depth2space_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { diff --git a/src/tim/transform/ops/elementwise_layout_inference.h b/src/tim/transform/ops/elementwise_layout_inference.h index ab6b6ff..472446b 100644 --- a/src/tim/transform/ops/elementwise_layout_inference.h +++ b/src/tim/transform/ops/elementwise_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { diff --git a/src/tim/transform/ops/fullyconnected_layout_inference.h b/src/tim/transform/ops/fullyconnected_layout_inference.h index a49febc..b86c53d 100644 --- a/src/tim/transform/ops/fullyconnected_layout_inference.h +++ b/src/tim/transform/ops/fullyconnected_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { diff --git a/src/tim/transform/ops/gather_layout_inference.h b/src/tim/transform/ops/gather_layout_inference.h index 72b145a..db4db9d 100644 --- a/src/tim/transform/ops/gather_layout_inference.h +++ b/src/tim/transform/ops/gather_layout_inference.h @@ -25,7 +25,7 @@ #define TIM_LAYOUT_INFER_GATHER_LAYOUT_INFERENCE_H_ #include "ops/op_layout_inference.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "tim/vx/ops/gather.h" namespace tim { diff --git a/src/tim/transform/ops/gather_nd_layout_inference.h b/src/tim/transform/ops/gather_nd_layout_inference.h index 9c8fc23..7db8114 100644 --- a/src/tim/transform/ops/gather_nd_layout_inference.h +++ b/src/tim/transform/ops/gather_nd_layout_inference.h @@ -25,7 +25,7 @@ #define TIM_LAYOUT_INFER_GATHER_ND_LAYOUT_INFERENCE_H_ #include "ops/op_layout_inference.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "tim/vx/ops/gathernd.h" namespace tim { diff --git a/src/tim/transform/ops/l2normalization_layout_inference.h b/src/tim/transform/ops/l2normalization_layout_inference.h index 027cecc..ea80e3a 100644 --- a/src/tim/transform/ops/l2normalization_layout_inference.h +++ b/src/tim/transform/ops/l2normalization_layout_inference.h @@ -25,7 +25,7 @@ #define TIM_LAYOUT_INFER_L2_NORMALIZATION_LAYOUT_INFERENCE_H_ #include "ops/op_layout_inference.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "tim/vx/ops/l2normalization.h" namespace tim { diff --git a/src/tim/transform/ops/logical_layout_inference.h b/src/tim/transform/ops/logical_layout_inference.h index 848f0eb..05e5c87 100644 --- a/src/tim/transform/ops/logical_layout_inference.h +++ b/src/tim/transform/ops/logical_layout_inference.h @@ -25,7 +25,7 @@ #define TIM_LAYOUT_INFER_LOGICAL_OPS_LAYOUT_INFERENCE_H_ #include "ops/op_layout_inference.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "tim/vx/ops/logical.h" namespace tim { diff --git a/src/tim/transform/ops/lrn_layout_inference.h b/src/tim/transform/ops/lrn_layout_inference.h index c541007..7251411 100644 --- a/src/tim/transform/ops/lrn_layout_inference.h +++ b/src/tim/transform/ops/lrn_layout_inference.h @@ -27,7 +27,7 @@ #include "tim/vx/ops/localresponsenormalization.h" #include "ops/op_layout_inference.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { diff --git a/src/tim/transform/ops/op_layout_inference.cc b/src/tim/transform/ops/op_layout_inference.cc index 15938a5..9e89f43 100644 --- a/src/tim/transform/ops/op_layout_inference.cc +++ b/src/tim/transform/ops/op_layout_inference.cc @@ -24,7 +24,7 @@ #include "op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "tim/vx/ops/transpose.h" #include "type_utils.h" diff --git a/src/tim/transform/ops/pad_layout_inference.h b/src/tim/transform/ops/pad_layout_inference.h index 607927c..cb4fe5a 100644 --- a/src/tim/transform/ops/pad_layout_inference.h +++ b/src/tim/transform/ops/pad_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { class PadLayoutInfer : public OpLayoutInfer { diff --git a/src/tim/transform/ops/pool2d_layout_inference.h b/src/tim/transform/ops/pool2d_layout_inference.h index 9954a7e..c36896b 100644 --- a/src/tim/transform/ops/pool2d_layout_inference.h +++ b/src/tim/transform/ops/pool2d_layout_inference.h @@ -26,7 +26,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "tim/vx/ops/pool2d.h" namespace tim { diff --git a/src/tim/transform/ops/reduce_layout_inference.h b/src/tim/transform/ops/reduce_layout_inference.h index e8ec676..3480798 100644 --- a/src/tim/transform/ops/reduce_layout_inference.h +++ b/src/tim/transform/ops/reduce_layout_inference.h @@ -30,7 +30,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { diff --git a/src/tim/transform/ops/resize_layout_inference.h b/src/tim/transform/ops/resize_layout_inference.h index 530907f..363cdf0 100644 --- a/src/tim/transform/ops/resize_layout_inference.h +++ b/src/tim/transform/ops/resize_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { class ResizeLayoutInfer : public OpLayoutInfer { diff --git a/src/tim/transform/ops/reverse_layout_inference.h b/src/tim/transform/ops/reverse_layout_inference.h index f999e27..9f7f17d 100644 --- a/src/tim/transform/ops/reverse_layout_inference.h +++ b/src/tim/transform/ops/reverse_layout_inference.h @@ -25,7 +25,7 @@ #define TIM_LAYOUT_INFER_REVERSE_LAYOUT_INFERENCE_H_ #include "ops/op_layout_inference.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "tim/vx/ops/reverse.h" namespace tim { diff --git a/src/tim/transform/ops/select_layout_inference.h b/src/tim/transform/ops/select_layout_inference.h index 60dd898..b3046be 100644 --- a/src/tim/transform/ops/select_layout_inference.h +++ b/src/tim/transform/ops/select_layout_inference.h @@ -25,7 +25,7 @@ #define TIM_LAYOUT_INFER_SELECT_LAYOUT_INFERENCE_H_ #include "ops/op_layout_inference.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "tim/vx/ops/select.h" namespace tim { diff --git a/src/tim/transform/ops/simple_ops_layout_inference.h b/src/tim/transform/ops/simple_ops_layout_inference.h index a1e92d2..5c50c66 100644 --- a/src/tim/transform/ops/simple_ops_layout_inference.h +++ b/src/tim/transform/ops/simple_ops_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { diff --git a/src/tim/transform/ops/slice_layout_inference.h b/src/tim/transform/ops/slice_layout_inference.h index 9b6079e..abfb846 100644 --- a/src/tim/transform/ops/slice_layout_inference.h +++ b/src/tim/transform/ops/slice_layout_inference.h @@ -25,7 +25,7 @@ #define TIM_LAYOUT_INFER_SLICE_LAYOUT_INFERENCE_H_ #include "ops/op_layout_inference.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "tim/vx/ops/slice.h" namespace tim { diff --git a/src/tim/transform/ops/softmax_layout_inference.h b/src/tim/transform/ops/softmax_layout_inference.h index 1b8a21f..cbdfa04 100644 --- a/src/tim/transform/ops/softmax_layout_inference.h +++ b/src/tim/transform/ops/softmax_layout_inference.h @@ -26,7 +26,7 @@ #include "tim/vx/ops/softmax.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "permute_vector.h" #include "ops/op_layout_inference.h" diff --git a/src/tim/transform/ops/space2batch_layout_inference.h b/src/tim/transform/ops/space2batch_layout_inference.h index 77fdfd9..f59683b 100644 --- a/src/tim/transform/ops/space2batch_layout_inference.h +++ b/src/tim/transform/ops/space2batch_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { class Space2BatchLayoutInfer : public OpLayoutInfer { diff --git a/src/tim/transform/ops/space2depth_layout_inference.h b/src/tim/transform/ops/space2depth_layout_inference.h index fbc89cb..f0111c8 100644 --- a/src/tim/transform/ops/space2depth_layout_inference.h +++ b/src/tim/transform/ops/space2depth_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { class SpaceToDepthLayoutInfer : public OpLayoutInfer { diff --git a/src/tim/transform/ops/split_layout_inference.h b/src/tim/transform/ops/split_layout_inference.h index 6fceec9..fe93e3d 100644 --- a/src/tim/transform/ops/split_layout_inference.h +++ b/src/tim/transform/ops/split_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { diff --git a/src/tim/transform/ops/squeeze_layout_inference.h b/src/tim/transform/ops/squeeze_layout_inference.h index 0fe8d67..d58b91b 100644 --- a/src/tim/transform/ops/squeeze_layout_inference.h +++ b/src/tim/transform/ops/squeeze_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { diff --git a/src/tim/transform/ops/stack_layout_inference.h b/src/tim/transform/ops/stack_layout_inference.h index 4161245..87492d1 100644 --- a/src/tim/transform/ops/stack_layout_inference.h +++ b/src/tim/transform/ops/stack_layout_inference.h @@ -26,7 +26,7 @@ #include "tim/vx/ops/stack.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "permute_vector.h" #include "ops/op_layout_inference.h" diff --git a/src/tim/transform/ops/stridedslice_layout_inference.h b/src/tim/transform/ops/stridedslice_layout_inference.h index afbdebf..058d2db 100644 --- a/src/tim/transform/ops/stridedslice_layout_inference.h +++ b/src/tim/transform/ops/stridedslice_layout_inference.h @@ -28,7 +28,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace transform { diff --git a/src/tim/vx/direct_map_op.cc b/src/tim/vx/direct_map_op.cc new file mode 100644 index 0000000..569d988 --- /dev/null +++ b/src/tim/vx/direct_map_op.cc @@ -0,0 +1,36 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/direct_map_op.h" + +#include "direct_map_op_impl.h" + +namespace tim { +namespace vx { +DirectMapOp::DirectMapOp(Graph* graph, uint32_t kind, int in_cnt, int out_cnt, + DataLayout layout) { + impl_ = std::make_unique(graph, kind, in_cnt, out_cnt, layout); +} + +} // namespace vx +} // namespace tim \ No newline at end of file diff --git a/src/tim/vx/direct_map_op_impl.cc b/src/tim/vx/direct_map_op_impl.cc new file mode 100644 index 0000000..1701604 --- /dev/null +++ b/src/tim/vx/direct_map_op_impl.cc @@ -0,0 +1,75 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "direct_map_op_impl.h" +#include "type_utils.h" + +namespace tim{ +namespace vx{ + +DirectMapOpImpl::DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt, + int output_cnt, DataLayout layout) + : OpImpl(graph, kind, input_cnt, output_cnt, layout), + node_(vsi_nn_AddNode(graph_->graph(), kind_, input_cnt_, output_cnt_, + NULL)) { + SetRoundingPolicy(); + node_->uid = graph_->graph()->cur_nid; +} + +DirectMapOpImpl& DirectMapOpImpl::BindInput(const std::shared_ptr& tensor) { + inputs_tensor_.push_back(tensor); + uint32_t tensor_id = tensor->GetId(); + node_->input.tensors[input_tensor_index++] = tensor_id; + if (tensor->GetSpec().attr_ & TensorAttribute::INPUT) { + graph_->AddInput(tensor_id); + graph_->AddInput(tensor); + } + return *this; +} + +DirectMapOpImpl& DirectMapOpImpl::BindOutput( + const std::shared_ptr& tensor) { + outputs_tensor_.push_back(tensor); + uint32_t tensor_id = tensor->GetId(); + node_->output.tensors[output_tensor_index++] = tensor_id; + if (tensor->GetSpec().attr_ == TensorAttribute::OUTPUT) { + graph_->AddOutput(tensor_id); + graph_->AddOutput(tensor); + } + return *this; +} + +void DirectMapOpImpl::SetRoundingPolicy( + OverflowPolicy overflow_policy, + RoundingPolicy rounding_policy, + RoundType down_scale_size_rounding, + uint32_t accumulator_bits) { + node_->vx_param.overflow_policy = TranslateOverflowPolicy(overflow_policy); + node_->vx_param.rounding_policy = TranslateRoundingPolicy(rounding_policy); + node_->vx_param.down_scale_size_rounding = + TranslateDownScaleSizeRounding(down_scale_size_rounding); + node_->vx_param.accumulator_bits = accumulator_bits; +} + +} +} \ No newline at end of file diff --git a/src/tim/vx/operation_private.h b/src/tim/vx/direct_map_op_impl.h similarity index 64% rename from src/tim/vx/operation_private.h rename to src/tim/vx/direct_map_op_impl.h index 176b8a1..e4ff432 100644 --- a/src/tim/vx/operation_private.h +++ b/src/tim/vx/direct_map_op_impl.h @@ -1,6 +1,6 @@ /**************************************************************************** * -* Copyright (c) 2020 Vivante Corporation +* Copyright (c) 2021 Vivante Corporation * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -21,50 +21,49 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ -#ifndef TIM_VX_OPERATION_PRIVATE_H_ -#define TIM_VX_OPERATION_PRIVATE_H_ -#include "graph_private.h" +#ifndef TIM_VX_DIRECT_MAP_OP_IMPL_H_ +#define TIM_VX_DIRECT_MAP_OP_IMPL_H_ + + #include "vsi_nn_pub.h" +#include "graph_private.h" + +#include "op_impl.h" namespace tim { namespace vx { -class OperationImpl { - public: - // OperationImpl(Graph* graph, uint32_t operation_id, int input_cnt = 0, - // int output_cnt = 0); - OperationImpl(Graph* graph, uint32_t operation_id, int input_cnt = 0, - int output_cnt = 0, DataLayout layout = DataLayout::ANY); - ~OperationImpl() {} - OperationImpl& BindInput(const std::shared_ptr& tensor); - OperationImpl& BindOutput(const std::shared_ptr& tensor); - OperationImpl& SetRoundingPolicy( +class DirectMapOpImpl : public OpImpl { + public: + // DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt = 0, + // int output_cnt = 0); + DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt = 0, + int output_cnt = 0, DataLayout layout = DataLayout::ANY); + ~DirectMapOpImpl() {} + + DirectMapOpImpl& BindInput(const std::shared_ptr& tensor) override; + DirectMapOpImpl& BindOutput(const std::shared_ptr& tensor) override; + + vsi_nn_node_t* node() override { return this->node_; } + + void SetRoundingPolicy( OverflowPolicy overflow_policy = OverflowPolicy::SATURATE, RoundingPolicy rounding_policy = RoundingPolicy::RTNE, RoundType down_scale_size_rounding = RoundType::FLOOR, - uint32_t accumulator_bits = 0); + uint32_t accumulator_bits =0); - vsi_nn_node_t* node() { return this->node_; } - - std::vector> InputsTensor() { return inputs_tensor_; } + std::vector> InputsTensor() { + return inputs_tensor_; + } std::vector> OutputsTensor() { return outputs_tensor_; } - GraphImpl* graph_; - uint32_t operation_id_{0}; - int32_t input_cnt_{0}; - int32_t output_cnt_{0}; - DataLayout layout_{DataLayout::ANY}; + protected: vsi_nn_node_t* node_{nullptr}; - int32_t input_tensor_index{0}; - int32_t output_tensor_index{0}; - std::vector> inputs_tensor_; - std::vector> outputs_tensor_; - }; } // namespace vx } // namespace tim -#endif /* TIM_VX_OPERATION_PRIVATE_H_ */ +#endif /* TIM_VX_DIRECT_MAP_OP_IMPL_H_ */ diff --git a/src/tim/vx/graph.cc b/src/tim/vx/graph.cc index 9a47c50..06a6d9e 100644 --- a/src/tim/vx/graph.cc +++ b/src/tim/vx/graph.cc @@ -26,7 +26,7 @@ #include "context_private.h" #include "graph_private.h" -#include "operation_private.h" +#include "op_impl.h" #include "tensor_private.h" #include "tim/vx/context.h" #include "tim/vx/ops/nbg.h" diff --git a/src/tim/vx/op_impl.cc b/src/tim/vx/op_impl.cc new file mode 100644 index 0000000..3d60375 --- /dev/null +++ b/src/tim/vx/op_impl.cc @@ -0,0 +1,37 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "op_impl.h" + +namespace tim{ +namespace vx{ + +OpImpl::OpImpl(Graph* graph, uint32_t kind, int input_cnt, int output_cnt, + DataLayout layout) + : graph_(reinterpret_cast(graph)), + kind_(kind), + input_cnt_(input_cnt), + output_cnt_(output_cnt), + layout_(layout) {} +} +} diff --git a/src/tim/vx/op_impl.h b/src/tim/vx/op_impl.h new file mode 100644 index 0000000..582ba79 --- /dev/null +++ b/src/tim/vx/op_impl.h @@ -0,0 +1,58 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ + +#ifndef TIM_VX_OP_IMPL_H_ +#define TIM_VX_OP_IMPL_H_ +#include "graph_private.h" +#include "tim/vx/graph.h" +#include "tim/vx/types.h" + +namespace tim { +namespace vx { + +class OpImpl { + public: + OpImpl(Graph* graph, uint32_t kind, int input_cnt, int output_cnt, + DataLayout layout); + virtual OpImpl& BindInput(const std::shared_ptr& tensor) = 0; + virtual OpImpl& BindOutput(const std::shared_ptr& tensor) = 0; + virtual std::vector> InputsTensor() = 0; + virtual std::vector> OutputsTensor() = 0; + virtual vsi_nn_node_t* node() = 0; + + GraphImpl* graph_; + uint32_t kind_{0}; + int32_t input_cnt_{0}; + int32_t output_cnt_{0}; + DataLayout layout_{DataLayout::ANY}; + int32_t input_tensor_index{0}; + int32_t output_tensor_index{0}; + std::vector> inputs_tensor_; + std::vector> outputs_tensor_; +}; + +} // namespace vx +} // namespace tim + +#endif \ No newline at end of file diff --git a/src/tim/vx/operation.cc b/src/tim/vx/operation.cc index f3c825a..c24ab90 100644 --- a/src/tim/vx/operation.cc +++ b/src/tim/vx/operation.cc @@ -22,75 +22,22 @@ * *****************************************************************************/ #include "tim/vx/operation.h" - #include +#include "op_impl.h" #include "graph_private.h" -#include "operation_private.h" #include "type_utils.h" #include "vsi_nn_pub.h" namespace tim { namespace vx { -OperationImpl::OperationImpl(Graph* graph, uint32_t operation_id, - int input_cnt, int output_cnt, DataLayout layout) - : graph_(reinterpret_cast(graph)), - operation_id_(operation_id), - input_cnt_(input_cnt), - output_cnt_(output_cnt), - layout_(layout), - node_(vsi_nn_AddNode(graph_->graph(), operation_id_, input_cnt_, - output_cnt_, NULL)) { - SetRoundingPolicy(); - node_->uid = graph_->graph()->cur_nid; -} - -OperationImpl& OperationImpl::BindInput(const std::shared_ptr& tensor) { - inputs_tensor_.push_back(tensor); - uint32_t tensor_id = tensor->GetId(); - node_->input.tensors[input_tensor_index++] = tensor_id; - if (tensor->GetSpec().attr_ & TensorAttribute::INPUT) { - graph_->AddInput(tensor_id); - graph_->AddInput(tensor); - } - return *this; -} - -OperationImpl& OperationImpl::BindOutput( - const std::shared_ptr& tensor) { - outputs_tensor_.push_back(tensor); - uint32_t tensor_id = tensor->GetId(); - node_->output.tensors[output_tensor_index++] = tensor_id; - if (tensor->GetSpec().attr_ == TensorAttribute::OUTPUT) { - graph_->AddOutput(tensor_id); - graph_->AddOutput(tensor); - } - return *this; -} - -OperationImpl& OperationImpl::SetRoundingPolicy( - OverflowPolicy overflow_policy, RoundingPolicy rounding_policy, - RoundType down_scale_size_rounding, uint32_t accumulator_bits) { - node_->vx_param.overflow_policy = TranslateOverflowPolicy(overflow_policy); - node_->vx_param.rounding_policy = TranslateRoundingPolicy(rounding_policy); - node_->vx_param.down_scale_size_rounding = - TranslateDownScaleSizeRounding(down_scale_size_rounding); - node_->vx_param.accumulator_bits = accumulator_bits; - - return *this; -} - // Operation implementation -Operation::Operation(Graph* graph, uint32_t operation_id, - int input_cnt, int output_cnt, DataLayout layout) { - impl_ = std::make_unique(graph, operation_id, - input_cnt, output_cnt, layout); -} +Operation::Operation() {} Operation::~Operation() {} -std::unique_ptr& Operation::impl() { return impl_; } -const std::unique_ptr& Operation::impl() const { return impl_; } +std::unique_ptr& Operation::impl() { return impl_; } +const std::unique_ptr& Operation::impl() const { return impl_; } Operation& Operation::BindInput(const std::shared_ptr& tensor) { impl_->BindInput(tensor); @@ -107,8 +54,9 @@ Operation& Operation::BindOutput(const std::shared_ptr& tensor) { Operation& Operation::SetRoundingPolicy( OverflowPolicy overflow_policy, RoundingPolicy rounding_policy, RoundType down_scale_size_rounding, uint32_t accumulator_bits) { - impl_->SetRoundingPolicy(overflow_policy, rounding_policy, - down_scale_size_rounding, accumulator_bits); + // impl_->SetRoundingPolicy(overflow_policy, rounding_policy, + // down_scale_size_rounding, accumulator_bits); + (void) overflow_policy;(void) rounding_policy;(void) down_scale_size_rounding;(void) accumulator_bits; return *this; } diff --git a/src/tim/vx/ops/README.md b/src/tim/vx/ops/README.md index 89d698a..b222db4 100644 --- a/src/tim/vx/ops/README.md +++ b/src/tim/vx/ops/README.md @@ -103,6 +103,7 @@ Svdf|SVDF|Mapped|[ANEURALNETWORKS_SVDF](https://developer.android.com/ndk/refere Erf|ERF|Mapped|[tf.math.erf](https://tensorflow.google.cn/api_docs/python/tf/math/erf) GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/Conv1D?hl=en) |SignalFrame|SIGNAL_FRAME|Mapped|[tf.signal.frame](https://tensorflow.google.cn/api_docs/python/tf/signal/frame) +|RNNCell|RNNCELL_OVXLIB|Mapped|[ANEURALNETWORKS_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0acd2684ac9c73bb29767b534e78a332e8) ||PROPOSAL| TBD |[Faster-RCNN Proposal Layer](https://github.com/intel/caffe/blob/master/examples/faster-rcnn/lib/rpn/proposal_layer.py) ||ROI_POOL|Planned 22Q1 |[ANEURALNETWORKS_ROI_POOLING](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a6736198af337b2efbdb0b6b64dee7fe4) ||ROI_ALIGN| TBD |[ANEURALNETWORKS_ROI_ALIGN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a2848b39dd4bfba78f2438fda0d9397a4) @@ -111,7 +112,6 @@ GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow. |UnidirectionalSequenceGRU|GRU_OVXLIB|Planned 21Q4|[tf.keras.layers.GRU](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/GRUCell?hl=en) |UnidirectionalSequenceRNN|UNIDIRECTIONAL_SEQUENCE_RNN|Planned 21Q4|[ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0ae11aa1d461d2abaa117f6ee2cb503dd8) |BidirectionalSequenceRNN|BIDIRECTIONAL_SEQUENCE_RNN|Planned 21Q4|[ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a487fc5ae247de828f13e62b99f259f3c) -|RNNCell|RNNCELL_OVXLIB|Planned 21Q3|[ANEURALNETWORKS_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0acd2684ac9c73bb29767b534e78a332e8) |BidirectionalSequenceLSTM|BIDIRECTIONAL_SEQUENCE_LSTM|Planned 21Q4|[ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a492a71cb7aa50b9a1a834a3cb269d778) |UnidirectionalSequenceLSTM|LSTM_OVXLIB|Mapped|[ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_LSTM](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0aaf30e491ad0b1fc7602cbde695b2c859) |LSTMCell|LSTMUNIT_OVXLIB|replace with UnidirectionalSequenceLSTM by set n_step = 1 |[ANEURALNETWORKS_LSTM](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0ad0377e8c305e596fb7f64ff896671fc5) diff --git a/src/tim/vx/ops/activations.cc b/src/tim/vx/ops/activations.cc index 855a4d1..11e7243 100644 --- a/src/tim/vx/ops/activations.cc +++ b/src/tim/vx/ops/activations.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/activations.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +31,7 @@ namespace vx { namespace ops { #define DEFINE_NO_PARAMETER_ACTIVATION(NAME, VSI_OP_CODE) \ - NAME::NAME(Graph* graph) : Operation(graph, VSI_OP_CODE) {} \ + NAME::NAME(Graph* graph) : DirectMapOp(graph, VSI_OP_CODE) {} \ std::shared_ptr NAME::Clone(std::shared_ptr& graph) \ const { \ return graph->CreateOperation(); \ @@ -49,7 +49,7 @@ DEFINE_NO_PARAMETER_ACTIVATION(SoftRelu, VSI_NN_OP_SOFTRELU) #undef DEFINE_NO_PARAMETER_ACTIVATION -HardSwish::HardSwish(Graph* graph) : Operation(graph, VSI_NN_OP_SWISH) { +HardSwish::HardSwish(Graph* graph) : DirectMapOp(graph, VSI_NN_OP_SWISH) { this->impl()->node()->nn_param.swish.type = VSI_NN_HSWISH; this->impl()->node()->nn_param.swish.beta = 1.0f; } @@ -59,7 +59,7 @@ std::shared_ptr HardSwish::Clone( return graph->CreateOperation(); } -Swish::Swish(Graph* graph) : Operation(graph, VSI_NN_OP_SWISH) { +Swish::Swish(Graph* graph) : DirectMapOp(graph, VSI_NN_OP_SWISH) { this->impl()->node()->nn_param.swish.type = VSI_NN_SWISH; this->impl()->node()->nn_param.swish.beta = 1.0f; } @@ -70,7 +70,7 @@ std::shared_ptr Swish::Clone( } Prelu::Prelu(Graph* graph, int axis) - : Operation(graph, VSI_NN_OP_PRELU), axis_(axis) { + : DirectMapOp(graph, VSI_NN_OP_PRELU), axis_(axis) { this->impl()->node()->nn_param.prelu.axis = axis_; } @@ -78,7 +78,7 @@ std::shared_ptr Prelu::Clone(std::shared_ptr& graph) const { return graph->CreateOperation(this->axis_); } -Tanh::Tanh(Graph* graph) : Operation(graph, VSI_NN_OP_TANH) { +Tanh::Tanh(Graph* graph) : DirectMapOp(graph, VSI_NN_OP_TANH) { this->impl()->node()->nn_param.tanh.scale_a = 1.0; this->impl()->node()->nn_param.tanh.scale_b = 1.0; } @@ -88,7 +88,7 @@ std::shared_ptr Tanh::Clone(std::shared_ptr& graph) const { } LeakyRelu::LeakyRelu(Graph* graph, float alpha) - : Operation(graph, VSI_NN_OP_LEAKY_RELU), alpha_(alpha) { + : DirectMapOp(graph, VSI_NN_OP_LEAKY_RELU), alpha_(alpha) { this->impl()->node()->nn_param.activation.leaky_ratio = alpha_; } @@ -98,7 +98,7 @@ std::shared_ptr LeakyRelu::Clone( } Linear::Linear(Graph* graph, float a, float b) - : Operation(graph, VSI_NN_OP_LINEAR), a_(a), b_(b) { + : DirectMapOp(graph, VSI_NN_OP_LINEAR), a_(a), b_(b) { this->impl()->node()->nn_param.linear.a = a_; this->impl()->node()->nn_param.linear.b = b_; } @@ -108,7 +108,7 @@ std::shared_ptr Linear::Clone(std::shared_ptr& graph) const { } Gelu::Gelu(Graph* graph, bool approximate) - : Operation(graph, VSI_NN_OP_GELU){ + : DirectMapOp(graph, VSI_NN_OP_GELU){ this->impl()->node()->nn_param.gelu.approximate = approximate; } diff --git a/src/tim/vx/ops/addn.cc b/src/tim/vx/ops/addn.cc index fd6c86c..bc51b96 100644 --- a/src/tim/vx/ops/addn.cc +++ b/src/tim/vx/ops/addn.cc @@ -23,7 +23,8 @@ *****************************************************************************/ #include "tim/vx/ops/addn.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" + #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +32,7 @@ namespace vx { namespace ops { AddN::AddN(Graph* graph, uint32_t num_inputs) - : Operation(graph, VSI_NN_OP_ADDN, num_inputs, 1) {} + : DirectMapOp(graph, VSI_NN_OP_ADDN, num_inputs, 1) {} std::shared_ptr AddN::Clone(std::shared_ptr& graph) const { return graph->CreateOperation(this->impl_->input_cnt_); diff --git a/src/tim/vx/ops/arg.cc b/src/tim/vx/ops/arg.cc index 9d1e586..d674c92 100644 --- a/src/tim/vx/ops/arg.cc +++ b/src/tim/vx/ops/arg.cc @@ -25,7 +25,7 @@ #include "vsi_nn_pub.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace vx { @@ -33,7 +33,7 @@ namespace ops { #define DEFINE_ARG_OP(NAME, VSI_OP_TYPE, OP_PARAM) \ Arg##NAME::Arg##NAME(Graph* graph, int32_t axis) \ - : Operation(graph, VSI_NN_OP_ARG##VSI_OP_TYPE), axis_(axis) { \ + : DirectMapOp(graph, VSI_NN_OP_ARG##VSI_OP_TYPE), axis_(axis) { \ this->impl()->node()->nn_param.arg##OP_PARAM.axis = axis_; \ } \ std::shared_ptr Arg##NAME::Clone(std::shared_ptr& graph) \ diff --git a/src/tim/vx/ops/batch2space.cc b/src/tim/vx/ops/batch2space.cc index 1d3f07c..7d2c604 100644 --- a/src/tim/vx/ops/batch2space.cc +++ b/src/tim/vx/ops/batch2space.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/batch2space.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -32,7 +32,7 @@ namespace ops { Batch2Space::Batch2Space(Graph* graph, const std::vector& block_size, const std::vector& crop, DataLayout layout) - : Operation(graph, VSI_NN_OP_BATCH2SPACE, 0, 0, layout), + : DirectMapOp(graph, VSI_NN_OP_BATCH2SPACE, 0, 0, layout), block_size_(block_size), crop_(crop) { this->impl()->node()->nn_param.batch2space.block_size = block_size_.data(); diff --git a/src/tim/vx/ops/batchnorm.cc b/src/tim/vx/ops/batchnorm.cc index 3685259..d45ded0 100644 --- a/src/tim/vx/ops/batchnorm.cc +++ b/src/tim/vx/ops/batchnorm.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/batchnorm.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,14 +31,14 @@ namespace vx { namespace ops { BatchNorm::BatchNorm(Graph* graph, float eps, DataLayout input_layout) - : Operation(graph, VSI_NN_OP_BATCH_NORM, 0, 0, input_layout), eps_(eps) { + : DirectMapOp(graph, VSI_NN_OP_BATCH_NORM, 0, 0, input_layout), eps_(eps) { this->impl()->node()->nn_param.batch_norm.eps = eps_; } std::shared_ptr BatchNorm::Clone( std::shared_ptr& graph) const { return graph->CreateOperation( - this->impl_->node_->nn_param.batch_norm.eps); + this->impl_->node()->nn_param.batch_norm.eps); } } // namespace ops diff --git a/src/tim/vx/ops/clip.cc b/src/tim/vx/ops/clip.cc index af51aca..f06a2a6 100644 --- a/src/tim/vx/ops/clip.cc +++ b/src/tim/vx/ops/clip.cc @@ -25,7 +25,7 @@ #include "vsi_nn_pub.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace vx { @@ -33,7 +33,7 @@ namespace ops { Clip::Clip(Graph* graph, float min, float max) - : Operation(graph, VSI_NN_OP_CLIP), + : DirectMapOp(graph, VSI_NN_OP_CLIP), min_(min), max_(max) { this->impl()->node()->nn_param.clip.min = min_; @@ -42,7 +42,7 @@ Clip::Clip(Graph* graph, float min, float max) std::shared_ptr Clip::Clone(std::shared_ptr& graph) const { return graph->CreateOperation(this->impl_->node()->nn_param.clip.min, - this->impl_->node_->nn_param.clip.max); + this->impl_->node()->nn_param.clip.max); } } // namespace ops diff --git a/src/tim/vx/ops/concat.cc b/src/tim/vx/ops/concat.cc index 86bda04..56e9f8d 100644 --- a/src/tim/vx/ops/concat.cc +++ b/src/tim/vx/ops/concat.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/concat.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +31,7 @@ namespace vx { namespace ops { Concat::Concat(Graph* graph, uint32_t axis, int input_cnt) - : Operation(graph, VSI_NN_OP_CONCAT, input_cnt, 1), axis_(axis) { + : DirectMapOp(graph, VSI_NN_OP_CONCAT, input_cnt, 1), axis_(axis) { this->impl()->node()->nn_param.concat.axis = axis_; } diff --git a/src/tim/vx/ops/conv1d.cc b/src/tim/vx/ops/conv1d.cc index 413be40..c35dd5e 100644 --- a/src/tim/vx/ops/conv1d.cc +++ b/src/tim/vx/ops/conv1d.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/conv1d.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -54,7 +54,7 @@ Conv1d::Conv1d(Graph* graph, int32_t weights, PadType padding, uint32_t ksize, uint32_t stride, uint32_t dilation, const std::array& pad, int32_t multiplier, DataLayout input_layout, DataLayout kernel_layout) - : Operation(graph, VSI_NN_OP_CONV1D, 0, 0, input_layout), + : DirectMapOp(graph, VSI_NN_OP_CONV1D, 0, 0, input_layout), weights_(weights), padding_(padding), ksize_(ksize), diff --git a/src/tim/vx/ops/conv2d.cc b/src/tim/vx/ops/conv2d.cc index 66faf5b..14b50d8 100644 --- a/src/tim/vx/ops/conv2d.cc +++ b/src/tim/vx/ops/conv2d.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/conv2d.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -59,7 +59,7 @@ Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding, const std::array& dilation, const std::array& pad, int32_t multiplier, DataLayout input_layout, DataLayout kernel_layout) - : Operation(graph, VSI_NN_OP_CONV2D, 0, 0, input_layout), + : DirectMapOp(graph, VSI_NN_OP_CONV2D, 0, 0, input_layout), weights_(weights), padding_(padding), ksize_(ksize), diff --git a/src/tim/vx/ops/deconv.cc b/src/tim/vx/ops/deconv.cc index 7704555..af043f9 100644 --- a/src/tim/vx/ops/deconv.cc +++ b/src/tim/vx/ops/deconv.cc @@ -25,7 +25,7 @@ #include -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -50,7 +50,7 @@ DeConv2d::DeConv2d(Graph* graph, int32_t oc_count, PadType pad_type, const uint32_t group, DataLayout input_layout, DataLayout kernel_layout) - : Operation(graph, VSI_NN_OP_DECONVOLUTION, 0, 0, input_layout), + : DirectMapOp(graph, VSI_NN_OP_DECONVOLUTION, 0, 0, input_layout), oc_count_(oc_count), pad_type_(pad_type), ksize_(ksize), diff --git a/src/tim/vx/ops/deconv1d.cc b/src/tim/vx/ops/deconv1d.cc index f833e99..dfc8d0f 100644 --- a/src/tim/vx/ops/deconv1d.cc +++ b/src/tim/vx/ops/deconv1d.cc @@ -25,7 +25,7 @@ #include -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -68,7 +68,7 @@ DeConv1d::DeConv1d(Graph* graph, PadType pad_type, uint32_t stride, uint32_t output_padding, const std::array& pad, uint32_t group, DataLayout input_layout, DataLayout kernel_layout) - : Operation(graph, VSI_NN_OP_DECONVOLUTION1D, 3, 1, input_layout), + : DirectMapOp(graph, VSI_NN_OP_DECONVOLUTION1D, 3, 1, input_layout), oc_count_(0), pad_type_(pad_type), ksize_(0), diff --git a/src/tim/vx/ops/depth2space.cc b/src/tim/vx/ops/depth2space.cc index 6f35494..20b0328 100644 --- a/src/tim/vx/ops/depth2space.cc +++ b/src/tim/vx/ops/depth2space.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/depth2space.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +31,7 @@ namespace vx { namespace ops { DepthToSpace::DepthToSpace(Graph* graph, int block_size, DataLayout layout) - : Operation(graph, VSI_NN_OP_DEPTH2SPACE, 0, 0, layout), + : DirectMapOp(graph, VSI_NN_OP_DEPTH2SPACE, 0, 0, layout), block_size_(block_size) { this->impl()->node()->nn_param.depth2space.block_size = block_size_; } diff --git a/src/tim/vx/ops/dropout.cc b/src/tim/vx/ops/dropout.cc index 8869846..1c94042 100644 --- a/src/tim/vx/ops/dropout.cc +++ b/src/tim/vx/ops/dropout.cc @@ -25,7 +25,7 @@ #include "vsi_nn_pub.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" namespace tim { namespace vx { @@ -33,7 +33,7 @@ namespace ops { Dropout::Dropout(Graph* graph, float ratio) - : Operation(graph, VSI_NN_OP_DROPOUT), + : DirectMapOp(graph, VSI_NN_OP_DROPOUT), ratio_(ratio) { this->impl()->node()->nn_param.dropout.ratio = ratio_; } diff --git a/src/tim/vx/ops/elementwise.cc b/src/tim/vx/ops/elementwise.cc index 1589949..726b7d5 100644 --- a/src/tim/vx/ops/elementwise.cc +++ b/src/tim/vx/ops/elementwise.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/elementwise.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +31,7 @@ namespace vx { namespace ops { #define DEFINE_ELEMENTWISE_OP(NAME, VSI_OP_CODE) \ - NAME::NAME(Graph* graph) : Operation(graph, VSI_OP_CODE, 2, 1) {} \ + NAME::NAME(Graph* graph) : DirectMapOp(graph, VSI_OP_CODE, 2, 1) {} \ std::shared_ptr NAME::Clone(std::shared_ptr& graph) \ const { \ return graph->CreateOperation(); \ @@ -47,25 +47,25 @@ DEFINE_ELEMENTWISE_OP(FloorDiv, VSI_NN_OP_FLOORDIV) #undef DEFINE_ELEMENTWISE_OP Multiply::Multiply(Graph* graph, float scale) - : Operation(graph, VSI_NN_OP_MULTIPLY, 2, 1) { + : DirectMapOp(graph, VSI_NN_OP_MULTIPLY, 2, 1) { this->impl()->node()->nn_param.multiply.scale = scale; } std::shared_ptr Multiply::Clone( std::shared_ptr& graph) const { return graph->CreateOperation( - this->impl_->node_->nn_param.multiply.scale); + this->impl_->node()->nn_param.multiply.scale); } Div::Div(Graph* graph, float scale) - : Operation(graph, VSI_NN_OP_DIVIDE, 2, 1) { + : DirectMapOp(graph, VSI_NN_OP_DIVIDE, 2, 1) { this->impl()->node()->nn_param.divide.scale = scale; } std::shared_ptr Div::Clone( std::shared_ptr& graph) const { return graph->CreateOperation
( - this->impl_->node_->nn_param.divide.scale); + this->impl_->node()->nn_param.divide.scale); } } // namespace ops diff --git a/src/tim/vx/ops/erf.cc b/src/tim/vx/ops/erf.cc index 40c737f..e72e265 100644 --- a/src/tim/vx/ops/erf.cc +++ b/src/tim/vx/ops/erf.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/erf.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -31,7 +31,7 @@ namespace tim { namespace vx { namespace ops { -Erf::Erf(Graph* graph) : Operation(graph, VSI_NN_OP_ERF) {} +Erf::Erf(Graph* graph) : DirectMapOp(graph, VSI_NN_OP_ERF) {} std::shared_ptr Erf::Clone(std::shared_ptr& graph) const { return graph->CreateOperation(); diff --git a/src/tim/vx/ops/fullyconnected.cc b/src/tim/vx/ops/fullyconnected.cc index 2b296d8..73259a9 100644 --- a/src/tim/vx/ops/fullyconnected.cc +++ b/src/tim/vx/ops/fullyconnected.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/fullyconnected.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -35,7 +35,7 @@ FullyConnected::FullyConnected(Graph* graph, uint32_t axis) } FullyConnected::FullyConnected(Graph* graph, uint32_t axis, uint32_t weights) - : Operation(graph, VSI_NN_OP_FCL2), axis_(axis), weights_(weights) { + : DirectMapOp(graph, VSI_NN_OP_FCL2), axis_(axis), weights_(weights) { this->impl()->node()->nn_param.fcl.axis = axis; this->impl()->node()->nn_param.fcl.weights = weights; } diff --git a/src/tim/vx/ops/gather.cc b/src/tim/vx/ops/gather.cc index f0ef74e..b247249 100644 --- a/src/tim/vx/ops/gather.cc +++ b/src/tim/vx/ops/gather.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/gather.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +31,7 @@ namespace vx { namespace ops { Gather::Gather(Graph* graph, int axis) - : Operation(graph, VSI_NN_OP_GATHER), axis_(axis) { + : DirectMapOp(graph, VSI_NN_OP_GATHER), axis_(axis) { this->impl()->node()->nn_param.gather.axis = axis_; } diff --git a/src/tim/vx/ops/gathernd.cc b/src/tim/vx/ops/gathernd.cc index 3428891..b6221f4 100644 --- a/src/tim/vx/ops/gathernd.cc +++ b/src/tim/vx/ops/gathernd.cc @@ -23,14 +23,14 @@ *****************************************************************************/ #include "tim/vx/ops/gathernd.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { namespace vx { namespace ops { -GatherNd::GatherNd(Graph* graph) : Operation(graph, VSI_NN_OP_GATHER_ND) {} +GatherNd::GatherNd(Graph* graph) : DirectMapOp(graph, VSI_NN_OP_GATHER_ND) {} std::shared_ptr GatherNd::Clone(std::shared_ptr& graph) const { return graph->CreateOperation(); diff --git a/src/tim/vx/ops/groupedconv1d.cc b/src/tim/vx/ops/groupedconv1d.cc index 273b936..46a8ca8 100644 --- a/src/tim/vx/ops/groupedconv1d.cc +++ b/src/tim/vx/ops/groupedconv1d.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/groupedconv1d.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -37,7 +37,7 @@ GroupedConv1d::GroupedConv1d(Graph* graph, const uint32_t dilation, uint32_t group, DataLayout input_layout, DataLayout kernel_layout) - : Operation(graph, VSI_NN_OP_GROUPED_CONV1D, 3, 1, input_layout), + : DirectMapOp(graph, VSI_NN_OP_GROUPED_CONV1D, 3, 1, input_layout), padding_(padding), stride_(stride), dilation_(dilation), pad_({0,0}), group_(group), kernel_layout_(kernel_layout) { diff --git a/src/tim/vx/ops/groupedconv2d.cc b/src/tim/vx/ops/groupedconv2d.cc index a6de274..a686e97 100644 --- a/src/tim/vx/ops/groupedconv2d.cc +++ b/src/tim/vx/ops/groupedconv2d.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/groupedconv2d.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -37,7 +37,7 @@ GroupedConv2d::GroupedConv2d(Graph* graph, const std::array& dilation, int32_t group_number, DataLayout input_layout, DataLayout kernel_layout) - : Operation(graph, VSI_NN_OP_GROUPED_CONV2D, 3, 1, input_layout), + : DirectMapOp(graph, VSI_NN_OP_GROUPED_CONV2D, 3, 1, input_layout), padding_(padding), strides_(strides), dilation_(dilation), pad_({0,0,0,0}), group_number_(group_number), kernel_layout_(kernel_layout) { @@ -55,7 +55,7 @@ GroupedConv2d::GroupedConv2d(Graph* graph, const std::array& dilation, int32_t group_number, DataLayout input_layout, DataLayout kernel_layout) - : Operation(graph, VSI_NN_OP_GROUPED_CONV2D, 3, 1, input_layout), + : DirectMapOp(graph, VSI_NN_OP_GROUPED_CONV2D, 3, 1, input_layout), padding_(PadType::AUTO), strides_(strides), dilation_(dilation), pad_(pad), group_number_(group_number), kernel_layout_(kernel_layout) { this->impl()->node()->nn_param.grouped_conv2d.stride[0] = strides_[0]; diff --git a/src/tim/vx/ops/instancenormalization.cc b/src/tim/vx/ops/instancenormalization.cc index 0fd4423..de2989a 100644 --- a/src/tim/vx/ops/instancenormalization.cc +++ b/src/tim/vx/ops/instancenormalization.cc @@ -23,14 +23,14 @@ *****************************************************************************/ #include "tim/vx/ops/instancenormalization.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { namespace vx { namespace ops { InstanceNormalization::InstanceNormalization(Graph* graph, float eps) - : Operation(graph, VSI_NN_OP_INSTANCE_NORM), eps_(eps) { + : DirectMapOp(graph, VSI_NN_OP_INSTANCE_NORM), eps_(eps) { this->impl()->node()->nn_param.instancenorm.eps = eps_; } diff --git a/src/tim/vx/ops/l2normalization.cc b/src/tim/vx/ops/l2normalization.cc index ec4d1b9..1a7a688 100644 --- a/src/tim/vx/ops/l2normalization.cc +++ b/src/tim/vx/ops/l2normalization.cc @@ -23,14 +23,14 @@ *****************************************************************************/ #include "tim/vx/ops/l2normalization.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { namespace vx { namespace ops { L2Normalization::L2Normalization(Graph* graph, int32_t axis) - : Operation(graph, VSI_NN_OP_L2_NORMALIZE), axis_(axis) { + : DirectMapOp(graph, VSI_NN_OP_L2_NORMALIZE), axis_(axis) { this->impl()->node()->nn_param.l2_normalize.axis = axis_; } diff --git a/src/tim/vx/ops/layernormalization.cc b/src/tim/vx/ops/layernormalization.cc index 19509c4..6c9278c 100644 --- a/src/tim/vx/ops/layernormalization.cc +++ b/src/tim/vx/ops/layernormalization.cc @@ -24,14 +24,14 @@ #include "tim/vx/ops/layernormalization.h" #include -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { namespace vx { namespace ops { LayerNormalization::LayerNormalization(Graph* graph, int32_t axis, float eps) - : Operation(graph, VSI_NN_OP_LAYER_NORM), axis_(axis), eps_(eps) { + : DirectMapOp(graph, VSI_NN_OP_LAYER_NORM), axis_(axis), eps_(eps) { // Layer normalization shares the parameters of instance normalization. if (axis != 0) { VSILOGE("Layer norm only support axis 0."); diff --git a/src/tim/vx/ops/localresponsenormalization.cc b/src/tim/vx/ops/localresponsenormalization.cc index 151ba9b..0937915 100644 --- a/src/tim/vx/ops/localresponsenormalization.cc +++ b/src/tim/vx/ops/localresponsenormalization.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/localresponsenormalization.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -33,7 +33,7 @@ LocalResponseNormalization::LocalResponseNormalization(Graph* graph, uint32_t size, float alpha, float beta, float bias, int32_t axis) - : Operation(graph, VSI_NN_OP_LRN2), + : DirectMapOp(graph, VSI_NN_OP_LRN2), size_(size), alpha_(alpha), beta_(beta), diff --git a/src/tim/vx/ops/logical.cc b/src/tim/vx/ops/logical.cc index 4864dde..c7a324b 100644 --- a/src/tim/vx/ops/logical.cc +++ b/src/tim/vx/ops/logical.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/logical.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -32,7 +32,7 @@ namespace ops { #define DEFINE_LOGICAL_OP(NAME, VSI_OP_CODE) \ Logical##NAME::Logical##NAME(Graph* graph) \ - : Operation(graph, VSI_NN_OP_LOGICAL_OPS) { \ + : DirectMapOp(graph, VSI_NN_OP_LOGICAL_OPS) { \ this->impl()->node()->nn_param.relational_ops.op = \ VSI_NN_LOGICAL_##VSI_OP_CODE; \ } \ diff --git a/src/tim/vx/ops/logsoftmax.cc b/src/tim/vx/ops/logsoftmax.cc index 8d523d0..ed91336 100644 --- a/src/tim/vx/ops/logsoftmax.cc +++ b/src/tim/vx/ops/logsoftmax.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/logsoftmax.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +31,7 @@ namespace vx { namespace ops { LogSoftmax::LogSoftmax(Graph* graph, int32_t axis, float beta) - : Operation(graph, VSI_NN_OP_LOG_SOFTMAX), axis_(axis), beta_(beta) { + : DirectMapOp(graph, VSI_NN_OP_LOG_SOFTMAX), axis_(axis), beta_(beta) { this->impl()->node()->nn_param.log_softmax.betaValue = beta_; this->impl()->node()->nn_param.log_softmax.axis = axis_; } diff --git a/src/tim/vx/ops/matmul.cc b/src/tim/vx/ops/matmul.cc index e02e407..eeb3928 100644 --- a/src/tim/vx/ops/matmul.cc +++ b/src/tim/vx/ops/matmul.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/matmul.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" #include "type_utils.h" @@ -33,7 +33,7 @@ namespace ops { Matmul::Matmul(Graph* graph, bool transpose_a, bool transpose_b, bool adjoint_a, bool adjoint_b) - : Operation(graph, VSI_NN_OP_MATRIXMUL), transpose_a_(transpose_a), + : DirectMapOp(graph, VSI_NN_OP_MATRIXMUL), transpose_a_(transpose_a), transpose_b_(transpose_b), adjoint_a_(adjoint_a), adjoint_b_(adjoint_b) { this->impl()->node()->nn_param.matrixmul.transpose[0] = ToVxBool(transpose_a_); this->impl()->node()->nn_param.matrixmul.transpose[1] = ToVxBool(transpose_b_); diff --git a/src/tim/vx/ops/maxpoolwithargmax.cc b/src/tim/vx/ops/maxpoolwithargmax.cc index f2126de..de8a038 100644 --- a/src/tim/vx/ops/maxpoolwithargmax.cc +++ b/src/tim/vx/ops/maxpoolwithargmax.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/maxpoolwithargmax.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -36,7 +36,7 @@ MaxpoolWithArgmax::MaxpoolWithArgmax(Graph* graph, PadType padding, const std::array& stride, RoundType round_type, DataLayout layout) - : Operation(graph, VSI_NN_OP_POOLWITHARGMAX, 1, 2, layout), + : DirectMapOp(graph, VSI_NN_OP_POOLWITHARGMAX, 1, 2, layout), padding_(padding), ksize_(ksize), stride_(stride), diff --git a/src/tim/vx/ops/maxunpool2d.cc b/src/tim/vx/ops/maxunpool2d.cc index 29ca625..ea6fa3b 100644 --- a/src/tim/vx/ops/maxunpool2d.cc +++ b/src/tim/vx/ops/maxunpool2d.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/maxunpool2d.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -33,7 +33,7 @@ namespace ops { MaxUnpool2d::MaxUnpool2d(Graph* graph, const std::array& ksize, const std::array& stride, DataLayout layout) - : Operation(graph, VSI_NN_OP_UPSAMPLE, 2, 1, layout), + : DirectMapOp(graph, VSI_NN_OP_UPSAMPLE, 2, 1, layout), ksize_(ksize), stride_(stride) { this->impl()->node()->nn_param.upsample.scale[0] = stride_[0]; this->impl()->node()->nn_param.upsample.scale[1] = stride_[1]; diff --git a/src/tim/vx/ops/mements.cc b/src/tim/vx/ops/mements.cc index b0f0229..f9a5b6e 100644 --- a/src/tim/vx/ops/mements.cc +++ b/src/tim/vx/ops/mements.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/moments.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -31,7 +31,7 @@ namespace tim { namespace vx { namespace ops { Moments::Moments(Graph* graph, const std::vector& axes, bool keep_dims) - : Operation(graph, VSI_NN_OP_MOMENTS), axes_(axes), keep_dims_(keep_dims) { + : DirectMapOp(graph, VSI_NN_OP_MOMENTS), axes_(axes), keep_dims_(keep_dims) { this->impl()->node()->nn_param.moments.axis = axes_.data(); this->impl()->node()->nn_param.moments.axis_num = axes_.size(); this->impl()->node()->nn_param.moments.keep_dim = ToVxBool(keep_dims_); diff --git a/src/tim/vx/ops/nbg.cc b/src/tim/vx/ops/nbg.cc index 00e3787..34dcdd7 100644 --- a/src/tim/vx/ops/nbg.cc +++ b/src/tim/vx/ops/nbg.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/nbg.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,13 +31,13 @@ namespace vx { namespace ops { NBG::NBG(Graph* graph, const char* binary, size_t input_count, size_t output_count) - : Operation(graph, VSI_NN_OP_NBG, input_count, output_count) { + : DirectMapOp(graph, VSI_NN_OP_NBG, input_count, output_count) { this->impl()->node()->nn_param.nbg.url = binary; this->impl()->node()->nn_param.nbg.type = VSI_NN_NBG_POINTER; } std::shared_ptr NBG::Clone(std::shared_ptr& graph) const { - return graph->CreateOperation(this->impl_->node_->nn_param.nbg.url, + return graph->CreateOperation(this->impl_->node()->nn_param.nbg.url, this->impl_->input_cnt_, this->impl_->output_cnt_); } diff --git a/src/tim/vx/ops/pad.cc b/src/tim/vx/ops/pad.cc index 46e8d9c..e0f12db 100644 --- a/src/tim/vx/ops/pad.cc +++ b/src/tim/vx/ops/pad.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/pad.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +31,7 @@ namespace vx { namespace ops { Pad::Pad(Graph* graph, const std::vector& front_size, const std::vector& back_size, int32_t const_val) - : Operation(graph, VSI_NN_OP_PAD), + : DirectMapOp(graph, VSI_NN_OP_PAD), front_size_(front_size), back_size_(back_size), const_val_(const_val) { diff --git a/src/tim/vx/ops/pool2d.cc b/src/tim/vx/ops/pool2d.cc index 083d86f..169afb2 100644 --- a/src/tim/vx/ops/pool2d.cc +++ b/src/tim/vx/ops/pool2d.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/pool2d.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -36,7 +36,7 @@ Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding, const std::array& ksize, const std::array& stride, RoundType round_type, DataLayout layout) - : Operation(graph, VSI_NN_OP_POOL, 1, 1, layout), + : DirectMapOp(graph, VSI_NN_OP_POOL, 1, 1, layout), type_(type), padding_(padding), ksize_(ksize), @@ -59,7 +59,7 @@ Pool2d::Pool2d(Graph* graph, PoolType type, const std::array& ksize, const std::array& stride, RoundType round_type, DataLayout layout) - : Operation(graph, VSI_NN_OP_POOL, 1, 1, layout), + : DirectMapOp(graph, VSI_NN_OP_POOL, 1, 1, layout), type_(type), padding_(PadType::AUTO), ksize_(ksize), stride_(stride), round_type_(round_type), pad_(pad) { Init(); @@ -70,7 +70,7 @@ Pool2d::Pool2d(Graph* graph, PoolType type, const std::array& input_size, RoundType round_type, DataLayout layout) - : Operation(graph, VSI_NN_OP_POOL, 1, 1, layout), + : DirectMapOp(graph, VSI_NN_OP_POOL, 1, 1, layout), type_(type), padding_(PadType::AUTO), ksize_(input_size), stride_(input_size), round_type_(round_type), pad_({0, 0, 0, 0}) { Init(); @@ -82,7 +82,7 @@ Pool2d::Pool2d(Graph* graph, PoolType type, const std::array& output_size, RoundType round_type, DataLayout layout) - : Operation(graph, VSI_NN_OP_POOL, 1, 1, layout), + : DirectMapOp(graph, VSI_NN_OP_POOL, 1, 1, layout), type_(type), padding_(PadType::AUTO), round_type_(round_type), pad_({0, 0, 0, 0}) { stride_[0] = floor(input_size[0] / (float)(output_size[0])); diff --git a/src/tim/vx/ops/reduce.cc b/src/tim/vx/ops/reduce.cc index 8669efa..d758bb3 100644 --- a/src/tim/vx/ops/reduce.cc +++ b/src/tim/vx/ops/reduce.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/reduce.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -33,7 +33,7 @@ namespace ops { #define DEFINE_REDUCE_OP(NAME, VSI_OP_CODE) \ Reduce##NAME::Reduce##NAME(Graph* graph, const std::vector& axis, \ bool keep_dims) \ - : Operation(graph, VSI_NN_OP_REDUCE), \ + : DirectMapOp(graph, VSI_NN_OP_REDUCE), \ axis_(std::move(axis)), \ keep_dims_(keep_dims) { \ this->impl()->node()->nn_param.reduce.type = VSI_OP_CODE; \ diff --git a/src/tim/vx/ops/relational_operations.cc b/src/tim/vx/ops/relational_operations.cc index 6993a0d..c6a9ec5 100644 --- a/src/tim/vx/ops/relational_operations.cc +++ b/src/tim/vx/ops/relational_operations.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/relational_operations.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -32,7 +32,7 @@ namespace ops { #define DEFINE_RELATIONAL_OP(NAME, VSI_OP_CODE) \ NAME::NAME(Graph* graph) \ - : Operation(graph, VSI_NN_OP_RELATIONAL_OPS, 2, 1) { \ + : DirectMapOp(graph, VSI_NN_OP_RELATIONAL_OPS, 2, 1) { \ this->impl()->node()->nn_param.relational_ops.op = VSI_OP_CODE; \ } \ std::shared_ptr NAME::Clone(std::shared_ptr& graph) \ diff --git a/src/tim/vx/ops/reorg.cc b/src/tim/vx/ops/reorg.cc index b791952..8b3169a 100644 --- a/src/tim/vx/ops/reorg.cc +++ b/src/tim/vx/ops/reorg.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/reorg.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +31,7 @@ namespace vx { namespace ops { Reorg::Reorg(Graph* graph, const uint32_t stride) - : Operation(graph, VSI_NN_OP_REORG), stride_(stride) { + : DirectMapOp(graph, VSI_NN_OP_REORG), stride_(stride) { this->impl()->node()->nn_param.reorg.stride = stride_; } diff --git a/src/tim/vx/ops/reshape.cc b/src/tim/vx/ops/reshape.cc index 79f0fb0..d3e80f0 100644 --- a/src/tim/vx/ops/reshape.cc +++ b/src/tim/vx/ops/reshape.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/reshape.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +31,7 @@ namespace vx { namespace ops { Reshape::Reshape(Graph* graph, const std::vector& size) - : Operation(graph, VSI_NN_OP_RESHAPE), size_(std::move(size)) { + : DirectMapOp(graph, VSI_NN_OP_RESHAPE), size_(std::move(size)) { this->impl()->node()->nn_param.reshape.size = size_.data(); this->impl()->node()->nn_param.reshape.dim_num = size_.size(); } diff --git a/src/tim/vx/ops/resize.cc b/src/tim/vx/ops/resize.cc index 82a20d3..ded4a98 100644 --- a/src/tim/vx/ops/resize.cc +++ b/src/tim/vx/ops/resize.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/resize.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -34,7 +34,7 @@ namespace ops { Resize::Resize(Graph* graph, ResizeType type, float factor, bool align_corners, bool half_pixel_centers, int target_height, int target_width, DataLayout layout) - : Operation(graph, VSI_NN_OP_RESIZE, 0, 0, layout), + : DirectMapOp(graph, VSI_NN_OP_RESIZE, 0, 0, layout), type_(type), factor_(factor), align_corners_(align_corners), diff --git a/src/tim/vx/ops/resize1d.cc b/src/tim/vx/ops/resize1d.cc index a436b24..9e6bbcd 100644 --- a/src/tim/vx/ops/resize1d.cc +++ b/src/tim/vx/ops/resize1d.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/resize1d.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "type_utils.h" #include "vsi_nn_pub.h" @@ -33,7 +33,7 @@ namespace ops { Resize1d::Resize1d(Graph* graph, ResizeType type, float factor, bool align_corners, bool half_pixel_centers, int target_size, DataLayout layout) - : Operation(graph, VSI_NN_OP_RESIZE_1D, 0, 0, layout), + : DirectMapOp(graph, VSI_NN_OP_RESIZE_1D, 0, 0, layout), type_(type), factor_(factor), align_corners_(align_corners), diff --git a/src/tim/vx/ops/reverse.cc b/src/tim/vx/ops/reverse.cc index 6de5502..487d116 100644 --- a/src/tim/vx/ops/reverse.cc +++ b/src/tim/vx/ops/reverse.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/reverse.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +31,7 @@ namespace vx { namespace ops { Reverse::Reverse(Graph* graph, const std::vector& axis) - : Operation(graph, VSI_NN_OP_REVERSE), axis_(axis) { + : DirectMapOp(graph, VSI_NN_OP_REVERSE), axis_(axis) { this->impl()->node()->nn_param.reverse.axis = axis_.data(); this->impl()->node()->nn_param.reverse.axis_num = axis_.size(); } diff --git a/src/tim/vx/ops/rnn_cell.cc b/src/tim/vx/ops/rnn_cell.cc new file mode 100644 index 0000000..6b784a7 --- /dev/null +++ b/src/tim/vx/ops/rnn_cell.cc @@ -0,0 +1,142 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/ops.h" +#include "vsi_nn_pub.h" +#include "op_impl.h" + +#include +namespace tim { +namespace vx { +namespace ops { + +class RNNCellImpl : public OpImpl{ + public: + + enum { + // signature + FULLY_CONNECTED_0_IN = 0, + FULLY_CONNECTED_0_WEIGHT = 1, + FULLY_CONNECTED_0_BIAS = 2, + FULLY_CONNECTED_1_WEIGHT = 3, + FULLY_CONNECTED_1_STATE_IN = 4, + + INPUT_CNT, + + OUT = 0, + STATE_OUT, + OUT_CNT, + // signature end + }; + + RNNCellImpl(Graph* graph, int input_cnt, + int output_cnt, DataLayout layout = DataLayout::ANY) + : OpImpl(graph, -1, input_cnt, output_cnt, layout){ + fc0_ = graph->CreateOperation(0, 4); + fc1_ = graph->CreateOperation(0, 4); + add_ = graph->CreateOperation(); + tanh_ = graph->CreateOperation(); + data_convert_ = graph->CreateOperation(); + } + + ~RNNCellImpl() {} + + RNNCellImpl& BindInput(const std::shared_ptr& tensor) override + { + in_tensors_[input_tensor_index] = tensor; + + if (this->input_tensor_index == INPUT_CNT - 1) { + // Get all input tensor + tim::vx::ShapeType shape = {0, 0}; + tim::vx::TensorSpec FC0_spec(tim::vx::DataType::FLOAT32, shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec FC1_spec(tim::vx::DataType::FLOAT32, shape, + tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec add_spec(tim::vx::DataType::FLOAT32, shape, + tim::vx::TensorAttribute::TRANSIENT); + + + auto FC0_tensor = graph_->CreateTensor(FC0_spec); + auto FC1_tensor = graph_->CreateTensor(FC1_spec); + auto add_tensor = graph_->CreateTensor(add_spec); + + fc0_->BindInput(in_tensors_[FULLY_CONNECTED_0_IN]); + fc0_->BindInput(in_tensors_[FULLY_CONNECTED_0_WEIGHT]); + fc0_->BindInput(in_tensors_[FULLY_CONNECTED_0_BIAS]); + fc0_->BindOutput(FC0_tensor); + + fc1_->BindInput(in_tensors_[FULLY_CONNECTED_1_WEIGHT]); + fc1_->BindInput(in_tensors_[FULLY_CONNECTED_1_STATE_IN]); + fc1_->BindOutput(FC1_tensor); + + add_->BindInput(FC0_tensor); + add_->BindInput(FC1_tensor); + add_->BindOutput(add_tensor); + + tanh_->BindInput(add_tensor); + } + this->input_tensor_index++; + return *this; + } + + RNNCellImpl& BindOutput(const std::shared_ptr& tensor) override{ + out_tensors_[output_tensor_index] = tensor; + + tanh_->BindOutput(out_tensors_[OUT]); + data_convert_->BindInput(out_tensors_[OUT]); + if (this->output_tensor_index == OUT_CNT - 1){ + data_convert_->BindOutput(out_tensors_[STATE_OUT]); + } + this->output_tensor_index++; + return *this; + } + + vsi_nn_node_t* node() override{ return nullptr; } + + std::vector> InputsTensor() { return inputs_tensor_; } + std::vector> OutputsTensor() { + return outputs_tensor_; + } + + private: + std::shared_ptr fc0_; + std::shared_ptr fc1_; + std::shared_ptr add_; + std::shared_ptr tanh_; + std::shared_ptr data_convert_; + + std::array, INPUT_CNT> in_tensors_; + std::array, OUT_CNT> out_tensors_; +}; + +RNNCell::RNNCell(Graph* graph, ActivationType activation) : activation_(activation){ + impl_ = std::make_unique(graph, 0, 0, DataLayout::ANY); +} + +std::shared_ptr RNNCell::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->activation_); +} + +} // namespace ops +} // namespace vx +} // namespace tim diff --git a/src/tim/vx/ops/rnn_cell_test.cc b/src/tim/vx/ops/rnn_cell_test.cc new file mode 100644 index 0000000..aaa499e --- /dev/null +++ b/src/tim/vx/ops/rnn_cell_test.cc @@ -0,0 +1,241 @@ +/**************************************************************************** +* +* Copyright (c) 2021 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/ops.h" +#include "tim/vx/types.h" +#include "test_utils.h" +#include "gtest/gtest.h" + +TEST(RNNCell, shape_3_2_4_float) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + uint32_t input_size = 3, batch_size = 2, num_units = 4; + + tim::vx::ShapeType input_shape({input_size, batch_size}); + tim::vx::ShapeType weights_shape({input_size, num_units}); + tim::vx::ShapeType recurrent_weights_shape({num_units, num_units}); + tim::vx::ShapeType bias_shape({num_units}); + tim::vx::ShapeType state_in_shape({num_units, batch_size}); + tim::vx::ShapeType output_shape({num_units, batch_size}); + tim::vx::ShapeType state_out_shape({num_units, batch_size}); + tim::vx::Quantization quant(tim::vx::QuantType::ASYMMETRIC, 0.0036, 0); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + input_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec weights_spec(tim::vx::DataType::FLOAT32, + weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_weights_spec(tim::vx::DataType::FLOAT32, + recurrent_weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, + bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec state_in_spec(tim::vx::DataType::FLOAT32, + state_in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + output_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec state_out_spec(tim::vx::DataType::UINT8, + state_out_shape, tim::vx::TensorAttribute::OUTPUT,quant); + + auto input_tensor = graph->CreateTensor(input_spec); + auto weights_tensor = graph->CreateTensor(weights_spec); + auto recurrent_weights_tensor = graph->CreateTensor(recurrent_weights_spec); + auto bias_tensor = graph->CreateTensor(bias_spec); + auto state_in_tensor = graph->CreateTensor(state_in_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto state_out_tensor = graph->CreateTensor(state_out_spec); + + std::vector in_data = { + 0.12609188, 0.46347019, 0.89598465, + 0.35867718, 0.36897406, 0.73463392, + }; + std::vector weights_data = { + 0.12609188, 0.46347019, 0.89598465, + 0.35867718, 0.36897406, 0.73463392, + 0.12609188, 0.46347019, 0.89598465, + 0.35867718, 0.36897406, 0.73463392, + }; + std::vector recurrent_weights_data = { + -0.31930989, 0.37613347, 0.27901134, 0.36137494, + -1.36916667, 0.38031587, 0.21580373, 0.27072677, + 1.01580888, 0.14943552, 1.15465137, 0.09784451, + -1.02702999, 1.39296314, 0.15785322, 0.21931258, + }; + std::vector bias_data = { + 0.01580888, 0.14943552, 0.15465137, 0.09784451, + }; + std::vector state_in_data = { + 0,0,0,0,0,0,0,0 + }; + std::vector output_golden = { + 0.781534, 0.771447, 0.830002, 0.749713, 0.711524, 0.74155, 0.77355, 0.717427 + }; + std::vector state_out_golden = { + 217, 214, 231, 208, 198, 206, 215, 199, + }; + + EXPECT_TRUE(input_tensor->CopyDataToTensor( + in_data.data(), in_data.size() * sizeof(float))); + EXPECT_TRUE(weights_tensor->CopyDataToTensor( + weights_data.data(), weights_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_weights_tensor->CopyDataToTensor( + recurrent_weights_data.data(), recurrent_weights_data.size() * sizeof(float))); + EXPECT_TRUE(bias_tensor->CopyDataToTensor( + bias_data.data(), bias_data.size() * sizeof(float))); + EXPECT_TRUE(state_in_tensor->CopyDataToTensor( + state_in_data.data(), state_in_data.size() * sizeof(float))); + + auto op = graph->CreateOperation(tim::vx::ops::RNNCell::ActivationType::kSIGMOID); + (*op).BindInputs({input_tensor, weights_tensor, bias_tensor, state_in_tensor, recurrent_weights_tensor}) + .BindOutputs({output_tensor, state_out_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(output_golden.size()); + std::vector state_out(state_out_golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(state_out_tensor->CopyDataFromTensor(state_out.data())); + + EXPECT_TRUE(ArraysMatch(output_golden, output,1e-5f)); + EXPECT_EQ(state_out_golden, state_out); +} + +TEST(RNNCell, seperate) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + std::vector in_data = { + 0.12609188, 0.46347019, 0.89598465, + 0.35867718, 0.36897406, 0.73463392, + }; + std::vector weights_data = { + 0.12609188, 0.46347019, 0.89598465, + 0.35867718, 0.36897406, 0.73463392, + 0.12609188, 0.46347019, 0.89598465, + 0.35867718, 0.36897406, 0.73463392, + }; + std::vector recurrent_weights_data = { + -0.31930989, 0.37613347, 0.27901134, 0.36137494, + -1.36916667, 0.38031587, 0.21580373, 0.27072677, + 1.01580888, 0.14943552, 1.15465137, 0.09784451, + -1.02702999, 1.39296314, 0.15785322, 0.21931258, + }; + std::vector bias_data = { + 0.01580888, 0.14943552, 0.15465137, 0.09784451, + }; + std::vector state_in_data = { + 0,0,0,0,0,0,0,0 + }; + std::vector output_golden = { + 0.781534, 0.771447, 0.830002, 0.749713, 0.711524, 0.74155, 0.77355, 0.717427 + }; + std::vector state_out_golden = { + 217, 214, 231, 208, 198, 206, 215, 199, + }; + uint32_t input_size = 3, batch_size = 2, num_units = 4; + + tim::vx::ShapeType input_shape({input_size, batch_size}); + tim::vx::ShapeType weights_shape({input_size, num_units}); + tim::vx::ShapeType recurrent_weights_shape({num_units, num_units}); + tim::vx::ShapeType bias_shape({num_units}); + tim::vx::ShapeType state_in_shape({num_units, batch_size}); + + tim::vx::ShapeType FC1_shape({num_units, batch_size}); + tim::vx::ShapeType add_shape({num_units, batch_size}); + tim::vx::ShapeType FC2_shape({num_units, batch_size}); + tim::vx::ShapeType activation_shape({num_units, batch_size}); + tim::vx::ShapeType convert_shape({num_units, batch_size}); + tim::vx::Quantization quant(tim::vx::QuantType::ASYMMETRIC, 0.0036, 0); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + input_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec weights_spec(tim::vx::DataType::FLOAT32, + weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec recurrent_weights_spec(tim::vx::DataType::FLOAT32, + recurrent_weights_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, + bias_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec state_in_spec(tim::vx::DataType::FLOAT32, + state_in_shape, tim::vx::TensorAttribute::INPUT); + + tim::vx::TensorSpec FC1_spec(tim::vx::DataType::FLOAT32, + FC1_shape, tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec FC2_spec(tim::vx::DataType::FLOAT32, + FC2_shape, tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec add_spec(tim::vx::DataType::FLOAT32, + add_shape, tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec activation_out_spec(tim::vx::DataType::FLOAT32, + activation_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec convert_spec(tim::vx::DataType::UINT8, + convert_shape, tim::vx::TensorAttribute::OUTPUT, quant); + + auto input_tensor = graph->CreateTensor(input_spec); + auto weights_tensor = graph->CreateTensor(weights_spec); + auto recurrent_weights_tensor = graph->CreateTensor(recurrent_weights_spec); + auto bias_tensor = graph->CreateTensor(bias_spec); + auto state_in_tensor = graph->CreateTensor(state_in_spec); + + auto FC1_tensor = graph->CreateTensor(FC1_spec); + auto FC2_tensor = graph->CreateTensor(FC2_spec); + auto add_tensor = graph->CreateTensor(add_spec); + auto activation_out_tensor = graph->CreateTensor(activation_out_spec); + auto convert_tensor = graph->CreateTensor(convert_spec); + + EXPECT_TRUE(input_tensor->CopyDataToTensor( + in_data.data(), in_data.size() * sizeof(float))); + EXPECT_TRUE(weights_tensor->CopyDataToTensor( + weights_data.data(), weights_data.size() * sizeof(float))); + EXPECT_TRUE(recurrent_weights_tensor->CopyDataToTensor( + recurrent_weights_data.data(), recurrent_weights_data.size() * sizeof(float))); + EXPECT_TRUE(bias_tensor->CopyDataToTensor( + bias_data.data(), bias_data.size() * sizeof(float))); + EXPECT_TRUE(state_in_tensor->CopyDataToTensor( + state_in_data.data(), state_in_data.size() * sizeof(float))); + + auto op1 = graph->CreateOperation(0,4); + (*op1).BindInputs({input_tensor, weights_tensor, bias_tensor}) + .BindOutputs({FC1_tensor}); + auto op2 = graph->CreateOperation(0,4); + (*op2).BindInputs({state_in_tensor, recurrent_weights_tensor}) + .BindOutputs({FC2_tensor}); + auto op3 = graph->CreateOperation(); + (*op3).BindInputs({FC1_tensor, FC2_tensor}) + .BindOutputs({add_tensor}); + auto op4 = graph->CreateOperation(); + (*op4).BindInputs({add_tensor}) + .BindOutputs({activation_out_tensor}); + auto op5 = graph->CreateOperation(); + (*op5).BindInputs({activation_out_tensor}) + .BindOutputs({convert_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(output_golden.size()); + std::vector state_out(state_out_golden.size()); + EXPECT_TRUE(activation_out_tensor->CopyDataFromTensor(output.data())); + EXPECT_TRUE(convert_tensor->CopyDataFromTensor(state_out.data())); + EXPECT_TRUE(ArraysMatch(output_golden, output, 1e-5f)); + EXPECT_EQ(state_out_golden, state_out); +} \ No newline at end of file diff --git a/src/tim/vx/ops/scatternd.cc b/src/tim/vx/ops/scatternd.cc index f62a917..0372c41 100644 --- a/src/tim/vx/ops/scatternd.cc +++ b/src/tim/vx/ops/scatternd.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/scatternd.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +31,7 @@ namespace vx { namespace ops { ScatterND::ScatterND(Graph* graph, const std::vector& shape) - : Operation(graph, VSI_NN_OP_SCATTER_ND), shape_(shape) { + : DirectMapOp(graph, VSI_NN_OP_SCATTER_ND), shape_(shape) { this->impl()->node()->nn_param.scatter_nd.dim_num = shape_.size(); this->impl()->node()->nn_param.scatter_nd.shape = shape_.data(); } diff --git a/src/tim/vx/ops/select.cc b/src/tim/vx/ops/select.cc index cf1b889..27ad69b 100644 --- a/src/tim/vx/ops/select.cc +++ b/src/tim/vx/ops/select.cc @@ -23,7 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/select.h" -#include "operation_private.h" +#include "direct_map_op_impl.h" #include "vsi_nn_pub.h" namespace tim { @@ -31,7 +31,7 @@ namespace vx { namespace ops { Select::Select(Graph* graph) - : Operation(graph, VSI_NN_OP_SELECT) {} + : DirectMapOp(graph, VSI_NN_OP_SELECT) {} std::shared_ptr Select::Clone(std::shared_ptr& graph) const { return graph->CreateOperation