Added RNNCell & unit test (#249)
Signed-off-by: Chen Xin <jack.chen@verisilicon.com> Co-authored-by: Chen Xin <jack.chen@verisilicon.com>
This commit is contained in:
parent
75d39e2cfd
commit
cea11422b8
8
BUILD
8
BUILD
|
|
@ -28,8 +28,10 @@ cc_library(
|
|||
],
|
||||
hdrs = [
|
||||
"include/tim/vx/context.h",
|
||||
"include/tim/vx/direct_map_op.h",
|
||||
"include/tim/vx/graph.h",
|
||||
"include/tim/vx/operation.h",
|
||||
"include/tim/vx/ops.h",
|
||||
"include/tim/vx/tensor.h",
|
||||
"include/tim/vx/types.h",
|
||||
"include/tim/transform/layout_inference.h",
|
||||
|
|
@ -41,8 +43,12 @@ cc_library(
|
|||
"src/tim/vx/context.cc",
|
||||
"src/tim/vx/graph_private.h",
|
||||
"src/tim/vx/graph.cc",
|
||||
"src/tim/vx/direct_map_op_impl.cc",
|
||||
"src/tim/vx/direct_map_op.cc",
|
||||
"src/tim/vx/direct_map_op_impl.h",
|
||||
"src/tim/vx/op_impl.cc",
|
||||
"src/tim/vx/op_impl.h",
|
||||
"src/tim/vx/operation.cc",
|
||||
"src/tim/vx/operation_private.h",
|
||||
"src/tim/vx/tensor.cc",
|
||||
"src/tim/vx/tensor_private.h",
|
||||
"src/tim/vx/type_utils.h",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2021 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#ifndef TIM_VX_DIRECTMAPOP_H
|
||||
#define TIM_VX_DIRECTMAPOP_H
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
// interface
|
||||
class DirectMapOp : public Operation {
|
||||
public:
|
||||
DirectMapOp(Graph* graph, uint32_t kind, int in_cnt = 0, int out_cnt = 0,
|
||||
DataLayout layout = DataLayout::ANY);
|
||||
};
|
||||
|
||||
} // namespace vx
|
||||
|
||||
} // namespace tim
|
||||
|
||||
#endif
|
||||
|
|
@ -30,12 +30,11 @@
|
|||
namespace tim {
|
||||
namespace vx {
|
||||
|
||||
class OperationImpl;
|
||||
class OpImpl;
|
||||
|
||||
class Operation {
|
||||
public:
|
||||
Operation(Graph* graph, uint32_t operation_id,
|
||||
int input_cnt = 0, int ouput_cnt = 0, DataLayout layout = DataLayout::ANY);
|
||||
Operation();
|
||||
virtual ~Operation();
|
||||
virtual std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const = 0;
|
||||
Operation& BindInput(const std::shared_ptr<Tensor>& tensor);
|
||||
|
|
@ -47,11 +46,11 @@ class Operation {
|
|||
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
|
||||
RoundType down_scale_size_rounding = RoundType::FLOOR,
|
||||
uint32_t accumulator_bits = 0);
|
||||
std::unique_ptr<OperationImpl>& impl();
|
||||
const std::unique_ptr<OperationImpl>& impl() const;
|
||||
std::unique_ptr<OpImpl>& impl();
|
||||
const std::unique_ptr<OpImpl>& impl() const;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<OperationImpl> impl_;
|
||||
std::unique_ptr<OpImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace vx
|
||||
|
|
|
|||
|
|
@ -63,10 +63,12 @@
|
|||
#include "tim/vx/ops/resize1d.h"
|
||||
#include "tim/vx/ops/resize.h"
|
||||
#include "tim/vx/ops/reverse.h"
|
||||
#include "tim/vx/ops/rnn_cell.h"
|
||||
#include "tim/vx/ops/scatternd.h"
|
||||
#include "tim/vx/ops/select.h"
|
||||
#include "tim/vx/ops/shuffle_channel.h"
|
||||
#include "tim/vx/ops/simple_operations.h"
|
||||
#include "tim/vx/ops/signal_frame.h"
|
||||
#include "tim/vx/ops/slice.h"
|
||||
#include "tim/vx/ops/softmax.h"
|
||||
#include "tim/vx/ops/space2batch.h"
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_ACTIVATIONS_H_
|
||||
#define TIM_VX_OPS_ACTIVATIONS_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -69,7 +69,7 @@ namespace ops {
|
|||
*/
|
||||
|
||||
#define DECLARE_NO_PARAMETER_ACTIVATION(NAME) \
|
||||
class NAME : public Operation { \
|
||||
class NAME : public DirectMapOp { \
|
||||
public: \
|
||||
NAME(Graph* graph); \
|
||||
std::shared_ptr<Operation> Clone( \
|
||||
|
|
@ -90,7 +90,7 @@ DECLARE_NO_PARAMETER_ACTIVATION(SoftRelu)
|
|||
|
||||
#undef DEFINE_NO_PARAMETER_ACTIVATION
|
||||
|
||||
class Prelu : public Operation {
|
||||
class Prelu : public DirectMapOp {
|
||||
public:
|
||||
Prelu(Graph* graph, int axis);
|
||||
std::shared_ptr<Operation> Clone(
|
||||
|
|
@ -100,7 +100,7 @@ class Prelu : public Operation {
|
|||
int axis_;
|
||||
};
|
||||
|
||||
class LeakyRelu : public Operation {
|
||||
class LeakyRelu : public DirectMapOp {
|
||||
public:
|
||||
LeakyRelu(Graph* graph, float alpha);
|
||||
std::shared_ptr<Operation> Clone(
|
||||
|
|
@ -110,7 +110,7 @@ class LeakyRelu : public Operation {
|
|||
float alpha_;
|
||||
};
|
||||
|
||||
class Linear : public Operation {
|
||||
class Linear : public DirectMapOp {
|
||||
public:
|
||||
Linear(Graph* graph, float a, float b = 0.0);
|
||||
std::shared_ptr<Operation> Clone(
|
||||
|
|
@ -121,7 +121,7 @@ class Linear : public Operation {
|
|||
float b_;
|
||||
};
|
||||
|
||||
class Gelu : public Operation {
|
||||
class Gelu : public DirectMapOp {
|
||||
public:
|
||||
/****************************************************************************
|
||||
*Non-approximate calculations will also have errors when the data type is
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_ADDN_H_
|
||||
#define TIM_VX_OPS_ADDN_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -37,7 +37,7 @@ namespace ops {
|
|||
* ```
|
||||
*/
|
||||
|
||||
class AddN : public Operation {
|
||||
class AddN : public DirectMapOp {
|
||||
public:
|
||||
AddN(Graph* graph, uint32_t num_inputs);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_ARG_H_
|
||||
#define TIM_VX_OPS_ARG_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -37,7 +37,7 @@ namespace ops {
|
|||
*/
|
||||
|
||||
#define DECLARE_ARG_OP(NAME) \
|
||||
class Arg##NAME : public Operation { \
|
||||
class Arg##NAME : public DirectMapOp { \
|
||||
public: \
|
||||
Arg##NAME(Graph* graph, int32_t axis); \
|
||||
std::shared_ptr<Operation> Clone( \
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include <vector>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -43,7 +43,7 @@ namespace ops {
|
|||
* - crop : corp the output tensor for ROI usage.
|
||||
*/
|
||||
|
||||
class Batch2Space : public Operation {
|
||||
class Batch2Space : public DirectMapOp {
|
||||
public:
|
||||
Batch2Space(Graph* graph, const std::vector<int>& block_size,
|
||||
const std::vector<int>& crop,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef OVXLIBXX_OPERATIONS_BATCHNORM_H_
|
||||
#define OVXLIBXX_OPERATIONS_BATCHNORM_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -40,7 +40,7 @@ namespace ops {
|
|||
* $$y_i=\gamma\hat x_i+\beta\equiv BN_{\gamma,\beta}(x_i)$$
|
||||
*/
|
||||
|
||||
class BatchNorm : public Operation {
|
||||
class BatchNorm : public DirectMapOp {
|
||||
public:
|
||||
BatchNorm(Graph* graph, float eps, DataLayout input_layout = DataLayout::WHCN);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef OVXLIBXX_OPERATIONS_CLIP_H_
|
||||
#define OVXLIBXX_OPERATIONS_CLIP_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
|
||||
namespace tim {
|
||||
|
|
@ -36,7 +36,7 @@ namespace ops {
|
|||
* Clip(x) : min if x <= min; x if min < x < max; max if x >= max
|
||||
*/
|
||||
|
||||
class Clip : public Operation {
|
||||
class Clip : public DirectMapOp {
|
||||
public:
|
||||
Clip(Graph* graph, float min, float max);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_CONCAT_H_
|
||||
#define TIM_VX_OPS_CONCAT_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -37,7 +37,7 @@ namespace ops {
|
|||
* - axis : Which axis to concat on.
|
||||
*/
|
||||
|
||||
class Concat : public Operation {
|
||||
class Concat : public DirectMapOp {
|
||||
public:
|
||||
Concat(Graph* graph, uint32_t axis, int input_cnt);
|
||||
|
||||
|
|
|
|||
|
|
@ -26,13 +26,13 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
class Conv1d : public Operation {
|
||||
class Conv1d : public DirectMapOp {
|
||||
public:
|
||||
Conv1d(Graph* graph, PadType padding, uint32_t stride,
|
||||
uint32_t dilation, int32_t multiplier = 0,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -55,7 +55,7 @@ namespace ops {
|
|||
* - layout : WHCN or CWHN.
|
||||
*/
|
||||
|
||||
class Conv2d : public Operation {
|
||||
class Conv2d : public DirectMapOp {
|
||||
public:
|
||||
Conv2d(Graph* graph, PadType padding,
|
||||
const std::array<uint32_t, 2>& stride,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -53,7 +53,7 @@ namespace ops {
|
|||
* - kernel_layout: Layout for kernel, WHIO by default.
|
||||
*/
|
||||
|
||||
class DeConv2d : public Operation {
|
||||
class DeConv2d : public DirectMapOp {
|
||||
public:
|
||||
DeConv2d(Graph* graph, int32_t oc_count_, PadType pad_type,
|
||||
const std::array<uint32_t, 2>& ksize,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -49,7 +49,7 @@ namespace ops {
|
|||
* the output tensor.
|
||||
*/
|
||||
|
||||
class DeConv1d : public Operation {
|
||||
class DeConv1d : public DirectMapOp {
|
||||
public:
|
||||
DeConv1d(Graph* graph, PadType pad_type,
|
||||
uint32_t stride, uint32_t output_padding, uint32_t group = 1,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_DEPTH2SPACE_H_
|
||||
#define TIM_VX_OPS_DEPTH2SPACE_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -45,7 +45,7 @@ namespace ops {
|
|||
* - crop : corp the output tensor for ROI usage.
|
||||
*/
|
||||
|
||||
class DepthToSpace : public Operation {
|
||||
class DepthToSpace : public DirectMapOp {
|
||||
public:
|
||||
DepthToSpace(Graph* Graph, int block_size,
|
||||
DataLayout layout = DataLayout::WHCN);
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef OVXLIBXX_OPERATIONS_DROPOUT_H_
|
||||
#define OVXLIBXX_OPERATIONS_DROPOUT_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
|
||||
namespace tim {
|
||||
|
|
@ -40,7 +40,7 @@ namespace ops {
|
|||
* for Dropout operator.
|
||||
*/
|
||||
|
||||
class Dropout : public Operation {
|
||||
class Dropout : public DirectMapOp {
|
||||
public:
|
||||
Dropout(Graph* graph, float ratio);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_ELEMENTWISE_H_
|
||||
#define TIM_VX_OPS_ELEMENTWISE_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -67,7 +67,7 @@ namespace ops {
|
|||
*/
|
||||
|
||||
#define DECLARE_ELEMENTWISE_OP(NAME) \
|
||||
class NAME : public Operation { \
|
||||
class NAME : public DirectMapOp { \
|
||||
public: \
|
||||
NAME(Graph* graph); \
|
||||
std::shared_ptr<Operation> Clone( \
|
||||
|
|
@ -81,14 +81,14 @@ DECLARE_ELEMENTWISE_OP(Sub)
|
|||
DECLARE_ELEMENTWISE_OP(Pow)
|
||||
DECLARE_ELEMENTWISE_OP(FloorDiv)
|
||||
|
||||
class Multiply : public Operation {
|
||||
class Multiply : public DirectMapOp {
|
||||
public:
|
||||
Multiply(Graph* graph, float scale = 1.0f);
|
||||
|
||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||
};
|
||||
|
||||
class Div : public Operation {
|
||||
class Div : public DirectMapOp {
|
||||
public:
|
||||
Div(Graph* graph, float scale = 1.0f);
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@
|
|||
#ifndef TIM_VX_OPS_ERF_H_
|
||||
#define TIM_VX_OPS_ERF_H_
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
#include "tim/vx/types.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
@ -39,7 +39,7 @@ namespace ops {
|
|||
* - no parameters
|
||||
*/
|
||||
|
||||
class Erf : public Operation {
|
||||
class Erf : public DirectMapOp {
|
||||
public:
|
||||
Erf(Graph* graph);
|
||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_FULLYCONNECTED_H_
|
||||
#define TIM_VX_OPS_FULLYCONNECTED_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -39,7 +39,7 @@ namespace ops {
|
|||
* - weights: the output channel number for weight tensor.
|
||||
*/
|
||||
|
||||
class FullyConnected : public Operation {
|
||||
class FullyConnected : public DirectMapOp {
|
||||
public:
|
||||
FullyConnected(Graph* graph, uint32_t axis);
|
||||
FullyConnected(Graph* graph, uint32_t axis, uint32_t weights);
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_GATHER_H_
|
||||
#define TIM_VX_OPS_GATHER_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -35,7 +35,7 @@ namespace ops {
|
|||
* Gather slices from input, **axis** according to **indices**.
|
||||
*/
|
||||
|
||||
class Gather : public Operation {
|
||||
class Gather : public DirectMapOp {
|
||||
public:
|
||||
Gather(Graph* Graph, int axis);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_GATHERND_H_
|
||||
#define TIM_VX_OPS_GATHERND_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -35,7 +35,7 @@ namespace ops {
|
|||
* An operation similar to Gather but gathers across multiple axis at once.
|
||||
*/
|
||||
|
||||
class GatherNd : public Operation {
|
||||
class GatherNd : public DirectMapOp {
|
||||
public:
|
||||
GatherNd(Graph* Graph);
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -53,7 +53,7 @@ namespace ops {
|
|||
* - layout : WCN or CWN.
|
||||
*/
|
||||
|
||||
class GroupedConv1d : public Operation {
|
||||
class GroupedConv1d : public DirectMapOp {
|
||||
public:
|
||||
GroupedConv1d(Graph* graph, PadType padding,
|
||||
uint32_t stride,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -53,7 +53,7 @@ namespace ops {
|
|||
* - layout : WHCN or CWHN.
|
||||
*/
|
||||
|
||||
class GroupedConv2d : public Operation {
|
||||
class GroupedConv2d : public DirectMapOp {
|
||||
public:
|
||||
GroupedConv2d(Graph* graph, PadType padding,
|
||||
const std::array<uint32_t, 2>& strides,
|
||||
|
|
|
|||
|
|
@ -23,12 +23,12 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_INSTANCENOMALIZATION_H_
|
||||
#define TIM_VX_OPS_INSTANCENOMALIZATION_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
class InstanceNormalization : public Operation {
|
||||
class InstanceNormalization : public DirectMapOp {
|
||||
public:
|
||||
InstanceNormalization(Graph* graph, float eps = 1e-5f);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_L2NOMALIZATION_H_
|
||||
#define TIM_VX_OPS_L2NOMALIZATION_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
/**
|
||||
* ## L2Normalization
|
||||
|
|
@ -40,7 +40,7 @@
|
|||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
class L2Normalization : public Operation {
|
||||
class L2Normalization : public DirectMapOp {
|
||||
public:
|
||||
L2Normalization(Graph* graph, int32_t axis);
|
||||
|
||||
|
|
|
|||
|
|
@ -25,12 +25,12 @@
|
|||
#define TIM_VX_OPS_LAYERNOMALIZATION_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
class LayerNormalization : public Operation {
|
||||
class LayerNormalization : public DirectMapOp {
|
||||
public:
|
||||
LayerNormalization(Graph* graph, int32_t axis = 0, float eps = 1e-5f);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_LOCALRESPONSENORMALIZATION_H_
|
||||
#define TIM_VX_OPS_LOCALRESPONSENORMALIZATION_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
/**
|
||||
* ## LocalResponseNormalization
|
||||
|
|
@ -40,7 +40,7 @@
|
|||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
class LocalResponseNormalization : public Operation {
|
||||
class LocalResponseNormalization : public DirectMapOp {
|
||||
public:
|
||||
LocalResponseNormalization(Graph* graph, uint32_t size, float alpha,
|
||||
float beta, float bias, int32_t axis);
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_LOGICAL_H_
|
||||
#define TIM_VX_OPS_LOGICAL_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -40,7 +40,7 @@ namespace ops {
|
|||
*/
|
||||
|
||||
#define DECLARE_LOGICAL_OP(NAME) \
|
||||
class Logical##NAME : public Operation { \
|
||||
class Logical##NAME : public DirectMapOp { \
|
||||
public: \
|
||||
Logical##NAME(Graph* graph); \
|
||||
std::shared_ptr<Operation> Clone( \
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_LOG_SOFTMAX_H_
|
||||
#define TIM_VX_OPS_LOG_SOFTMAX_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -39,7 +39,7 @@ namespace ops {
|
|||
* ```
|
||||
*/
|
||||
|
||||
class LogSoftmax : public Operation {
|
||||
class LogSoftmax : public DirectMapOp {
|
||||
public:
|
||||
LogSoftmax(Graph* graph, int32_t axis, float beta = 1.f);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_MATMUL_H_
|
||||
#define TIM_VX_OPS_MATMUL_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -40,7 +40,7 @@ namespace ops {
|
|||
* - adjoint_b: If True, b is conjugated and transposed before multiplication.
|
||||
*/
|
||||
|
||||
class Matmul : public Operation {
|
||||
class Matmul : public DirectMapOp {
|
||||
public:
|
||||
Matmul(Graph* graph, bool transpose_a = false, bool transpose_b = false,
|
||||
bool adjoint_a = false, bool adjoint_b = false);
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
#include "tim/vx/types.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
@ -44,7 +44,7 @@ namespace ops {
|
|||
* - round_type : CEILING or FLOOR.
|
||||
*/
|
||||
|
||||
class MaxpoolWithArgmax : public Operation {
|
||||
class MaxpoolWithArgmax : public DirectMapOp {
|
||||
public:
|
||||
MaxpoolWithArgmax(Graph* graph, PadType padding,
|
||||
const std::array<uint32_t, 2>& ksize,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
#include "tim/vx/types.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
@ -42,7 +42,7 @@ namespace ops {
|
|||
* - ksize : filter size.
|
||||
*/
|
||||
|
||||
class MaxUnpool2d : public Operation {
|
||||
class MaxUnpool2d : public DirectMapOp {
|
||||
public:
|
||||
MaxUnpool2d(Graph* graph, const std::array<uint32_t, 2>& ksize,
|
||||
const std::array<uint32_t, 2>& stride, DataLayout layout = DataLayout::WHCN);
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_MOMENTS_H_
|
||||
#define TIM_VX_OPS_MOMENTS_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -39,7 +39,7 @@ namespace ops {
|
|||
* - keep_dims : Produce moments with the same dimensionality as input.
|
||||
*/
|
||||
|
||||
class Moments : public Operation {
|
||||
class Moments : public DirectMapOp {
|
||||
public:
|
||||
Moments(Graph* graph, const std::vector<int32_t>& axes,
|
||||
bool keep_dims = false);
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_NBG_H_
|
||||
#define TIM_VX_OPS_NBG_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -36,7 +36,7 @@ namespace ops {
|
|||
* a bianry file.
|
||||
*/
|
||||
|
||||
class NBG : public Operation {
|
||||
class NBG : public DirectMapOp {
|
||||
public:
|
||||
NBG(Graph* graph, const char* binary, size_t input_count, size_t output_count);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPERATION_PAD_H_
|
||||
#define TIM_VX_OPERATION_PAD_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -37,7 +37,7 @@ namespace ops {
|
|||
* - const_val : the value to pad.
|
||||
*/
|
||||
|
||||
class Pad : public Operation {
|
||||
class Pad : public DirectMapOp {
|
||||
public:
|
||||
Pad(Graph* graph, const std::vector<uint32_t>& front_size,
|
||||
const std::vector<uint32_t>& back_size, int32_t const_val);
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
#include "tim/vx/types.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
@ -63,7 +63,7 @@ namespace ops {
|
|||
*
|
||||
*/
|
||||
|
||||
class Pool2d : public Operation {
|
||||
class Pool2d : public DirectMapOp {
|
||||
public:
|
||||
// for Classic Pool2d
|
||||
Pool2d(Graph* graph, PoolType type, PadType padding,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_REDUCE_H_
|
||||
#define TIM_VX_OPS_REDUCE_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -95,7 +95,7 @@ namespace ops {
|
|||
*/
|
||||
|
||||
#define DECLARE_REDUCE_OP(NAME) \
|
||||
class Reduce##NAME : public Operation { \
|
||||
class Reduce##NAME : public DirectMapOp { \
|
||||
public: \
|
||||
Reduce##NAME(Graph* graph, const std::vector<int32_t>& axis, \
|
||||
bool keep_dims); \
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_RELATIONAL_H_
|
||||
#define TIM_VX_OPS_RELATIONAL_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -56,7 +56,7 @@ namespace ops {
|
|||
*/
|
||||
|
||||
#define DECLARE_RELATIONAL_OP(NAME) \
|
||||
class NAME : public Operation { \
|
||||
class NAME : public DirectMapOp { \
|
||||
public: \
|
||||
NAME(Graph* graph); \
|
||||
std::shared_ptr<Operation> Clone( \
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_REORG_H_
|
||||
#define TIM_VX_OPS_REORG_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -35,7 +35,7 @@ namespace ops {
|
|||
* The layer used in YOLOv2. See also https://github.com/pjreddie/darknet/blob/master/src/reorg_layer.c
|
||||
*/
|
||||
|
||||
class Reorg : public Operation {
|
||||
class Reorg : public DirectMapOp {
|
||||
public:
|
||||
Reorg(Graph* graph, const uint32_t stride);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_RESHAPE_H_
|
||||
#define TIM_VX_OPS_RESHAPE_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -37,7 +37,7 @@ namespace ops {
|
|||
* - size : defining the shape of the output tensor.
|
||||
*/
|
||||
|
||||
class Reshape : public Operation {
|
||||
class Reshape : public DirectMapOp {
|
||||
public:
|
||||
Reshape(Graph* graph, const std::vector<uint32_t>& size);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_RESIZE_H_
|
||||
#define TIM_VX_OPS_RESIZE_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -44,7 +44,7 @@ namespace ops {
|
|||
* - target_height / target_width : output height / width. DO NOT use it with factor together.
|
||||
*/
|
||||
|
||||
class Resize : public Operation {
|
||||
class Resize : public DirectMapOp {
|
||||
public:
|
||||
Resize(Graph* graph, ResizeType type, float factor, bool align_corners,
|
||||
bool half_pixel_centers, int target_height, int target_width,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_RESIZE1D_H_
|
||||
#define TIM_VX_OPS_RESIZE1D_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -44,7 +44,7 @@ namespace ops {
|
|||
* - target_height / target_width : output height / width. DO NOT use it with factor together.
|
||||
*/
|
||||
|
||||
class Resize1d : public Operation {
|
||||
class Resize1d : public DirectMapOp {
|
||||
public:
|
||||
Resize1d(Graph* graph, ResizeType type, float factor, bool align_corners,
|
||||
bool half_pixel_centers, int target_size,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_REVERSE_H_
|
||||
#define TIM_VX_OPS_REVERSE_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -37,7 +37,7 @@ namespace ops {
|
|||
* - axis : The indices of the dimensions to reverse.
|
||||
*/
|
||||
|
||||
class Reverse : public Operation {
|
||||
class Reverse : public DirectMapOp {
|
||||
public:
|
||||
Reverse(Graph* graph, const std::vector<int32_t>& axis);
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2021 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_RNN_CELL_H_
|
||||
#define TIM_VX_OPS_RNN_CELL_H_
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
class RNNCell : public Operation{
|
||||
public:
|
||||
enum ActivationType {
|
||||
kNONE = 0,
|
||||
kRELU = 1,
|
||||
kRELU1 = 2,
|
||||
kRELU6 = 3,
|
||||
kTANH = 4,
|
||||
kSIGMOID = 6,
|
||||
kHARDSIGMOID = 31, /* temporary use 31*/
|
||||
};
|
||||
RNNCell(Graph* graph, ActivationType activation);
|
||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||
|
||||
protected:
|
||||
const ActivationType activation_;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
||||
#endif /* TIM_VX_OPS_RNN_CELL_H_ */
|
||||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_SCATTERND_H_
|
||||
#define TIM_VX_OPS_SCATTERND_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -37,7 +37,7 @@ namespace ops {
|
|||
* - shape : The shape of the resulting tensor.
|
||||
*/
|
||||
|
||||
class ScatterND : public Operation {
|
||||
class ScatterND : public DirectMapOp {
|
||||
public:
|
||||
ScatterND(Graph* graph, const std::vector<uint32_t>& shape);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_SELECT_H_
|
||||
#define TIM_VX_OPS_SELECT_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -36,7 +36,7 @@ namespace ops {
|
|||
* from both input tensors: O[i] = C[i] ? x[i] : y[i].
|
||||
*/
|
||||
|
||||
class Select : public Operation {
|
||||
class Select : public DirectMapOp {
|
||||
public:
|
||||
Select(Graph* graph);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_SHUFFLE_H_
|
||||
#define TIM_VX_OPS_SHUFFLE_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -38,7 +38,7 @@ namespace ops {
|
|||
* ```
|
||||
*/
|
||||
|
||||
class ShuffleChannel : public Operation {
|
||||
class ShuffleChannel : public DirectMapOp {
|
||||
public:
|
||||
explicit ShuffleChannel(Graph* graph, int32_t num_groups, int32_t index_axis);
|
||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_SIGNALFRAME_H_
|
||||
#define TIM_VX_OPS_SIGNALFRAME_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -39,7 +39,7 @@ namespace ops {
|
|||
* ```
|
||||
*/
|
||||
|
||||
class SignalFrame : public Operation {
|
||||
class SignalFrame : public DirectMapOp {
|
||||
public:
|
||||
SignalFrame(Graph* graph, uint32_t window_length, uint32_t step, uint32_t pad_end=0,
|
||||
uint32_t axis=0);
|
||||
|
|
|
|||
|
|
@ -23,14 +23,14 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_SIMPLE_OPERATIONS_H_
|
||||
#define TIM_VX_OPS_SIMPLE_OPERATIONS_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
#define DECLARE_SIMPLE_OP(NAME) \
|
||||
class NAME : public Operation { \
|
||||
class NAME : public DirectMapOp { \
|
||||
public: \
|
||||
NAME(Graph* graph); \
|
||||
std::shared_ptr<Operation> Clone( \
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_SLICE_H_
|
||||
#define TIM_VX_OPS_SLICE_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -38,7 +38,7 @@ namespace ops {
|
|||
* - length : the size of the slice in each dimension.
|
||||
*/
|
||||
|
||||
class Slice : public Operation {
|
||||
class Slice : public DirectMapOp {
|
||||
public:
|
||||
Slice(Graph* graph,
|
||||
uint32_t dims,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_SOFTMAX_H_
|
||||
#define TIM_VX_OPS_SOFTMAX_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -42,7 +42,7 @@ namespace ops {
|
|||
* ```
|
||||
*/
|
||||
|
||||
class Softmax : public Operation {
|
||||
class Softmax : public DirectMapOp {
|
||||
public:
|
||||
Softmax(Graph* graph, float beta, int32_t axis);
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include <vector>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -46,7 +46,7 @@ namespace ops {
|
|||
* - pad : the paddings for each spatial dimension of the input tensor.
|
||||
*/
|
||||
|
||||
class Space2Batch : public Operation {
|
||||
class Space2Batch : public DirectMapOp {
|
||||
public:
|
||||
Space2Batch(Graph* graph, const std::vector<int>& block_size,
|
||||
const std::vector<int>& pad,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_SPACE2DEPTH_H_
|
||||
#define TIM_VX_OPS_SPACE2DEPTH_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -38,7 +38,7 @@ namespace ops {
|
|||
* transformation of DepthToSpace.
|
||||
*/
|
||||
|
||||
class SpaceToDepth : public Operation {
|
||||
class SpaceToDepth : public DirectMapOp {
|
||||
public:
|
||||
SpaceToDepth(Graph* graph, std::vector<int> block_size,
|
||||
DataLayout layout = DataLayout::WHCN);
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_SPATIAL_TRANSFORMER_H_
|
||||
#define TIM_VX_OPS_SPATIAL_TRANSFORMER_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -40,7 +40,7 @@ namespace ops {
|
|||
It is the output of the localization network.
|
||||
*/
|
||||
|
||||
class SpatialTransformer : public Operation {
|
||||
class SpatialTransformer : public DirectMapOp {
|
||||
public:
|
||||
SpatialTransformer(Graph* graph, uint32_t output_h, uint32_t output_w,
|
||||
bool has_theta_1_1, bool has_theta_1_2, bool has_theta_1_3,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#define TIM_VX_OPS_SPLIT_H_
|
||||
#include <vector>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -40,7 +40,7 @@ namespace ops {
|
|||
* - slices : indicating the number of splits along given axis.
|
||||
*/
|
||||
|
||||
class Split : public Operation {
|
||||
class Split : public DirectMapOp {
|
||||
public:
|
||||
Split(Graph* graph, uint32_t axis, std::vector<uint32_t> slices);
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@
|
|||
#ifndef TIM_VX_OPS_SQUEEZE_H_
|
||||
#define TIM_VX_OPS_SQUEEZE_H_
|
||||
#include <vector>
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -38,7 +38,7 @@ namespace ops {
|
|||
* - axis : the dimensions to squeeze.
|
||||
*/
|
||||
|
||||
class Squeeze : public Operation {
|
||||
class Squeeze : public DirectMapOp {
|
||||
public:
|
||||
Squeeze(Graph* graph, std::vector<uint32_t> axis);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_STACK_H_
|
||||
#define TIM_VX_OPS_STACK_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -36,7 +36,7 @@ namespace ops {
|
|||
* each tensor in values, by packing them along the **axis** dimension.
|
||||
*/
|
||||
|
||||
class Stack : public Operation {
|
||||
class Stack : public DirectMapOp {
|
||||
public:
|
||||
Stack(Graph* graph, uint32_t axis, int input_cnt);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_STRIDEDSLICE_H_
|
||||
#define TIM_VX_OPS_STRIDEDSLICE_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -52,7 +52,7 @@ namespace ops {
|
|||
* e.g. begin[i] = x, end[i] = x + 1.
|
||||
*/
|
||||
|
||||
class StridedSlice : public Operation {
|
||||
class StridedSlice : public DirectMapOp {
|
||||
public:
|
||||
StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
|
||||
const std::vector<int32_t> end_dims,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
#include "tim/vx/types.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
@ -43,7 +43,7 @@ namespace ops {
|
|||
* - spectrogram_length : corresponds to the fixed-size of the memory.
|
||||
*/
|
||||
|
||||
class Svdf : public Operation {
|
||||
class Svdf : public DirectMapOp {
|
||||
public:
|
||||
Svdf(Graph* graph, int32_t rank, int32_t num_units, int32_t spectrogram_length);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_TILE_H_
|
||||
#define TIM_VX_OPS_TILE_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -37,7 +37,7 @@ namespace ops {
|
|||
* Length must be the same as the number of dimensions in input.
|
||||
*/
|
||||
|
||||
class Tile : public Operation {
|
||||
class Tile : public DirectMapOp {
|
||||
public:
|
||||
Tile(Graph* graph, const std::vector<int32_t>& multiples);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_TRANSPOSE_H_
|
||||
#define TIM_VX_OPS_TRANSPOSE_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -41,7 +41,7 @@ namespace ops {
|
|||
* 2-D input Tensors.
|
||||
*/
|
||||
|
||||
class Transpose : public Operation {
|
||||
class Transpose : public DirectMapOp {
|
||||
public:
|
||||
Transpose(Graph* graph, const std::vector<uint32_t>& perm);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_UNIDIRECTIONAL_SEQUENCE_LSTM_H_
|
||||
#define TIM_VX_OPS_UNIDIRECTIONAL_SEQUENCE_LSTM_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -32,7 +32,7 @@ namespace ops {
|
|||
* ## Unidirectional sequence lstm
|
||||
* how to bind input/output: take unidirectional_sequence_lstm_test.cc
|
||||
*/
|
||||
class UnidirectionalSequenceLstm: public Operation {
|
||||
class UnidirectionalSequenceLstm: public DirectMapOp {
|
||||
public:
|
||||
enum ActivationType {
|
||||
kNONE = 0,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPS_UNSTACK_H_
|
||||
#define TIM_VX_OPS_UNSTACK_H_
|
||||
#include "tim/vx/operation.h"
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
|
|
@ -37,7 +37,7 @@ namespace ops {
|
|||
* Negative values wrap around, so the valid range is [-R, R).
|
||||
*/
|
||||
|
||||
class Unstack : public Operation {
|
||||
class Unstack : public DirectMapOp {
|
||||
public:
|
||||
Unstack(Graph* graph, int32_t axis, uint32_t output_num);
|
||||
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
|
|||
std::shared_ptr<layout_inference_impl::LayoutInferContext>& ctx,
|
||||
const std::shared_ptr<vx::Operation>& op) {
|
||||
ctx->MarkVisited(op);
|
||||
auto op_id = op->impl()->operation_id_;
|
||||
auto op_id = op->impl()->kind_;
|
||||
std::vector<std::shared_ptr<vx::Tensor>> next_tensors;
|
||||
switch (op_id) {
|
||||
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_CONV2D, Conv2d);
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#define TIM_LAYOUT_INFER_ADDN_LAYOUT_INFERENCE_H_
|
||||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "tim/vx/ops/addn.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#define TIM_LAYOUT_INFER_ARG_OPS_LAYOUT_INFERENCE_H_
|
||||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "tim/vx/ops/arg.h"
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
class Batch2SpaceLayoutInfer : public OpLayoutInfer {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "op_impl.h"
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
class BatchNormLayoutInfer : public OpLayoutInfer {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include "tim/vx/ops/conv2d.h"
|
||||
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "permute_vector.h"
|
||||
#include "ops/op_layout_inference.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "tim/vx/ops/deconv.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#define TIM_LAYOUT_INFER_GATHER_LAYOUT_INFERENCE_H_
|
||||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "tim/vx/ops/gather.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#define TIM_LAYOUT_INFER_GATHER_ND_LAYOUT_INFERENCE_H_
|
||||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "tim/vx/ops/gathernd.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#define TIM_LAYOUT_INFER_L2_NORMALIZATION_LAYOUT_INFERENCE_H_
|
||||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "tim/vx/ops/l2normalization.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#define TIM_LAYOUT_INFER_LOGICAL_OPS_LAYOUT_INFERENCE_H_
|
||||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "tim/vx/ops/logical.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@
|
|||
#include "tim/vx/ops/localresponsenormalization.h"
|
||||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@
|
|||
|
||||
#include "op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "tim/vx/ops/transpose.h"
|
||||
#include "type_utils.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
class PadLayoutInfer : public OpLayoutInfer {
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "tim/vx/ops/pool2d.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
class ResizeLayoutInfer : public OpLayoutInfer {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#define TIM_LAYOUT_INFER_REVERSE_LAYOUT_INFERENCE_H_
|
||||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "tim/vx/ops/reverse.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#define TIM_LAYOUT_INFER_SELECT_LAYOUT_INFERENCE_H_
|
||||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "tim/vx/ops/select.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#define TIM_LAYOUT_INFER_SLICE_LAYOUT_INFERENCE_H_
|
||||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "tim/vx/ops/slice.h"
|
||||
|
||||
namespace tim {
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include "tim/vx/ops/softmax.h"
|
||||
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "permute_vector.h"
|
||||
#include "ops/op_layout_inference.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
class Space2BatchLayoutInfer : public OpLayoutInfer {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
class SpaceToDepthLayoutInfer : public OpLayoutInfer {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include "tim/vx/ops/stack.h"
|
||||
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "permute_vector.h"
|
||||
#include "ops/op_layout_inference.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "operation_private.h"
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2021 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
#include "direct_map_op_impl.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
DirectMapOp::DirectMapOp(Graph* graph, uint32_t kind, int in_cnt, int out_cnt,
|
||||
DataLayout layout) {
|
||||
impl_ = std::make_unique<DirectMapOpImpl>(graph, kind, in_cnt, out_cnt, layout);
|
||||
}
|
||||
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2021 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "type_utils.h"
|
||||
|
||||
namespace tim{
|
||||
namespace vx{
|
||||
|
||||
DirectMapOpImpl::DirectMapOpImpl(Graph* graph, uint32_t kind, int input_cnt,
|
||||
int output_cnt, DataLayout layout)
|
||||
: OpImpl(graph, kind, input_cnt, output_cnt, layout),
|
||||
node_(vsi_nn_AddNode(graph_->graph(), kind_, input_cnt_, output_cnt_,
|
||||
NULL)) {
|
||||
SetRoundingPolicy();
|
||||
node_->uid = graph_->graph()->cur_nid;
|
||||
}
|
||||
|
||||
DirectMapOpImpl& DirectMapOpImpl::BindInput(const std::shared_ptr<Tensor>& tensor) {
|
||||
inputs_tensor_.push_back(tensor);
|
||||
uint32_t tensor_id = tensor->GetId();
|
||||
node_->input.tensors[input_tensor_index++] = tensor_id;
|
||||
if (tensor->GetSpec().attr_ & TensorAttribute::INPUT) {
|
||||
graph_->AddInput(tensor_id);
|
||||
graph_->AddInput(tensor);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
DirectMapOpImpl& DirectMapOpImpl::BindOutput(
|
||||
const std::shared_ptr<Tensor>& tensor) {
|
||||
outputs_tensor_.push_back(tensor);
|
||||
uint32_t tensor_id = tensor->GetId();
|
||||
node_->output.tensors[output_tensor_index++] = tensor_id;
|
||||
if (tensor->GetSpec().attr_ == TensorAttribute::OUTPUT) {
|
||||
graph_->AddOutput(tensor_id);
|
||||
graph_->AddOutput(tensor);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void DirectMapOpImpl::SetRoundingPolicy(
|
||||
OverflowPolicy overflow_policy,
|
||||
RoundingPolicy rounding_policy,
|
||||
RoundType down_scale_size_rounding,
|
||||
uint32_t accumulator_bits) {
|
||||
node_->vx_param.overflow_policy = TranslateOverflowPolicy(overflow_policy);
|
||||
node_->vx_param.rounding_policy = TranslateRoundingPolicy(rounding_policy);
|
||||
node_->vx_param.down_scale_size_rounding =
|
||||
TranslateDownScaleSizeRounding(down_scale_size_rounding);
|
||||
node_->vx_param.accumulator_bits = accumulator_bits;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue