diff --git a/include/tim/vx/operation.h b/include/tim/vx/operation.h index 2367879..57fa7ef 100644 --- a/include/tim/vx/operation.h +++ b/include/tim/vx/operation.h @@ -37,6 +37,7 @@ class Operation { Operation(Graph* graph, uint32_t operation_id, int input_cnt = 0, int ouput_cnt = 0, DataLayout layout = DataLayout::ANY); virtual ~Operation(); + virtual std::shared_ptr Clone(std::shared_ptr& graph) const = 0; Operation& BindInput(const std::shared_ptr& tensor); Operation& BindOutput(const std::shared_ptr& tensor); Operation& BindInputs(const std::vector>& tensors); diff --git a/include/tim/vx/ops/activations.h b/include/tim/vx/ops/activations.h index 5a28555..a4b5002 100644 --- a/include/tim/vx/ops/activations.h +++ b/include/tim/vx/ops/activations.h @@ -64,10 +64,12 @@ namespace ops { * ``` */ -#define DECLARE_NO_PARAMETER_ACTIVATION(NAME) \ - class NAME : public Operation { \ - public: \ - NAME(Graph* graph); \ +#define DECLARE_NO_PARAMETER_ACTIVATION(NAME) \ + class NAME : public Operation { \ + public: \ + NAME(Graph* graph); \ + std::shared_ptr Clone( \ + std::shared_ptr& graph) const override; \ }; DECLARE_NO_PARAMETER_ACTIVATION(Relu) @@ -86,6 +88,8 @@ DECLARE_NO_PARAMETER_ACTIVATION(SoftRelu) class Prelu : public Operation { public: Prelu(Graph* graph, int axis); + std::shared_ptr Clone( + std::shared_ptr& graph) const override; protected: int axis_; @@ -94,6 +98,8 @@ class Prelu : public Operation { class LeakyRelu : public Operation { public: LeakyRelu(Graph* graph, float alpha); + std::shared_ptr Clone( + std::shared_ptr& graph) const override; protected: float alpha_; @@ -101,7 +107,10 @@ class LeakyRelu : public Operation { class Linear : public Operation { public: - Linear(Graph* graph, float a, float b=0.0); + Linear(Graph* graph, float a, float b = 0.0); + std::shared_ptr Clone( + std::shared_ptr& graph) const override; + protected: float a_; float b_; diff --git a/include/tim/vx/ops/addn.h b/include/tim/vx/ops/addn.h index 2fb1bd6..3de97e5 100644 --- a/include/tim/vx/ops/addn.h +++ b/include/tim/vx/ops/addn.h @@ -40,6 +40,8 @@ namespace ops { class AddN : public Operation { public: AddN(Graph* graph, uint32_t num_inputs); + + std::shared_ptr Clone(std::shared_ptr& graph) const override; }; } // namespace ops diff --git a/include/tim/vx/ops/arg.h b/include/tim/vx/ops/arg.h index 26a3b11..316ae8b 100644 --- a/include/tim/vx/ops/arg.h +++ b/include/tim/vx/ops/arg.h @@ -36,13 +36,15 @@ namespace ops { * along the provided **axis**. The type of the output tensor is integer. */ -#define DECLARE_ARG_OP(NAME) \ - class Arg##NAME : public Operation { \ - public: \ - Arg##NAME(Graph* graph, int32_t axis); \ - \ - protected: \ - int32_t axis_; \ +#define DECLARE_ARG_OP(NAME) \ + class Arg##NAME : public Operation { \ + public: \ + Arg##NAME(Graph* graph, int32_t axis); \ + std::shared_ptr Clone( \ + std::shared_ptr& graph) const override; \ + \ + protected: \ + int32_t axis_; \ }; DECLARE_ARG_OP(Min); diff --git a/include/tim/vx/ops/batch2space.h b/include/tim/vx/ops/batch2space.h index 576420e..61fe51d 100644 --- a/include/tim/vx/ops/batch2space.h +++ b/include/tim/vx/ops/batch2space.h @@ -49,6 +49,8 @@ class Batch2Space : public Operation { const std::vector& crop, DataLayout layout = DataLayout::WHCN); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: std::vector block_size_; std::vector crop_; diff --git a/include/tim/vx/ops/batchnorm.h b/include/tim/vx/ops/batchnorm.h index 8d52586..30d25d4 100644 --- a/include/tim/vx/ops/batchnorm.h +++ b/include/tim/vx/ops/batchnorm.h @@ -44,7 +44,9 @@ class BatchNorm : public Operation { public: BatchNorm(Graph* graph, float eps); - protected: + std::shared_ptr Clone(std::shared_ptr& graph) const override; + + protected: float eps_; }; diff --git a/include/tim/vx/ops/clip.h b/include/tim/vx/ops/clip.h index 35ad07f..9cc4c8a 100644 --- a/include/tim/vx/ops/clip.h +++ b/include/tim/vx/ops/clip.h @@ -39,7 +39,10 @@ namespace ops { class Clip : public Operation { public: Clip(Graph* graph, float min, float max); - protected: + + std::shared_ptr Clone(std::shared_ptr& graph) const override; + + protected: float min_; float max_; }; diff --git a/include/tim/vx/ops/concat.h b/include/tim/vx/ops/concat.h index 2ed5dcd..9c3c9aa 100644 --- a/include/tim/vx/ops/concat.h +++ b/include/tim/vx/ops/concat.h @@ -41,6 +41,8 @@ class Concat : public Operation { public: Concat(Graph* graph, uint32_t axis, int input_cnt); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: uint32_t axis_; }; diff --git a/include/tim/vx/ops/conv1d.h b/include/tim/vx/ops/conv1d.h index 64cb1de..60383c3 100644 --- a/include/tim/vx/ops/conv1d.h +++ b/include/tim/vx/ops/conv1d.h @@ -55,6 +55,8 @@ class Conv1d : public Operation { DataLayout KernelDataLayout() { return kernel_layout_; } + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const uint32_t weights_; const PadType padding_; diff --git a/include/tim/vx/ops/conv2d.h b/include/tim/vx/ops/conv2d.h index 6527ae1..a997cf2 100644 --- a/include/tim/vx/ops/conv2d.h +++ b/include/tim/vx/ops/conv2d.h @@ -83,6 +83,8 @@ class Conv2d : public Operation { DataLayout KernelDataLayout() { return kernel_layout_; } + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const uint32_t weights_; const PadType padding_; diff --git a/include/tim/vx/ops/deconv.h b/include/tim/vx/ops/deconv.h index ab7def4..79c8f9b 100644 --- a/include/tim/vx/ops/deconv.h +++ b/include/tim/vx/ops/deconv.h @@ -71,6 +71,9 @@ class DeConv2d : public Operation { DataLayout kernel_layout = DataLayout::WHIcOc); DataLayout KernelDataLayout() { return kernel_layout_; } + + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const uint32_t oc_count_; const PadType pad_type_; diff --git a/include/tim/vx/ops/deconv1d.h b/include/tim/vx/ops/deconv1d.h index 3f7377d..1e30017 100644 --- a/include/tim/vx/ops/deconv1d.h +++ b/include/tim/vx/ops/deconv1d.h @@ -74,6 +74,8 @@ class DeConv1d : public Operation { const std::array& pad, uint32_t group, DataLayout input_layout, DataLayout kernel_layout); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const uint32_t oc_count_; // output channel count const PadType pad_type_; diff --git a/include/tim/vx/ops/depth2space.h b/include/tim/vx/ops/depth2space.h index fc5e8b4..ba33d14 100644 --- a/include/tim/vx/ops/depth2space.h +++ b/include/tim/vx/ops/depth2space.h @@ -50,6 +50,8 @@ class DepthToSpace : public Operation { DepthToSpace(Graph* Graph, int block_size, DataLayout layout = DataLayout::WHCN); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: int block_size_; }; diff --git a/include/tim/vx/ops/dropout.h b/include/tim/vx/ops/dropout.h index b4e4d3a..7a5b6ff 100644 --- a/include/tim/vx/ops/dropout.h +++ b/include/tim/vx/ops/dropout.h @@ -41,11 +41,13 @@ namespace ops { */ class Dropout : public Operation { - public: - Dropout(Graph* graph, float ratio); + public: + Dropout(Graph* graph, float ratio); - protected: - float ratio_; + std::shared_ptr Clone(std::shared_ptr& graph) const override; + + protected: + float ratio_; }; } // namespace ops diff --git a/include/tim/vx/ops/elementwise.h b/include/tim/vx/ops/elementwise.h index 30dea90..905e0cb 100644 --- a/include/tim/vx/ops/elementwise.h +++ b/include/tim/vx/ops/elementwise.h @@ -66,10 +66,12 @@ namespace ops { * FloorDiv(x, y): floor( x / y ). This operation supports broadcasting. */ -#define DECLARE_ELEMENTWISE_OP(NAME) \ - class NAME : public Operation { \ - public: \ - NAME(Graph* graph); \ +#define DECLARE_ELEMENTWISE_OP(NAME) \ + class NAME : public Operation { \ + public: \ + NAME(Graph* graph); \ + std::shared_ptr Clone( \ + std::shared_ptr& graph) const override; \ }; DECLARE_ELEMENTWISE_OP(Minimum) @@ -81,8 +83,10 @@ DECLARE_ELEMENTWISE_OP(Pow) DECLARE_ELEMENTWISE_OP(FloorDiv) class Multiply : public Operation { - public: - Multiply(Graph* graph, float scale = 1.0f); + public: + Multiply(Graph* graph, float scale = 1.0f); + + std::shared_ptr Clone(std::shared_ptr& graph) const override; }; #undef DECLARE_ELEMENTWISE_OP diff --git a/include/tim/vx/ops/fullyconnected.h b/include/tim/vx/ops/fullyconnected.h index 3aad178..38ffb04 100644 --- a/include/tim/vx/ops/fullyconnected.h +++ b/include/tim/vx/ops/fullyconnected.h @@ -44,6 +44,8 @@ class FullyConnected : public Operation { FullyConnected(Graph* graph, uint32_t axis); FullyConnected(Graph* graph, uint32_t axis, uint32_t weights); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: uint32_t axis_; uint32_t weights_; diff --git a/include/tim/vx/ops/gather.h b/include/tim/vx/ops/gather.h index de294fa..953190f 100644 --- a/include/tim/vx/ops/gather.h +++ b/include/tim/vx/ops/gather.h @@ -39,6 +39,8 @@ class Gather : public Operation { public: Gather(Graph* Graph, int axis); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: int axis_; }; diff --git a/include/tim/vx/ops/gathernd.h b/include/tim/vx/ops/gathernd.h index f4d92e6..31f8151 100644 --- a/include/tim/vx/ops/gathernd.h +++ b/include/tim/vx/ops/gathernd.h @@ -38,6 +38,9 @@ namespace ops { class GatherNd : public Operation { public: GatherNd(Graph* Graph); + + std::shared_ptr Clone(std::shared_ptr& graph) const override; + }; } // namespace ops diff --git a/include/tim/vx/ops/groupedconv2d.h b/include/tim/vx/ops/groupedconv2d.h index 68f4508..9e9a5eb 100644 --- a/include/tim/vx/ops/groupedconv2d.h +++ b/include/tim/vx/ops/groupedconv2d.h @@ -71,6 +71,8 @@ class GroupedConv2d : public Operation { DataLayout KernelDataLayout() { return kernel_layout_; } + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const PadType padding_; const std::array strides_; diff --git a/include/tim/vx/ops/instancenormalization.h b/include/tim/vx/ops/instancenormalization.h index 9f20eb6..0ffd60e 100644 --- a/include/tim/vx/ops/instancenormalization.h +++ b/include/tim/vx/ops/instancenormalization.h @@ -32,6 +32,8 @@ class InstanceNormalization : public Operation { public: InstanceNormalization(Graph* graph, float eps = 1e-5f); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: float eps_; }; diff --git a/include/tim/vx/ops/l2normalization.h b/include/tim/vx/ops/l2normalization.h index 2e1f355..9e0e00e 100644 --- a/include/tim/vx/ops/l2normalization.h +++ b/include/tim/vx/ops/l2normalization.h @@ -44,6 +44,8 @@ class L2Normalization : public Operation { public: L2Normalization(Graph* graph, int32_t axis); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: int32_t axis_; }; diff --git a/include/tim/vx/ops/layernormalization.h b/include/tim/vx/ops/layernormalization.h index e2803a2..55e2d4e 100644 --- a/include/tim/vx/ops/layernormalization.h +++ b/include/tim/vx/ops/layernormalization.h @@ -34,6 +34,8 @@ class LayerNormalization : public Operation { public: LayerNormalization(Graph* graph, int32_t axis = 0, float eps = 1e-5f); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: int32_t axis_; int32_t eps_; diff --git a/include/tim/vx/ops/localresponsenormalization.h b/include/tim/vx/ops/localresponsenormalization.h index 8f31ff0..0b0bc24 100644 --- a/include/tim/vx/ops/localresponsenormalization.h +++ b/include/tim/vx/ops/localresponsenormalization.h @@ -45,6 +45,8 @@ class LocalResponseNormalization : public Operation { LocalResponseNormalization(Graph* graph, uint32_t size, float alpha, float beta, float bias, int32_t axis); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: uint32_t size_; float alpha_; @@ -52,6 +54,8 @@ class LocalResponseNormalization : public Operation { float bias_; int32_t axis_; }; + +using LRN = LocalResponseNormalization; } // namespace ops } // namespace vx } // namespace tim diff --git a/include/tim/vx/ops/logical.h b/include/tim/vx/ops/logical.h index 911ac16..66a2573 100644 --- a/include/tim/vx/ops/logical.h +++ b/include/tim/vx/ops/logical.h @@ -39,10 +39,12 @@ namespace ops { * Returns the truth value of x OR y element-wise. This operation supports broadcasting. */ -#define DECLARE_LOGICAL_OP(NAME) \ - class Logical##NAME : public Operation { \ - public: \ - Logical##NAME(Graph* graph); \ +#define DECLARE_LOGICAL_OP(NAME) \ + class Logical##NAME : public Operation { \ + public: \ + Logical##NAME(Graph* graph); \ + std::shared_ptr Clone( \ + std::shared_ptr& graph) const override; \ }; DECLARE_LOGICAL_OP(And); diff --git a/include/tim/vx/ops/logsoftmax.h b/include/tim/vx/ops/logsoftmax.h index f05de74..67cdde2 100644 --- a/include/tim/vx/ops/logsoftmax.h +++ b/include/tim/vx/ops/logsoftmax.h @@ -43,6 +43,8 @@ class LogSoftmax : public Operation { public: LogSoftmax(Graph* graph, int32_t axis, float beta = 1.f); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: int32_t axis_; float beta_; diff --git a/include/tim/vx/ops/matmul.h b/include/tim/vx/ops/matmul.h index af47a58..9a63013 100644 --- a/include/tim/vx/ops/matmul.h +++ b/include/tim/vx/ops/matmul.h @@ -45,6 +45,8 @@ class Matmul : public Operation { Matmul(Graph* graph, bool transpose_a = false, bool transpose_b = false, bool adjoint_a = false, bool adjoint_b = false); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: bool transpose_a_; bool transpose_b_; diff --git a/include/tim/vx/ops/maxpoolwithargmax.h b/include/tim/vx/ops/maxpoolwithargmax.h index 23d6f04..5bfba13 100644 --- a/include/tim/vx/ops/maxpoolwithargmax.h +++ b/include/tim/vx/ops/maxpoolwithargmax.h @@ -52,6 +52,8 @@ class MaxpoolWithArgmax : public Operation { RoundType round_type = RoundType::FLOOR, DataLayout layout = DataLayout::WHCN); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const PadType padding_; const std::array ksize_; diff --git a/include/tim/vx/ops/maxunpool2d.h b/include/tim/vx/ops/maxunpool2d.h index 943985c..bd004d6 100644 --- a/include/tim/vx/ops/maxunpool2d.h +++ b/include/tim/vx/ops/maxunpool2d.h @@ -47,6 +47,8 @@ class MaxUnpool2d : public Operation { MaxUnpool2d(Graph* graph, const std::array& ksize, const std::array& stride, DataLayout layout = DataLayout::WHCN); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const std::array ksize_; const std::array stride_; diff --git a/include/tim/vx/ops/moments.h b/include/tim/vx/ops/moments.h index e7a6191..6d34481 100644 --- a/include/tim/vx/ops/moments.h +++ b/include/tim/vx/ops/moments.h @@ -40,13 +40,15 @@ namespace ops { */ class Moments : public Operation { - public: - Moments(Graph* graph, const std::vector& axes, - bool keep_dims = false); + public: + Moments(Graph* graph, const std::vector& axes, + bool keep_dims = false); - protected: - const std::vector axes_; - const bool keep_dims_; + std::shared_ptr Clone(std::shared_ptr& graph) const override; + + protected: + const std::vector axes_; + const bool keep_dims_; }; } // namespace ops diff --git a/include/tim/vx/ops/nbg.h b/include/tim/vx/ops/nbg.h index fffac21..0d0bd6d 100644 --- a/include/tim/vx/ops/nbg.h +++ b/include/tim/vx/ops/nbg.h @@ -40,6 +40,8 @@ class NBG : public Operation { public: NBG(Graph* graph, const char* binary, size_t input_count, size_t output_count); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + }; } // namespace ops diff --git a/include/tim/vx/ops/pad.h b/include/tim/vx/ops/pad.h index 49448e8..cb863a1 100644 --- a/include/tim/vx/ops/pad.h +++ b/include/tim/vx/ops/pad.h @@ -42,6 +42,8 @@ class Pad : public Operation { Pad(Graph* graph, const std::vector& front_size, const std::vector& back_size, int32_t const_val); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: std::vector front_size_; std::vector back_size_; diff --git a/include/tim/vx/ops/pool2d.h b/include/tim/vx/ops/pool2d.h index 19c2864..fbb2041 100644 --- a/include/tim/vx/ops/pool2d.h +++ b/include/tim/vx/ops/pool2d.h @@ -59,6 +59,8 @@ class Pool2d : public Operation { RoundType round_type = RoundType::FLOOR, DataLayout layout = DataLayout::WHCN); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const PoolType type_; const PadType padding_; diff --git a/include/tim/vx/ops/reduce.h b/include/tim/vx/ops/reduce.h index f846d16..592bfd7 100644 --- a/include/tim/vx/ops/reduce.h +++ b/include/tim/vx/ops/reduce.h @@ -99,6 +99,8 @@ namespace ops { public: \ Reduce##NAME(Graph* graph, const std::vector& axis, \ bool keep_dims); \ + std::shared_ptr \ + Clone(std::shared_ptr& graph) const override; \ \ protected: \ std::vector axis_; \ diff --git a/include/tim/vx/ops/relational_operations.h b/include/tim/vx/ops/relational_operations.h index 11c4fcc..a2a2732 100644 --- a/include/tim/vx/ops/relational_operations.h +++ b/include/tim/vx/ops/relational_operations.h @@ -55,10 +55,12 @@ namespace ops { * For input tensors x and y, computes x == y elementwise. */ -#define DECLARE_RELATIONAL_OP(NAME) \ - class NAME : public Operation { \ - public: \ - NAME(Graph* graph); \ +#define DECLARE_RELATIONAL_OP(NAME) \ + class NAME : public Operation { \ + public: \ + NAME(Graph* graph); \ + std::shared_ptr Clone( \ + std::shared_ptr& graph) const override; \ }; DECLARE_RELATIONAL_OP(Greater) diff --git a/include/tim/vx/ops/reorg.h b/include/tim/vx/ops/reorg.h index 22d115e..cd8fe62 100644 --- a/include/tim/vx/ops/reorg.h +++ b/include/tim/vx/ops/reorg.h @@ -39,6 +39,8 @@ class Reorg : public Operation { public: Reorg(Graph* graph, const uint32_t stride); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: uint32_t stride_; }; diff --git a/include/tim/vx/ops/reshape.h b/include/tim/vx/ops/reshape.h index eae55e0..8b6b16b 100644 --- a/include/tim/vx/ops/reshape.h +++ b/include/tim/vx/ops/reshape.h @@ -39,7 +39,9 @@ namespace ops { class Reshape : public Operation { public: - Reshape(Graph* graph, const std::vector& perm); + Reshape(Graph* graph, const std::vector& size); + + std::shared_ptr Clone(std::shared_ptr& graph) const override; protected: std::vector size_; diff --git a/include/tim/vx/ops/resize.h b/include/tim/vx/ops/resize.h index db0bf1b..32e5687 100644 --- a/include/tim/vx/ops/resize.h +++ b/include/tim/vx/ops/resize.h @@ -50,6 +50,8 @@ class Resize : public Operation { bool half_pixel_centers, int target_height, int target_width, DataLayout layout = DataLayout::WHCN); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const ResizeType type_; const float factor_; diff --git a/include/tim/vx/ops/resize1d.h b/include/tim/vx/ops/resize1d.h index 15dae8b..0f76c76 100644 --- a/include/tim/vx/ops/resize1d.h +++ b/include/tim/vx/ops/resize1d.h @@ -50,6 +50,8 @@ class Resize1d : public Operation { bool half_pixel_centers, int target_size, DataLayout layout = DataLayout::WHCN); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const ResizeType type_; const float factor_; diff --git a/include/tim/vx/ops/reverse.h b/include/tim/vx/ops/reverse.h index 69c9bef..731950a 100644 --- a/include/tim/vx/ops/reverse.h +++ b/include/tim/vx/ops/reverse.h @@ -41,6 +41,8 @@ class Reverse : public Operation { public: Reverse(Graph* graph, const std::vector& axis); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const std::vector axis_; }; diff --git a/include/tim/vx/ops/scatternd.h b/include/tim/vx/ops/scatternd.h index ed28ba8..8de7d4b 100644 --- a/include/tim/vx/ops/scatternd.h +++ b/include/tim/vx/ops/scatternd.h @@ -41,6 +41,8 @@ class ScatterND : public Operation { public: ScatterND(Graph* graph, const std::vector& shape); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const std::vector shape_; }; diff --git a/include/tim/vx/ops/select.h b/include/tim/vx/ops/select.h index 8a7ab34..38dfb32 100644 --- a/include/tim/vx/ops/select.h +++ b/include/tim/vx/ops/select.h @@ -39,6 +39,8 @@ namespace ops { class Select : public Operation { public: Select(Graph* graph); + + std::shared_ptr Clone(std::shared_ptr& graph) const override; }; } // namespace ops diff --git a/include/tim/vx/ops/simple_operations.h b/include/tim/vx/ops/simple_operations.h index 3d9d347..715848c 100644 --- a/include/tim/vx/ops/simple_operations.h +++ b/include/tim/vx/ops/simple_operations.h @@ -29,10 +29,12 @@ namespace tim { namespace vx { namespace ops { -#define DECLARE_SIMPLE_OP(NAME) \ - class NAME : public Operation { \ - public: \ - NAME(Graph* graph); \ +#define DECLARE_SIMPLE_OP(NAME) \ + class NAME : public Operation { \ + public: \ + NAME(Graph* graph); \ + std::shared_ptr Clone( \ + std::shared_ptr& graph) const override; \ }; /** diff --git a/include/tim/vx/ops/slice.h b/include/tim/vx/ops/slice.h index 4bacce0..8cc6e31 100644 --- a/include/tim/vx/ops/slice.h +++ b/include/tim/vx/ops/slice.h @@ -45,6 +45,8 @@ class Slice : public Operation { const std::vector& start, const std::vector& length); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: uint32_t dims_; const std::vector start_; diff --git a/include/tim/vx/ops/softmax.h b/include/tim/vx/ops/softmax.h index 54f9425..525ce55 100644 --- a/include/tim/vx/ops/softmax.h +++ b/include/tim/vx/ops/softmax.h @@ -46,6 +46,8 @@ class Softmax : public Operation { public: Softmax(Graph* graph, float beta, int32_t axis); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: float beta_; int32_t axis_; diff --git a/include/tim/vx/ops/space2batch.h b/include/tim/vx/ops/space2batch.h index 298d182..e8e855a 100644 --- a/include/tim/vx/ops/space2batch.h +++ b/include/tim/vx/ops/space2batch.h @@ -52,6 +52,8 @@ class Space2Batch : public Operation { const std::vector& pad, DataLayout layout = DataLayout::WHCN); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: std::vector block_size_; std::vector pad_; diff --git a/include/tim/vx/ops/space2depth.h b/include/tim/vx/ops/space2depth.h index 2ec07bc..832197f 100644 --- a/include/tim/vx/ops/space2depth.h +++ b/include/tim/vx/ops/space2depth.h @@ -43,6 +43,8 @@ class SpaceToDepth : public Operation { SpaceToDepth(Graph* graph, std::vector block_size, DataLayout layout = DataLayout::WHCN); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: std::vector block_size_; }; diff --git a/include/tim/vx/ops/split.h b/include/tim/vx/ops/split.h index 0359aca..ec3ec09 100644 --- a/include/tim/vx/ops/split.h +++ b/include/tim/vx/ops/split.h @@ -44,6 +44,8 @@ class Split : public Operation { public: Split(Graph* graph, uint32_t axis, std::vector slices); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: uint32_t axis_; std::vector slices_; diff --git a/include/tim/vx/ops/squeeze.h b/include/tim/vx/ops/squeeze.h index 1e78832..af06edb 100644 --- a/include/tim/vx/ops/squeeze.h +++ b/include/tim/vx/ops/squeeze.h @@ -42,6 +42,8 @@ class Squeeze : public Operation { public: Squeeze(Graph* graph, std::vector axis); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: std::vector axis_; }; diff --git a/include/tim/vx/ops/stack.h b/include/tim/vx/ops/stack.h index f5bd76d..eefe2c3 100644 --- a/include/tim/vx/ops/stack.h +++ b/include/tim/vx/ops/stack.h @@ -40,6 +40,8 @@ class Stack : public Operation { public: Stack(Graph* graph, uint32_t axis, int input_cnt); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: uint32_t axis_; }; diff --git a/include/tim/vx/ops/stridedslice.h b/include/tim/vx/ops/stridedslice.h index fc2e8e9..d369069 100644 --- a/include/tim/vx/ops/stridedslice.h +++ b/include/tim/vx/ops/stridedslice.h @@ -59,6 +59,8 @@ class StridedSlice : public Operation { const std::vector stride_dims, int32_t begin_mask, int32_t end_mask, int32_t shrink_axis_mask); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: std::vector begin_dims_; std::vector end_dims_; diff --git a/include/tim/vx/ops/tile.h b/include/tim/vx/ops/tile.h index d9e0d39..14c393b 100644 --- a/include/tim/vx/ops/tile.h +++ b/include/tim/vx/ops/tile.h @@ -40,6 +40,9 @@ namespace ops { class Tile : public Operation { public: Tile(Graph* graph, const std::vector& multiples); + + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: const std::vector multiples_; }; diff --git a/include/tim/vx/ops/transpose.h b/include/tim/vx/ops/transpose.h index ead3f7b..6cc6176 100644 --- a/include/tim/vx/ops/transpose.h +++ b/include/tim/vx/ops/transpose.h @@ -45,6 +45,8 @@ class Transpose : public Operation { public: Transpose(Graph* graph, const std::vector& perm); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: std::vector perm_; }; diff --git a/include/tim/vx/ops/unstack.h b/include/tim/vx/ops/unstack.h index c01eaea..73e1d77 100644 --- a/include/tim/vx/ops/unstack.h +++ b/include/tim/vx/ops/unstack.h @@ -41,6 +41,8 @@ class Unstack : public Operation { public: Unstack(Graph* graph, int32_t axis, uint32_t output_num); + std::shared_ptr Clone(std::shared_ptr& graph) const override; + protected: int32_t axis_; }; diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index e2de993..447e00c 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -31,7 +31,6 @@ #include "ops/elementwise_layout_inference.h" #include "ops/activation_layout_inference.h" #include "ops/concat_layout_inferene.h" -#include "ops/reshape_layout_inference.h" #include "ops/simple_ops_layout_inference.h" #include "ops/pool2d_layout_inference.h" #include "ops/softmax_layout_inference.h" @@ -58,7 +57,7 @@ #include "ops/logical_layout_inference.h" #include "ops/arg_layout_inference.h" #include "ops/deconv2d_layout_inference.h" -#include "ops/nbg_layout_inference.h" +#include "ops/default_layout_inference.h" #include #include @@ -216,7 +215,6 @@ std::vector> HandleLayoutInfer( REGIST_LAYOUT_INFERENCE(VSI_NN_OP_POW, Pow); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_MINIMUM, Minimum); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_MAXIMUM, Maximum); - REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RESHAPE, Reshape); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DATACONVERT, DataConvert); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_NEG, Neg); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ABS, Abs); @@ -233,8 +231,8 @@ std::vector> HandleLayoutInfer( REGIST_LAYOUT_INFERENCE(VSI_NN_OP_STACK, Stack); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SPACE2DEPTH, SpaceToDepth); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DEPTH2SPACE, DepthToSpace); - REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SPACE2BATCH, SpaceToBatch); - REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BATCH2SPACE, BatchToSpace); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SPACE2BATCH, Space2Batch); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BATCH2SPACE, Batch2Space); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_PAD, Pad); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_FCL2, FullyConnected); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RESIZE, Resize); @@ -249,15 +247,18 @@ std::vector> HandleLayoutInfer( REGIST_LAYOUT_INFERENCE(VSI_NN_OP_REVERSE, Reverse); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SLICE, Slice); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SELECT, Select); - REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ARGMAX, ArgMax); - REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ARGMIN, ArgMin); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ARGMAX, Arg); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ARGMIN, Arg); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DECONVOLUTION, DeConv2d); - REGIST_LAYOUT_INFERENCE(VSI_NN_OP_NBG, Nbg); REGIST_LOGICAL_LAYOUT_INFERENCE(VSI_NN_OP_LOGICAL_OPS); REGIST_REDUCE_LAYOUT_INFERENCE(VSI_NN_OP_REDUCE); - default: - VSILOGW("Op %d: not support layout inference.", op_id); - assert(false); + // use default layout inference + default: { + VSILOGW("Op %d: default layout inference pass.", op_id); + auto op_infer = std::make_shared(op, ctx); + op_infer->OnInputs(next_tensors); + op_infer->OnOutputs(next_tensors); + } } return next_tensors; } diff --git a/src/tim/transform/ops/addn_layout_inference.h b/src/tim/transform/ops/addn_layout_inference.h index aaabdc7..1ace581 100644 --- a/src/tim/transform/ops/addn_layout_inference.h +++ b/src/tim/transform/ops/addn_layout_inference.h @@ -36,13 +36,12 @@ class AddNLayoutInfer : public OpLayoutInfer { const std::shared_ptr& op, std::shared_ptr& context) : OpLayoutInfer(op, context) {} + void OnInputs( std::vector>& next_tensors) override { auto required_pv = AlignPermuteVectorForMutilInputs(); - uint32_t num_inputs = op_->impl()->input_cnt_; - auto addn = - context_->infer_graph_->CreateOperation(num_inputs); + auto addn = op_->Clone(context_->infer_graph_); for (const auto& i_src : op_->impl()->InputsTensor()) { (*addn).BindInput(context_->GetMapedTensor(i_src)); diff --git a/src/tim/transform/ops/arg_layout_inference.h b/src/tim/transform/ops/arg_layout_inference.h index eb0a745..3c69f37 100644 --- a/src/tim/transform/ops/arg_layout_inference.h +++ b/src/tim/transform/ops/arg_layout_inference.h @@ -29,9 +29,9 @@ #include "tim/vx/ops/arg.h" namespace tim { namespace transform { -class ArgMaxLayoutInfer : public OpLayoutInfer { +class ArgLayoutInfer : public OpLayoutInfer { public: - ArgMaxLayoutInfer( + ArgLayoutInfer( const std::shared_ptr op, std::shared_ptr& context) : OpLayoutInfer(op, context) {} @@ -43,40 +43,10 @@ class ArgMaxLayoutInfer : public OpLayoutInfer { auto src_input = op_->impl()->InputsTensor()[0]; auto input_pv = context_->GetPermuteVector(src_input); - uint32_t axis = op_->impl()->node()->nn_param.argmax.axis; - - auto argmax = - context_->infer_graph_->CreateOperation(axis); + auto arg = op_->Clone(context_->infer_graph_); auto infer_out = CreateOutputsTensor(input_pv); - (*argmax).BindInput(context_->GetMapedTensor(src_input)); - (*argmax).BindOutput(infer_out[0]); - - context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv); - next_tensors.push_back(op_->impl()->OutputsTensor()[0]); - } -}; - -class ArgMinLayoutInfer : public OpLayoutInfer { - public: - ArgMinLayoutInfer( - const std::shared_ptr op, - std::shared_ptr& context) - : OpLayoutInfer(op, context) {} - - void OnInputs( - std::vector>& next_tensors) override { - ReverseInputsPermuteVector(); - assert(1 == op_->impl()->InputsTensor().size()); - auto src_input = op_->impl()->InputsTensor()[0]; - auto input_pv = context_->GetPermuteVector(src_input); - - uint32_t axis = op_->impl()->node()->nn_param.argmin.axis; - - auto argmin = - context_->infer_graph_->CreateOperation(axis); - auto infer_out = CreateOutputsTensor(input_pv); - (*argmin).BindInput(context_->GetMapedTensor(src_input)); - (*argmin).BindOutput(infer_out[0]); + (*arg).BindInput(context_->GetMapedTensor(src_input)); + (*arg).BindOutput(infer_out[0]); context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv); next_tensors.push_back(op_->impl()->OutputsTensor()[0]); diff --git a/src/tim/transform/ops/batch2space_layout_inference.h b/src/tim/transform/ops/batch2space_layout_inference.h index 85e952d..5ae1b35 100644 --- a/src/tim/transform/ops/batch2space_layout_inference.h +++ b/src/tim/transform/ops/batch2space_layout_inference.h @@ -31,9 +31,9 @@ #include "src/tim/vx/operation_private.h" namespace tim { namespace transform { -class BatchToSpaceLayoutInfer : public OpLayoutInfer { +class Batch2SpaceLayoutInfer : public OpLayoutInfer { public: - BatchToSpaceLayoutInfer( + Batch2SpaceLayoutInfer( const std::shared_ptr op, std::shared_ptr& context) : OpLayoutInfer(op, context) {} diff --git a/src/tim/transform/ops/conv2d_layout_inference.h b/src/tim/transform/ops/conv2d_layout_inference.h index 1c47977..87b753a 100644 --- a/src/tim/transform/ops/conv2d_layout_inference.h +++ b/src/tim/transform/ops/conv2d_layout_inference.h @@ -32,7 +32,6 @@ namespace tim { namespace transform { - class Conv2dLayoutInfer : public OpLayoutInfer { public: Conv2dLayoutInfer( diff --git a/src/tim/transform/ops/reshape_layout_inference.h b/src/tim/transform/ops/default_layout_inference.h similarity index 78% rename from src/tim/transform/ops/reshape_layout_inference.h rename to src/tim/transform/ops/default_layout_inference.h index 743e4a2..09252cc 100644 --- a/src/tim/transform/ops/reshape_layout_inference.h +++ b/src/tim/transform/ops/default_layout_inference.h @@ -21,10 +21,15 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ -#ifndef TIM_LAYOUT_INFER_RESHAPE_LAYOUT_INFERENCE_H_ -#define TIM_LAYOUT_INFER_RESHAPE_LAYOUT_INFERENCE_H_ +#ifndef TIM_LAYOUT_INFER_DEFAULT_LAYOUT_INFERENCE_H_ +#define TIM_LAYOUT_INFER_DEFAULT_LAYOUT_INFERENCE_H_ #include "tim/vx/ops/reshape.h" +#include "tim/vx/ops/nbg.h" +#include "tim/vx/ops/transpose.h" +#include "tim/vx/ops/batchnorm.h" +#include "tim/vx/ops/clip.h" + #include "src/tim/transform/ops/op_layout_inference.h" #include "src/tim/transform/permute_vector.h" @@ -32,30 +37,30 @@ namespace tim { namespace transform { -class ReshapeLayoutInfer : public OpLayoutInfer { + +class DefaultLayoutInfer : public OpLayoutInfer { public: - ReshapeLayoutInfer( + DefaultLayoutInfer( const std::shared_ptr op, std::shared_ptr& context) : OpLayoutInfer(op, context) {} + // reverse any applied permute on it's input tensor void OnInputs( std::vector>& next_tensors) override { ReverseInputsPermuteVector(); - std::vector perm; - for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reshape.dim_num; - i++) { - perm.push_back(op_->impl()->node()->nn_param.reshape.size[i]); - } - auto reshape = - context_->infer_graph_->CreateOperation(perm); - (*reshape).BindInput( - context_->GetMapedTensor(op_->impl()->InputsTensor()[0])); + auto cloned_op = op_->Clone(context_->infer_graph_); + + for (const auto& i_src : op_->impl()->InputsTensor()) { + (*cloned_op).BindInput(context_->GetMapedTensor(i_src)); + } auto required_pv = MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size()); auto out_infer = CreateOutputsTensor(required_pv); - (*reshape).BindOutput(out_infer[0]); + + // TODO: bind all output + (*cloned_op).BindOutputs(out_infer); context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv); next_tensors.push_back(op_->impl()->OutputsTensor()[0]); } diff --git a/src/tim/transform/ops/fullyconnected_layout_inference.h b/src/tim/transform/ops/fullyconnected_layout_inference.h index c329630..3708dfb 100644 --- a/src/tim/transform/ops/fullyconnected_layout_inference.h +++ b/src/tim/transform/ops/fullyconnected_layout_inference.h @@ -41,7 +41,7 @@ class FullyConnectedLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { - + auto input_tensors = op_->impl()->InputsTensor(); for (const auto& in : input_tensors) { if (in->IsConstTensor()) { diff --git a/src/tim/transform/ops/nbg_layout_inference.h b/src/tim/transform/ops/nbg_layout_inference.h deleted file mode 100644 index cdc531e..0000000 --- a/src/tim/transform/ops/nbg_layout_inference.h +++ /dev/null @@ -1,68 +0,0 @@ -/**************************************************************************** - * - * Copyright (c) 2020 Vivante Corporation - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and associated documentation files (the "Software"), - * to deal in the Software without restriction, including without limitation - * the rights to use, copy, modify, merge, publish, distribute, sublicense, - * and/or sell copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER - * DEALINGS IN THE SOFTWARE. - * - *****************************************************************************/ -#ifndef TIM_LAYOUT_INFER_DEFAULT_LAYOUT_INFERENCE_H_ -#define TIM_LAYOUT_INFER_DEFAULT_LAYOUT_INFERENCE_H_ - -#include "tim/vx/ops/nbg.h" - -#include "src/tim/transform/ops/op_layout_inference.h" -#include "src/tim/vx/operation_private.h" - -namespace tim { -namespace transform { -class NbgLayoutInfer : public OpLayoutInfer { - public: - NbgLayoutInfer( - const std::shared_ptr op, - std::shared_ptr& context) - : OpLayoutInfer(op, context) {} - // reverse any applied permute on it's input tensor - void OnInputs( - std::vector>& next_tensors) override { - ReverseInputsPermuteVector(); - - auto url = op_->impl()->node()->nn_param.nbg.url; - uint32_t input_count = op_->impl()->input_cnt_; - uint32_t output_count = op_->impl()->output_cnt_; - auto nbg = context_->infer_graph_->CreateOperation( - url, input_count, output_count); - - for (auto i_src : op_->impl()->InputsTensor()) { - (*nbg).BindInput(context_->GetMapedTensor(i_src)); - auto input_pv = MakeShared(i_src->GetShape().size()); - context_->SetPermuteVector(i_src, input_pv); - } - auto infer_out = CreateOutputsTensor(MakeShared(1)); - (*nbg).BindOutputs(infer_out); - for (const auto& out : op_->impl()->OutputsTensor()) { - context_->SetPermuteVector(out, MakeShared(out->GetShape().size())); - next_tensors.push_back(out); - } - } -}; - -} // namespace transform -} // namespace tim - -#endif \ No newline at end of file diff --git a/src/tim/transform/ops/pool2d_layout_inference.h b/src/tim/transform/ops/pool2d_layout_inference.h index f41e0eb..aea557c 100644 --- a/src/tim/transform/ops/pool2d_layout_inference.h +++ b/src/tim/transform/ops/pool2d_layout_inference.h @@ -31,7 +31,6 @@ namespace tim { namespace transform { - class Pool2dLayoutInfer : public OpLayoutInfer { public: Pool2dLayoutInfer( diff --git a/src/tim/transform/ops/softmax_layout_inference.h b/src/tim/transform/ops/softmax_layout_inference.h index 8557985..2aa798e 100644 --- a/src/tim/transform/ops/softmax_layout_inference.h +++ b/src/tim/transform/ops/softmax_layout_inference.h @@ -32,7 +32,6 @@ namespace tim { namespace transform { - class SoftmaxLayoutInfer : public OpLayoutInfer { public: SoftmaxLayoutInfer( diff --git a/src/tim/transform/ops/space2batch_layout_inference.h b/src/tim/transform/ops/space2batch_layout_inference.h index 2a99b05..2b25720 100644 --- a/src/tim/transform/ops/space2batch_layout_inference.h +++ b/src/tim/transform/ops/space2batch_layout_inference.h @@ -31,9 +31,9 @@ #include "src/tim/vx/operation_private.h" namespace tim { namespace transform { -class SpaceToBatchLayoutInfer : public OpLayoutInfer { +class Space2BatchLayoutInfer : public OpLayoutInfer { public: - SpaceToBatchLayoutInfer( + Space2BatchLayoutInfer( const std::shared_ptr op, std::shared_ptr& context) : OpLayoutInfer(op, context) {} diff --git a/src/tim/transform/ops/stack_layout_inference.h b/src/tim/transform/ops/stack_layout_inference.h index 7df6dba..5df8a7d 100644 --- a/src/tim/transform/ops/stack_layout_inference.h +++ b/src/tim/transform/ops/stack_layout_inference.h @@ -32,7 +32,6 @@ namespace tim { namespace transform { - class StackLayoutInfer : public OpLayoutInfer { public: StackLayoutInfer( diff --git a/src/tim/vx/ops/activations.cc b/src/tim/vx/ops/activations.cc index 1a84409..16439e9 100644 --- a/src/tim/vx/ops/activations.cc +++ b/src/tim/vx/ops/activations.cc @@ -30,8 +30,12 @@ namespace tim { namespace vx { namespace ops { -#define DEFINE_NO_PARAMETER_ACTIVATION(NAME, VSI_OP_CODE) \ - NAME::NAME(Graph* graph) : Operation(graph, VSI_OP_CODE) {} +#define DEFINE_NO_PARAMETER_ACTIVATION(NAME, VSI_OP_CODE) \ + NAME::NAME(Graph* graph) : Operation(graph, VSI_OP_CODE) {} \ + std::shared_ptr NAME::Clone(std::shared_ptr& graph) \ + const { \ + return graph->CreateOperation(); \ + } DEFINE_NO_PARAMETER_ACTIVATION(Relu, VSI_NN_OP_RELU) DEFINE_NO_PARAMETER_ACTIVATION(Relu1, VSI_NN_OP_RELU1) @@ -50,27 +54,49 @@ HardSwish::HardSwish(Graph* graph) : Operation(graph, VSI_NN_OP_SWISH) { this->impl()->node()->nn_param.swish.beta = 1.0f; } +std::shared_ptr HardSwish::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(); +} + Prelu::Prelu(Graph* graph, int axis) : Operation(graph, VSI_NN_OP_PRELU), axis_(axis) { this->impl()->node()->nn_param.prelu.axis = axis_; } +std::shared_ptr Prelu::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->axis_); +} + Tanh::Tanh(Graph* graph) : Operation(graph, VSI_NN_OP_TANH) { this->impl()->node()->nn_param.tanh.scale_a = 1.0; this->impl()->node()->nn_param.tanh.scale_b = 1.0; } +std::shared_ptr Tanh::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(); +} + LeakyRelu::LeakyRelu(Graph* graph, float alpha) : Operation(graph, VSI_NN_OP_LEAKY_RELU), alpha_(alpha) { this->impl()->node()->nn_param.activation.leaky_ratio = alpha_; } +std::shared_ptr LeakyRelu::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(this->alpha_); +} + Linear::Linear(Graph* graph, float a, float b) : Operation(graph, VSI_NN_OP_LINEAR), a_(a), b_(b) { this->impl()->node()->nn_param.linear.a = a_; this->impl()->node()->nn_param.linear.b = b_; } +std::shared_ptr Linear::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->a_, this->b_); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/addn.cc b/src/tim/vx/ops/addn.cc index 1428ebc..a8ae283 100644 --- a/src/tim/vx/ops/addn.cc +++ b/src/tim/vx/ops/addn.cc @@ -33,6 +33,10 @@ namespace ops { AddN::AddN(Graph* graph, uint32_t num_inputs) : Operation(graph, VSI_NN_OP_ADDN, num_inputs, 1) {} +std::shared_ptr AddN::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->impl_->input_cnt_); +}; + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/arg.cc b/src/tim/vx/ops/arg.cc index b71791f..9d1e586 100644 --- a/src/tim/vx/ops/arg.cc +++ b/src/tim/vx/ops/arg.cc @@ -31,10 +31,14 @@ namespace tim { namespace vx { namespace ops { -#define DEFINE_ARG_OP(NAME, VSI_OP_TYPE, OP_PARAM) \ - Arg##NAME::Arg##NAME(Graph* graph, int32_t axis) \ - : Operation(graph, VSI_NN_OP_ARG##VSI_OP_TYPE), axis_(axis) { \ - this->impl()->node()->nn_param.arg##OP_PARAM.axis = axis_; \ +#define DEFINE_ARG_OP(NAME, VSI_OP_TYPE, OP_PARAM) \ + Arg##NAME::Arg##NAME(Graph* graph, int32_t axis) \ + : Operation(graph, VSI_NN_OP_ARG##VSI_OP_TYPE), axis_(axis) { \ + this->impl()->node()->nn_param.arg##OP_PARAM.axis = axis_; \ + } \ + std::shared_ptr Arg##NAME::Clone(std::shared_ptr& graph) \ + const { \ + return graph->CreateOperation(this->axis_); \ } DEFINE_ARG_OP(Max, MAX, max); diff --git a/src/tim/vx/ops/batch2space.cc b/src/tim/vx/ops/batch2space.cc index df74766..1d3f07c 100644 --- a/src/tim/vx/ops/batch2space.cc +++ b/src/tim/vx/ops/batch2space.cc @@ -42,6 +42,13 @@ Batch2Space::Batch2Space(Graph* graph, const std::vector& block_size, this->impl()->node()->nn_param.batch2space.crop[i] = crop_[i]; } } + +std::shared_ptr Batch2Space::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(this->block_size_, this->crop_, + this->impl_->layout_); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/batchnorm.cc b/src/tim/vx/ops/batchnorm.cc index 21a5c41..fc221cc 100644 --- a/src/tim/vx/ops/batchnorm.cc +++ b/src/tim/vx/ops/batchnorm.cc @@ -23,21 +23,24 @@ *****************************************************************************/ #include "tim/vx/ops/batchnorm.h" -#include "vsi_nn_pub.h" - #include "operation_private.h" +#include "vsi_nn_pub.h" namespace tim { namespace vx { namespace ops { - BatchNorm::BatchNorm(Graph* graph, float eps) - : Operation(graph, VSI_NN_OP_BATCH_NORM), - eps_(eps) { + : Operation(graph, VSI_NN_OP_BATCH_NORM), eps_(eps) { this->impl()->node()->nn_param.batch_norm.eps = eps_; } +std::shared_ptr BatchNorm::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation( + this->impl_->node_->nn_param.batch_norm.eps); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/clip.cc b/src/tim/vx/ops/clip.cc index b76d448..af51aca 100644 --- a/src/tim/vx/ops/clip.cc +++ b/src/tim/vx/ops/clip.cc @@ -40,6 +40,11 @@ Clip::Clip(Graph* graph, float min, float max) this->impl()->node()->nn_param.clip.max = max_; } +std::shared_ptr Clip::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->impl_->node()->nn_param.clip.min, + this->impl_->node_->nn_param.clip.max); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/concat.cc b/src/tim/vx/ops/concat.cc index 18e5337..86bda04 100644 --- a/src/tim/vx/ops/concat.cc +++ b/src/tim/vx/ops/concat.cc @@ -35,6 +35,10 @@ Concat::Concat(Graph* graph, uint32_t axis, int input_cnt) this->impl()->node()->nn_param.concat.axis = axis_; } +std::shared_ptr Concat::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->axis_, this->impl_->input_cnt_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/conv1d.cc b/src/tim/vx/ops/conv1d.cc index 96a60f8..413be40 100644 --- a/src/tim/vx/ops/conv1d.cc +++ b/src/tim/vx/ops/conv1d.cc @@ -72,6 +72,13 @@ Conv1d::Conv1d(Graph* graph, int32_t weights, PadType padding, this->impl()->node()->nn_param.conv1d.multiplier = multiplier_; } +std::shared_ptr Conv1d::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation( + this->weights_, this->padding_, this->ksize_, this->stride_, + this->dilation_, this->pad_, this->multiplier_, this->impl_->layout_, + this->kernel_layout_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/conv2d.cc b/src/tim/vx/ops/conv2d.cc index e0f6d9c..66faf5b 100644 --- a/src/tim/vx/ops/conv2d.cc +++ b/src/tim/vx/ops/conv2d.cc @@ -81,6 +81,13 @@ Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding, this->impl()->node()->nn_param.conv2d.multiplier = multiplier_; } +std::shared_ptr Conv2d::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation( + this->weights_, this->padding_, this->ksize_, this->stride_, + this->dilation_, this->pad_, this->multiplier_, this->impl_->layout_, + this->kernel_layout_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/deconv.cc b/src/tim/vx/ops/deconv.cc index ca1f554..7704555 100644 --- a/src/tim/vx/ops/deconv.cc +++ b/src/tim/vx/ops/deconv.cc @@ -77,6 +77,14 @@ DeConv2d::DeConv2d(Graph* graph, int32_t oc_count, PadType pad_type, this->impl()->node()->nn_param.deconv.pad[3] = pad_[3]; } +std::shared_ptr DeConv2d::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation( + this->oc_count_, this->pad_type_, this->ksize_, this->stride_, + this->output_padding_, this->pad_, this->group_, this->impl_->layout_, + this->kernel_layout_); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/deconv1d.cc b/src/tim/vx/ops/deconv1d.cc index d9be3ac..f833e99 100644 --- a/src/tim/vx/ops/deconv1d.cc +++ b/src/tim/vx/ops/deconv1d.cc @@ -85,6 +85,12 @@ DeConv1d::DeConv1d(Graph* graph, PadType pad_type, this->impl()->node()->nn_param.deconvolution1d.pad[1] = pad_[1]; } +std::shared_ptr DeConv1d::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation( + this->pad_type_, this->stride_, this->output_padding_, this->pad_, + this->group_, this->impl_->layout_, this->kernel_layout_); +} } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/depth2space.cc b/src/tim/vx/ops/depth2space.cc index eda6bb6..6f35494 100644 --- a/src/tim/vx/ops/depth2space.cc +++ b/src/tim/vx/ops/depth2space.cc @@ -35,6 +35,13 @@ DepthToSpace::DepthToSpace(Graph* graph, int block_size, DataLayout layout) block_size_(block_size) { this->impl()->node()->nn_param.depth2space.block_size = block_size_; } + +std::shared_ptr DepthToSpace::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(this->block_size_, + this->impl_->layout_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/dropout.cc b/src/tim/vx/ops/dropout.cc index 598cfbe..8869846 100644 --- a/src/tim/vx/ops/dropout.cc +++ b/src/tim/vx/ops/dropout.cc @@ -38,6 +38,10 @@ Dropout::Dropout(Graph* graph, float ratio) this->impl()->node()->nn_param.dropout.ratio = ratio_; } +std::shared_ptr Dropout::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->ratio_); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/elementwise.cc b/src/tim/vx/ops/elementwise.cc index 07b22a8..98abcad 100644 --- a/src/tim/vx/ops/elementwise.cc +++ b/src/tim/vx/ops/elementwise.cc @@ -30,8 +30,12 @@ namespace tim { namespace vx { namespace ops { -#define DEFINE_ELEMENTWISE_OP(NAME, VSI_OP_CODE) \ - NAME::NAME(Graph* graph) : Operation(graph, VSI_OP_CODE, 2, 1) {} +#define DEFINE_ELEMENTWISE_OP(NAME, VSI_OP_CODE) \ + NAME::NAME(Graph* graph) : Operation(graph, VSI_OP_CODE, 2, 1) {} \ + std::shared_ptr NAME::Clone(std::shared_ptr& graph) \ + const { \ + return graph->CreateOperation(); \ + } DEFINE_ELEMENTWISE_OP(Minimum, VSI_NN_OP_MINIMUM) DEFINE_ELEMENTWISE_OP(Maximum, VSI_NN_OP_MAXIMUM) @@ -48,6 +52,12 @@ Multiply::Multiply(Graph* graph, float scale) this->impl()->node()->nn_param.multiply.scale = scale; } +std::shared_ptr Multiply::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation( + this->impl_->node_->nn_param.multiply.scale); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/fullyconnected.cc b/src/tim/vx/ops/fullyconnected.cc index 16088b9..2b296d8 100644 --- a/src/tim/vx/ops/fullyconnected.cc +++ b/src/tim/vx/ops/fullyconnected.cc @@ -35,11 +35,16 @@ FullyConnected::FullyConnected(Graph* graph, uint32_t axis) } FullyConnected::FullyConnected(Graph* graph, uint32_t axis, uint32_t weights) - : Operation(graph, VSI_NN_OP_FCL2) { + : Operation(graph, VSI_NN_OP_FCL2), axis_(axis), weights_(weights) { this->impl()->node()->nn_param.fcl.axis = axis; this->impl()->node()->nn_param.fcl.weights = weights; } +std::shared_ptr FullyConnected::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(this->axis_, this->weights_); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/gather.cc b/src/tim/vx/ops/gather.cc index e1d9e6f..f0ef74e 100644 --- a/src/tim/vx/ops/gather.cc +++ b/src/tim/vx/ops/gather.cc @@ -34,6 +34,11 @@ Gather::Gather(Graph* graph, int axis) : Operation(graph, VSI_NN_OP_GATHER), axis_(axis) { this->impl()->node()->nn_param.gather.axis = axis_; } + +std::shared_ptr Gather::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->axis_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/gathernd.cc b/src/tim/vx/ops/gathernd.cc index 8457e06..3428891 100644 --- a/src/tim/vx/ops/gathernd.cc +++ b/src/tim/vx/ops/gathernd.cc @@ -31,6 +31,11 @@ namespace vx { namespace ops { GatherNd::GatherNd(Graph* graph) : Operation(graph, VSI_NN_OP_GATHER_ND) {} + +std::shared_ptr GatherNd::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/groupedconv2d.cc b/src/tim/vx/ops/groupedconv2d.cc index 60e510e..f66d0a4 100644 --- a/src/tim/vx/ops/groupedconv2d.cc +++ b/src/tim/vx/ops/groupedconv2d.cc @@ -65,6 +65,13 @@ GroupedConv2d::GroupedConv2d(Graph* graph, this->impl()->node()->nn_param.conv2d.dilation[1] = dilation_[1]; } +std::shared_ptr GroupedConv2d::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation( + this->pad_, this->strides_, this->dilation_, this->group_number_, + this->impl_->layout_, this->kernel_layout_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/instancenormalization.cc b/src/tim/vx/ops/instancenormalization.cc index 45b0d36..0fd4423 100644 --- a/src/tim/vx/ops/instancenormalization.cc +++ b/src/tim/vx/ops/instancenormalization.cc @@ -34,6 +34,11 @@ InstanceNormalization::InstanceNormalization(Graph* graph, float eps) this->impl()->node()->nn_param.instancenorm.eps = eps_; } +std::shared_ptr InstanceNormalization::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(this->eps_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/l2normalization.cc b/src/tim/vx/ops/l2normalization.cc index 4b882a7..ec4d1b9 100644 --- a/src/tim/vx/ops/l2normalization.cc +++ b/src/tim/vx/ops/l2normalization.cc @@ -34,6 +34,11 @@ L2Normalization::L2Normalization(Graph* graph, int32_t axis) this->impl()->node()->nn_param.l2_normalize.axis = axis_; } +std::shared_ptr L2Normalization::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(this->axis_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/layernormalization.cc b/src/tim/vx/ops/layernormalization.cc index 2cb20a7..19509c4 100644 --- a/src/tim/vx/ops/layernormalization.cc +++ b/src/tim/vx/ops/layernormalization.cc @@ -40,6 +40,11 @@ LayerNormalization::LayerNormalization(Graph* graph, int32_t axis, float eps) this->impl()->node()->nn_param.instancenorm.eps = eps_; } +std::shared_ptr LayerNormalization::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(this->axis_, this->eps_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/localresponsenormalization.cc b/src/tim/vx/ops/localresponsenormalization.cc index a4bcaa2..151ba9b 100644 --- a/src/tim/vx/ops/localresponsenormalization.cc +++ b/src/tim/vx/ops/localresponsenormalization.cc @@ -47,6 +47,13 @@ LocalResponseNormalization::LocalResponseNormalization(Graph* graph, this->impl()->node()->nn_param.lrn.type = VX_CONVOLUTIONAL_NETWORK_NORM_ACROSS_MAPS; } + +std::shared_ptr LocalResponseNormalization::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation( + this->size_, this->alpha_, this->beta_, this->bias_, this->axis_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/logical.cc b/src/tim/vx/ops/logical.cc index 0924570..4864dde 100644 --- a/src/tim/vx/ops/logical.cc +++ b/src/tim/vx/ops/logical.cc @@ -35,6 +35,10 @@ namespace ops { : Operation(graph, VSI_NN_OP_LOGICAL_OPS) { \ this->impl()->node()->nn_param.relational_ops.op = \ VSI_NN_LOGICAL_##VSI_OP_CODE; \ + } \ + std::shared_ptr Logical##NAME::Clone( \ + std::shared_ptr& graph) const { \ + return graph->CreateOperation(); \ } DEFINE_LOGICAL_OP(And, AND); diff --git a/src/tim/vx/ops/logsoftmax.cc b/src/tim/vx/ops/logsoftmax.cc index 5ea9130..8d523d0 100644 --- a/src/tim/vx/ops/logsoftmax.cc +++ b/src/tim/vx/ops/logsoftmax.cc @@ -36,6 +36,11 @@ LogSoftmax::LogSoftmax(Graph* graph, int32_t axis, float beta) this->impl()->node()->nn_param.log_softmax.axis = axis_; } +std::shared_ptr LogSoftmax::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(this->axis_, this->beta_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/matmul.cc b/src/tim/vx/ops/matmul.cc index b062e64..e02e407 100644 --- a/src/tim/vx/ops/matmul.cc +++ b/src/tim/vx/ops/matmul.cc @@ -41,6 +41,11 @@ Matmul::Matmul(Graph* graph, bool transpose_a, bool transpose_b, this->impl()->node()->nn_param.matrixmul.adjoint[1] = ToVxBool(adjoint_b_); } +std::shared_ptr Matmul::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->transpose_a_, this->transpose_b_, + this->adjoint_a_, this->adjoint_b_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/maxpoolwithargmax.cc b/src/tim/vx/ops/maxpoolwithargmax.cc index cd71161..f2126de 100644 --- a/src/tim/vx/ops/maxpoolwithargmax.cc +++ b/src/tim/vx/ops/maxpoolwithargmax.cc @@ -52,6 +52,13 @@ MaxpoolWithArgmax::MaxpoolWithArgmax(Graph* graph, PadType padding, this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_); } +std::shared_ptr MaxpoolWithArgmax::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation( + this->padding_, this->ksize_, this->stride_, this->round_type_, + this->impl_->layout_); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/maxunpool2d.cc b/src/tim/vx/ops/maxunpool2d.cc index a50e198..29ca625 100644 --- a/src/tim/vx/ops/maxunpool2d.cc +++ b/src/tim/vx/ops/maxunpool2d.cc @@ -41,6 +41,12 @@ MaxUnpool2d::MaxUnpool2d(Graph* graph, const std::array& ksize, this->impl()->node()->nn_param.upsample.size[1] = ksize_[1]; } +std::shared_ptr MaxUnpool2d::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(this->ksize_, this->stride_, + this->impl_->layout_); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/mements.cc b/src/tim/vx/ops/mements.cc index f4670e4..b0f0229 100644 --- a/src/tim/vx/ops/mements.cc +++ b/src/tim/vx/ops/mements.cc @@ -30,12 +30,17 @@ namespace tim { namespace vx { namespace ops { - Moments::Moments(Graph* graph, const std::vector& axes, bool keep_dims) - : Operation(graph, VSI_NN_OP_MOMENTS), axes_(axes), keep_dims_(keep_dims) { - this->impl()->node()->nn_param.moments.axis = axes_.data(); - this->impl()->node()->nn_param.moments.axis_num = axes_.size(); - this->impl()->node()->nn_param.moments.keep_dim = ToVxBool(keep_dims_); - } +Moments::Moments(Graph* graph, const std::vector& axes, bool keep_dims) + : Operation(graph, VSI_NN_OP_MOMENTS), axes_(axes), keep_dims_(keep_dims) { + this->impl()->node()->nn_param.moments.axis = axes_.data(); + this->impl()->node()->nn_param.moments.axis_num = axes_.size(); + this->impl()->node()->nn_param.moments.keep_dim = ToVxBool(keep_dims_); +} + +std::shared_ptr Moments::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->axes_, this->keep_dims_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/nbg.cc b/src/tim/vx/ops/nbg.cc index 1e60359..00e3787 100644 --- a/src/tim/vx/ops/nbg.cc +++ b/src/tim/vx/ops/nbg.cc @@ -29,10 +29,19 @@ namespace tim { namespace vx { namespace ops { - NBG::NBG(Graph* graph, const char* binary, size_t input_count, size_t output_count) : Operation(graph, VSI_NN_OP_NBG, input_count, output_count) { - this->impl()->node()->nn_param.nbg.url = binary; - this->impl()->node()->nn_param.nbg.type = VSI_NN_NBG_POINTER; - } +NBG::NBG(Graph* graph, const char* binary, size_t input_count, + size_t output_count) + : Operation(graph, VSI_NN_OP_NBG, input_count, output_count) { + this->impl()->node()->nn_param.nbg.url = binary; + this->impl()->node()->nn_param.nbg.type = VSI_NN_NBG_POINTER; +} + +std::shared_ptr NBG::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->impl_->node_->nn_param.nbg.url, + this->impl_->input_cnt_, + this->impl_->output_cnt_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/pad.cc b/src/tim/vx/ops/pad.cc index eb17e39..46e8d9c 100644 --- a/src/tim/vx/ops/pad.cc +++ b/src/tim/vx/ops/pad.cc @@ -41,6 +41,11 @@ Pad::Pad(Graph* graph, const std::vector& front_size, this->impl()->node()->nn_param.pad.const_val = const_val_; this->impl()->node()->nn_param.pad.mode = VSI_NN_PAD_MODE_CONSTANT; } + +std::shared_ptr Pad::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->front_size_, this->back_size_, this->const_val_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/pool2d.cc b/src/tim/vx/ops/pool2d.cc index b63a5a2..7ae5d38 100644 --- a/src/tim/vx/ops/pool2d.cc +++ b/src/tim/vx/ops/pool2d.cc @@ -75,6 +75,12 @@ Pool2d::Pool2d(Graph* graph, PoolType type, this->SetRoundingPolicy(OverflowPolicy::SATURATE, RoundingPolicy::RTNE, round_type_); } +std::shared_ptr Pool2d::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->type_, this->pad_, this->ksize_, + this->stride_, this->round_type_, + this->impl_->layout_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/reduce.cc b/src/tim/vx/ops/reduce.cc index 11cdf9e..8669efa 100644 --- a/src/tim/vx/ops/reduce.cc +++ b/src/tim/vx/ops/reduce.cc @@ -40,6 +40,11 @@ namespace ops { this->impl()->node()->nn_param.reduce.axis = axis_.data(); \ this->impl()->node()->nn_param.reduce.axis_num = axis_.size(); \ this->impl()->node()->nn_param.reduce.keep_dim = keep_dims_; \ + } \ + std::shared_ptr Reduce##NAME::Clone( \ + std::shared_ptr& graph) const { \ + return graph->CreateOperation(this->axis_, \ + this->keep_dims_); \ } DEFINE_REDUCE_OP(Min, VSI_NN_REDUCE_MIN); diff --git a/src/tim/vx/ops/relational_operations.cc b/src/tim/vx/ops/relational_operations.cc index 8b0dc73..6993a0d 100644 --- a/src/tim/vx/ops/relational_operations.cc +++ b/src/tim/vx/ops/relational_operations.cc @@ -30,9 +30,14 @@ namespace tim { namespace vx { namespace ops { -#define DEFINE_RELATIONAL_OP(NAME, VSI_OP_CODE) \ - NAME::NAME(Graph* graph) : Operation(graph, VSI_NN_OP_RELATIONAL_OPS, 2, 1) { \ - this->impl()->node()->nn_param.relational_ops.op = VSI_OP_CODE; \ +#define DEFINE_RELATIONAL_OP(NAME, VSI_OP_CODE) \ + NAME::NAME(Graph* graph) \ + : Operation(graph, VSI_NN_OP_RELATIONAL_OPS, 2, 1) { \ + this->impl()->node()->nn_param.relational_ops.op = VSI_OP_CODE; \ + } \ + std::shared_ptr NAME::Clone(std::shared_ptr& graph) \ + const { \ + return graph->CreateOperation(); \ } DEFINE_RELATIONAL_OP(Greater, VSI_NN_RELATIONAL_OPS_GREAT) diff --git a/src/tim/vx/ops/reorg.cc b/src/tim/vx/ops/reorg.cc index 41e9285..b791952 100644 --- a/src/tim/vx/ops/reorg.cc +++ b/src/tim/vx/ops/reorg.cc @@ -35,6 +35,10 @@ Reorg::Reorg(Graph* graph, const uint32_t stride) this->impl()->node()->nn_param.reorg.stride = stride_; } +std::shared_ptr Reorg::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->stride_); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/reshape.cc b/src/tim/vx/ops/reshape.cc index 63cf948..79f0fb0 100644 --- a/src/tim/vx/ops/reshape.cc +++ b/src/tim/vx/ops/reshape.cc @@ -36,6 +36,11 @@ Reshape::Reshape(Graph* graph, const std::vector& size) this->impl()->node()->nn_param.reshape.dim_num = size_.size(); } +std::shared_ptr Reshape::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation(this->size_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/resize.cc b/src/tim/vx/ops/resize.cc index 41566a5..82a20d3 100644 --- a/src/tim/vx/ops/resize.cc +++ b/src/tim/vx/ops/resize.cc @@ -50,6 +50,13 @@ Resize::Resize(Graph* graph, ResizeType type, float factor, bool align_corners, impl()->node()->nn_param.resize.size[1] = target_height; } +std::shared_ptr Resize::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation( + this->type_, this->factor_, this->align_corners_, + this->half_pixel_centers_, this->target_height_, this->target_width_, + this->impl_->layout_); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/resize1d.cc b/src/tim/vx/ops/resize1d.cc index d3567c9..a436b24 100644 --- a/src/tim/vx/ops/resize1d.cc +++ b/src/tim/vx/ops/resize1d.cc @@ -47,6 +47,13 @@ Resize1d::Resize1d(Graph* graph, ResizeType type, float factor, bool align_corne impl()->node()->nn_param.resize_1d.size[0] = target_size; } +std::shared_ptr Resize1d::Clone( + std::shared_ptr& graph) const { + return graph->CreateOperation( + this->type_, this->factor_, this->align_corners_, + this->half_pixel_centers_, this->target_size_, this->impl_->layout_); +} + } // namespace ops } // namespace vx } // namespace tim diff --git a/src/tim/vx/ops/reverse.cc b/src/tim/vx/ops/reverse.cc index 9a086d6..6de5502 100644 --- a/src/tim/vx/ops/reverse.cc +++ b/src/tim/vx/ops/reverse.cc @@ -35,6 +35,11 @@ Reverse::Reverse(Graph* graph, const std::vector& axis) this->impl()->node()->nn_param.reverse.axis = axis_.data(); this->impl()->node()->nn_param.reverse.axis_num = axis_.size(); } + +std::shared_ptr Reverse::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->axis_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/scatternd.cc b/src/tim/vx/ops/scatternd.cc index d979d81..f62a917 100644 --- a/src/tim/vx/ops/scatternd.cc +++ b/src/tim/vx/ops/scatternd.cc @@ -35,6 +35,11 @@ ScatterND::ScatterND(Graph* graph, const std::vector& shape) this->impl()->node()->nn_param.scatter_nd.dim_num = shape_.size(); this->impl()->node()->nn_param.scatter_nd.shape = shape_.data(); } + +std::shared_ptr ScatterND::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation(this->shape_); +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/select.cc b/src/tim/vx/ops/select.cc index b2f4ddd..cf1b889 100644 --- a/src/tim/vx/ops/select.cc +++ b/src/tim/vx/ops/select.cc @@ -32,6 +32,11 @@ namespace ops { Select::Select(Graph* graph) : Operation(graph, VSI_NN_OP_SELECT) {} + +std::shared_ptr Select::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation