Add new APIs for conv, deconv and fc

The new apis remvoe weights, oc_count and ksize.

Signed-off-by: zhao.xia <zhao.xia@verisilicon.com>
This commit is contained in:
zhao.xia 2021-06-07 10:17:56 +08:00 committed by Kainan Cha
parent 8d35c4dd7a
commit 0ed1e8947f
8 changed files with 96 additions and 15 deletions

View File

@ -34,6 +34,14 @@ namespace ops {
class Conv1d : public Operation {
public:
Conv1d(Graph* graph, PadType padding, uint32_t stride,
uint32_t dilation, int32_t multiplier = 0,
DataLayout input_layout = DataLayout::WHCN,
DataLayout kernel_layout = DataLayout::WHIcOc);
Conv1d(Graph* graph, const std::array<uint32_t, 2>& pad,
uint32_t stride, uint32_t dilation, int32_t multiplier = 0,
DataLayout input_layout = DataLayout::WHCN,
DataLayout kernel_layout = DataLayout::WHIcOc);
Conv1d(Graph* graph, int32_t weights, PadType padding,
uint32_t ksize, uint32_t stride,
uint32_t dilation, int32_t multiplier = 0,

View File

@ -57,6 +57,16 @@ namespace ops {
class Conv2d : public Operation {
public:
Conv2d(Graph* graph, PadType padding,
const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, int32_t multiplier = 0,
DataLayout input_layout = DataLayout::WHCN,
DataLayout kernel_layout = DataLayout::WHIcOc);
Conv2d(Graph* graph, const std::array<uint32_t, 4> pad,
const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, int32_t multiplier = 0,
DataLayout input_layout = DataLayout::WHCN,
DataLayout kernel_layout = DataLayout::WHIcOc);
Conv2d(Graph* graph, int32_t weights, PadType padding,
const std::array<uint32_t, 2>& ksize,
const std::array<uint32_t, 2>& stride,

View File

@ -51,6 +51,14 @@ namespace ops {
class DeConv1d : public Operation {
public:
DeConv1d(Graph* graph, PadType pad_type,
uint32_t stride, uint32_t output_padding, uint32_t group = 1,
DataLayout input_layout = DataLayout::WHCN,
DataLayout kernel_layout = DataLayout::WHIcOc);
DeConv1d(Graph* graph, const std::array<uint32_t, 2>& pad,
uint32_t stride, uint32_t output_padding, uint32_t group = 1,
DataLayout input_layout = DataLayout::WHCN,
DataLayout kernel_layout = DataLayout::WHIcOc);
DeConv1d(Graph* graph, int32_t oc_count_, PadType pad_type,
uint32_t ksize,
uint32_t stride,
@ -61,6 +69,10 @@ class DeConv1d : public Operation {
uint32_t output_padding,
const std::array<uint32_t, 2>& pad,
uint32_t group = 1);
DeConv1d(Graph* graph, PadType pad_type,
uint32_t stride, uint32_t output_padding,
const std::array<uint32_t, 2>& pad, uint32_t group,
DataLayout input_layout, DataLayout kernel_layout);
protected:
const uint32_t oc_count_; // output channel count
@ -70,6 +82,7 @@ class DeConv1d : public Operation {
const uint32_t output_padding_;
const std::array<uint32_t, 2> pad_;
const uint32_t group_;
const DataLayout kernel_layout_;
};
} // namespace ops

View File

@ -41,6 +41,7 @@ namespace ops {
class FullyConnected : public Operation {
public:
FullyConnected(Graph* graph, uint32_t axis);
FullyConnected(Graph* graph, uint32_t axis, uint32_t weights);
protected:

View File

@ -31,6 +31,18 @@ namespace tim {
namespace vx {
namespace ops {
Conv1d::Conv1d(Graph* graph, PadType padding, uint32_t stride,
uint32_t dilation, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: Conv1d(graph, 0, padding, 0, stride, dilation, {0, 0},
multiplier, input_layout, kernel_layout) {}
Conv1d::Conv1d(Graph* graph, const std::array<uint32_t, 2>& pad,
uint32_t stride, uint32_t dilation, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: Conv1d(graph, 0, PadType::AUTO, 0, stride, dilation, pad,
multiplier, input_layout, kernel_layout) {}
Conv1d::Conv1d(Graph* graph, int32_t weights, PadType padding,
uint32_t ksize, uint32_t stride,
uint32_t dilation, int32_t multiplier,
@ -51,10 +63,8 @@ Conv1d::Conv1d(Graph* graph, int32_t weights, PadType padding,
pad_(pad),
multiplier_(multiplier),
kernel_layout_(kernel_layout) {
this->impl()->node()->nn_param.conv1d.ksize = ksize_;
this->impl()->node()->nn_param.conv1d.stride = stride_;
this->impl()->node()->nn_param.conv1d.pad_type = TranslatePadType(padding_);
this->impl()->node()->nn_param.conv1d.weights = weights;
this->impl()->node()->nn_param.conv1d.group = 1;
this->impl()->node()->nn_param.conv1d.dilation = dilation_;
this->impl()->node()->nn_param.conv1d.pad[0] = pad_[0];

View File

@ -31,6 +31,20 @@ namespace tim {
namespace vx {
namespace ops {
Conv2d::Conv2d(Graph* graph, PadType padding,
const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: Conv2d(graph, 0, padding, {0, 0}, stride, dilation, {0, 0, 0, 0},
multiplier, input_layout, kernel_layout) {}
Conv2d::Conv2d(Graph* graph, const std::array<uint32_t, 4> pad,
const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: Conv2d(graph, 0, PadType::AUTO, {0, 0}, stride, dilation, pad,
multiplier, input_layout, kernel_layout) {}
Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding,
const std::array<uint32_t, 2>& ksize,
const std::array<uint32_t, 2>& stride,
@ -54,12 +68,9 @@ Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding,
pad_(pad),
multiplier_(multiplier),
kernel_layout_(kernel_layout) {
this->impl()->node()->nn_param.conv2d.ksize[0] = ksize_[0];
this->impl()->node()->nn_param.conv2d.ksize[1] = ksize_[1];
this->impl()->node()->nn_param.conv2d.stride[0] = stride_[0];
this->impl()->node()->nn_param.conv2d.stride[1] = stride_[1];
this->impl()->node()->nn_param.conv2d.pad_type = TranslatePadType(padding_);
this->impl()->node()->nn_param.conv2d.weights = weights;
this->impl()->node()->nn_param.conv2d.group = 1;
this->impl()->node()->nn_param.conv2d.dilation[0] = dilation_[0];
this->impl()->node()->nn_param.conv2d.dilation[1] = dilation_[1];

View File

@ -33,28 +33,52 @@ namespace tim {
namespace vx {
namespace ops {
DeConv1d::DeConv1d(Graph* graph, PadType pad_type,
uint32_t stride, uint32_t output_padding, uint32_t group,
DataLayout input_layout, DataLayout kernel_layout)
: DeConv1d(graph, pad_type, stride, output_padding, {0, 0}, group,
input_layout, kernel_layout) {
}
DeConv1d::DeConv1d(Graph* graph, const std::array<uint32_t, 2>& pad,
uint32_t stride, uint32_t output_padding, uint32_t group,
DataLayout input_layout, DataLayout kernel_layout)
: DeConv1d(graph, PadType::AUTO, stride, output_padding, pad, group,
input_layout, kernel_layout) {
}
DeConv1d::DeConv1d(Graph* graph, int32_t oc_count, PadType pad_type,
uint32_t ksize, uint32_t stride, uint32_t output_padding)
: DeConv1d(graph, oc_count, pad_type, ksize, stride, output_padding,
{0, 0}) {
: DeConv1d(graph, pad_type, stride, output_padding,
{0, 0}, 1, DataLayout::WHCN, DataLayout::WHIcOc) {
(void)ksize;
(void)oc_count;
}
DeConv1d::DeConv1d(Graph* graph, int32_t oc_count, PadType pad_type,
uint32_t ksize, uint32_t stride, uint32_t output_padding,
const std::array<uint32_t, 2>& pad, uint32_t group)
: Operation(graph, VSI_NN_OP_DECONVOLUTION1D),
oc_count_(oc_count),
: DeConv1d(graph, pad_type, stride, output_padding,
pad, group, DataLayout::WHCN, DataLayout::WHIcOc) {
(void)ksize;
(void)oc_count;
}
DeConv1d::DeConv1d(Graph* graph, PadType pad_type,
uint32_t stride, uint32_t output_padding,
const std::array<uint32_t, 2>& pad, uint32_t group,
DataLayout input_layout, DataLayout kernel_layout)
: Operation(graph, VSI_NN_OP_DECONVOLUTION1D, 3, 1, input_layout),
oc_count_(0),
pad_type_(pad_type),
ksize_(ksize),
ksize_(0),
stride_(stride),
output_padding_(output_padding),
pad_(pad),
group_(group) {
this->impl()->node()->nn_param.deconvolution1d.ksize = ksize_;
group_(group),
kernel_layout_(kernel_layout) {
this->impl()->node()->nn_param.deconvolution1d.stride = stride_;
this->impl()->node()->nn_param.deconvolution1d.pad_type = TranslatePadType(pad_type_);
this->impl()->node()->nn_param.deconvolution1d.weights = oc_count_;
this->impl()->node()->nn_param.deconvolution1d.group = group_;
this->impl()->node()->nn_param.deconvolution1d.output_padding = output_padding_;
this->impl()->node()->nn_param.deconvolution1d.pad[0] = pad_[0];

View File

@ -30,10 +30,14 @@ namespace tim {
namespace vx {
namespace ops {
FullyConnected::FullyConnected(Graph* graph, uint32_t axis)
: FullyConnected(graph, axis, 0) {
}
FullyConnected::FullyConnected(Graph* graph, uint32_t axis, uint32_t weights)
: Operation(graph, VSI_NN_OP_FCL2) {
(void)weights;
this->impl()->node()->nn_param.fcl.axis = axis;
this->impl()->node()->nn_param.fcl.weights = weights;
}
} // namespace ops