Add group parameter for deconv API
Limitation: only support depthwise deconvolution Signed-off-by: xiang.zhang <xiang.zhang@verisilicon.com>
This commit is contained in:
parent
8ab7759e3c
commit
b1b7eadefc
|
|
@ -34,23 +34,25 @@ namespace ops {
|
||||||
|
|
||||||
class DeConv2d : public Operation {
|
class DeConv2d : public Operation {
|
||||||
public:
|
public:
|
||||||
DeConv2d(Graph* graph, int32_t weights, PadType pad_type,
|
DeConv2d(Graph* graph, int32_t oc_count_, PadType pad_type,
|
||||||
const std::array<uint32_t, 2>& ksize,
|
const std::array<uint32_t, 2>& ksize,
|
||||||
const std::array<uint32_t, 2>& stride,
|
const std::array<uint32_t, 2>& stride,
|
||||||
const std::array<uint32_t, 2>& output_padding);
|
const std::array<uint32_t, 2>& output_padding);
|
||||||
DeConv2d(Graph* graph, int32_t weights, PadType pad_type,
|
DeConv2d(Graph* graph, int32_t oc_count_, PadType pad_type,
|
||||||
const std::array<uint32_t, 2>& ksize,
|
const std::array<uint32_t, 2>& ksize,
|
||||||
const std::array<uint32_t, 2>& stride,
|
const std::array<uint32_t, 2>& stride,
|
||||||
const std::array<uint32_t, 2>& output_padding,
|
const std::array<uint32_t, 2>& output_padding,
|
||||||
const std::array<uint32_t, 4>& pad);
|
const std::array<uint32_t, 4>& pad,
|
||||||
|
const uint32_t group = 1);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const uint32_t weights_;
|
const uint32_t oc_count_; // output channel count
|
||||||
const PadType pad_type_;
|
const PadType pad_type_;
|
||||||
const std::array<uint32_t, 2> ksize_;
|
const std::array<uint32_t, 2> ksize_;
|
||||||
const std::array<uint32_t, 2> stride_;
|
const std::array<uint32_t, 2> stride_;
|
||||||
const std::array<uint32_t, 2> output_padding_;
|
const std::array<uint32_t, 2> output_padding_;
|
||||||
const std::array<uint32_t, 4> pad_;
|
const std::array<uint32_t, 4> pad_;
|
||||||
|
const uint32_t group_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
||||||
|
|
@ -30,9 +30,6 @@ if(TIM_VX_ENABLE_LAYOUT_INFER)
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(UT_SRC)
|
|
||||||
aux_source_directory(./vx/ut VX_UT_SRC)
|
|
||||||
list(APPEND UT_SRC ${VX_UT_SRC})
|
|
||||||
foreach(src_file ${SRC})
|
foreach(src_file ${SRC})
|
||||||
if(${src_file} MATCHES ".*_test\.cc")
|
if(${src_file} MATCHES ".*_test\.cc")
|
||||||
list(REMOVE_ITEM SRC ${src_file})
|
list(REMOVE_ITEM SRC ${src_file})
|
||||||
|
|
|
||||||
|
|
@ -21,9 +21,10 @@
|
||||||
* DEALINGS IN THE SOFTWARE.
|
* DEALINGS IN THE SOFTWARE.
|
||||||
*
|
*
|
||||||
*****************************************************************************/
|
*****************************************************************************/
|
||||||
|
|
||||||
#include "tim/vx/ops/deconv.h"
|
#include "tim/vx/ops/deconv.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
#include "operation_private.h"
|
#include "operation_private.h"
|
||||||
#include "type_utils.h"
|
#include "type_utils.h"
|
||||||
#include "vsi_nn_pub.h"
|
#include "vsi_nn_pub.h"
|
||||||
|
|
@ -32,33 +33,38 @@ namespace tim {
|
||||||
namespace vx {
|
namespace vx {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
DeConv2d::DeConv2d(Graph* graph, int32_t weights, PadType pad_type,
|
DeConv2d::DeConv2d(Graph* graph, int32_t oc_count, PadType pad_type,
|
||||||
const std::array<uint32_t, 2>& ksize,
|
const std::array<uint32_t, 2>& ksize,
|
||||||
const std::array<uint32_t, 2>& stride,
|
const std::array<uint32_t, 2>& stride,
|
||||||
const std::array<uint32_t, 2>& output_padding)
|
const std::array<uint32_t, 2>& output_padding)
|
||||||
: DeConv2d(graph, weights, pad_type, ksize, stride, output_padding,
|
: DeConv2d(graph, oc_count, pad_type, ksize, stride, output_padding,
|
||||||
{0, 0, 0, 0}) {
|
{0, 0, 0, 0}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
DeConv2d::DeConv2d(Graph* graph, int32_t weights, PadType pad_type,
|
DeConv2d::DeConv2d(Graph* graph, int32_t oc_count, PadType pad_type,
|
||||||
const std::array<uint32_t, 2>& ksize,
|
const std::array<uint32_t, 2>& ksize,
|
||||||
const std::array<uint32_t, 2>& stride,
|
const std::array<uint32_t, 2>& stride,
|
||||||
const std::array<uint32_t, 2>& output_padding,
|
const std::array<uint32_t, 2>& output_padding,
|
||||||
const std::array<uint32_t, 4>& pad)
|
const std::array<uint32_t, 4>& pad,
|
||||||
|
const uint32_t group)
|
||||||
: Operation(graph, VSI_NN_OP_DECONVOLUTION),
|
: Operation(graph, VSI_NN_OP_DECONVOLUTION),
|
||||||
weights_(weights),
|
oc_count_(oc_count),
|
||||||
pad_type_(pad_type),
|
pad_type_(pad_type),
|
||||||
ksize_(ksize),
|
ksize_(ksize),
|
||||||
stride_(stride),
|
stride_(stride),
|
||||||
output_padding_(output_padding),
|
output_padding_(output_padding),
|
||||||
pad_(pad) {
|
pad_(pad),
|
||||||
|
group_(group) {
|
||||||
|
|
||||||
|
// TODO(Sven): only support depthwise usage
|
||||||
|
assert(group != 1 && group == oc_count);
|
||||||
this->impl()->node()->nn_param.deconv.ksize[0] = ksize_[0];
|
this->impl()->node()->nn_param.deconv.ksize[0] = ksize_[0];
|
||||||
this->impl()->node()->nn_param.deconv.ksize[1] = ksize_[1];
|
this->impl()->node()->nn_param.deconv.ksize[1] = ksize_[1];
|
||||||
this->impl()->node()->nn_param.deconv.stride[0] = stride_[0];
|
this->impl()->node()->nn_param.deconv.stride[0] = stride_[0];
|
||||||
this->impl()->node()->nn_param.deconv.stride[1] = stride_[1];
|
this->impl()->node()->nn_param.deconv.stride[1] = stride_[1];
|
||||||
this->impl()->node()->nn_param.deconv.pad_type = TranslatePadType(pad_type_);
|
this->impl()->node()->nn_param.deconv.pad_type = TranslatePadType(pad_type_);
|
||||||
this->impl()->node()->nn_param.deconv.weights = weights_;
|
this->impl()->node()->nn_param.deconv.weights = oc_count_;
|
||||||
this->impl()->node()->nn_param.deconv.group = 1;
|
this->impl()->node()->nn_param.deconv.group = group_;
|
||||||
this->impl()->node()->nn_param.deconv.output_padding[0] = output_padding_[0];
|
this->impl()->node()->nn_param.deconv.output_padding[0] = output_padding_[0];
|
||||||
this->impl()->node()->nn_param.deconv.output_padding[1] = output_padding_[1];
|
this->impl()->node()->nn_param.deconv.output_padding[1] = output_padding_[1];
|
||||||
this->impl()->node()->nn_param.deconv.pad[0] = pad_[0];
|
this->impl()->node()->nn_param.deconv.pad[0] = pad_[0];
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,86 @@
|
||||||
|
#include "tim/vx/context.h"
|
||||||
|
#include "tim/vx/graph.h"
|
||||||
|
#include "tim/vx/ops/deconv.h"
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
size_t element_count(const tim::vx::ShapeType& shape) {
|
||||||
|
size_t sz = 1;
|
||||||
|
for (auto d : shape) {
|
||||||
|
sz *= d;
|
||||||
|
}
|
||||||
|
|
||||||
|
return sz;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(OP, deconv_group) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape ({3, 3, 2, 1}); //whcn
|
||||||
|
tim::vx::ShapeType kernel_shape({3, 3, 2, 1}); //whc1 same as depthwise convolution
|
||||||
|
tim::vx::ShapeType output_shape({5, 5, 2, 1}); //whcn
|
||||||
|
|
||||||
|
tim::vx::TensorSpec input_spec (tim::vx::DataType::FLOAT32, input_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec kernel_spec (tim::vx::DataType::FLOAT32, kernel_shape, tim::vx::TensorAttribute::CONSTANT);
|
||||||
|
tim::vx::TensorSpec output_spec (tim::vx::DataType::FLOAT32, output_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
auto kernel_tensor = graph->CreateTensor(kernel_spec);
|
||||||
|
|
||||||
|
std::vector<float> input_data = {3.0f, 8.0f, 1.0f,
|
||||||
|
9.0f, 5.0f, 7.0f,
|
||||||
|
3.0f, 2.0f, 3.0f,
|
||||||
|
|
||||||
|
7.0f, 9.0f, 1.0f,
|
||||||
|
5.0f, 2.0f, 3.0f,
|
||||||
|
9.0f, 0.0f, 2.0f};
|
||||||
|
std::vector<float> kernel_data =
|
||||||
|
{9.0f, 0.0f, 3.0f,
|
||||||
|
0.0f, 0.0f, 0.0f,
|
||||||
|
1.0f, 0.0f, 2.0f,
|
||||||
|
|
||||||
|
3.0f, 0.0f, 7.0f,
|
||||||
|
0.0f, 0.0f, 0.0f,
|
||||||
|
0.0f, 0.0f, 8.0f,
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<float> output_data(element_count(output_shape));
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size()*4));
|
||||||
|
EXPECT_TRUE(kernel_tensor->CopyDataToTensor(kernel_data.data(), kernel_data.size()*4));
|
||||||
|
|
||||||
|
auto add = graph->CreateOperation<tim::vx::ops::DeConv2d>(
|
||||||
|
2,
|
||||||
|
tim::vx::PadType::SAME,
|
||||||
|
std::array<uint32_t, 2>({3, 3}), /*ksize*/
|
||||||
|
std::array<uint32_t, 2>({1, 1}), /*stride*/
|
||||||
|
std::array<uint32_t, 2>({1, 1}), /*dilation*/
|
||||||
|
std::array<uint32_t, 4>({0, 0, 0, 0}), /*pad*/
|
||||||
|
2/*group*/);
|
||||||
|
(*add).BindInputs({input_tensor, kernel_tensor}).BindOutputs({output_tensor});
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output_data.data()));
|
||||||
|
std::vector<float> golden = {
|
||||||
|
27.0f, 72.0f, 18.0f, 24.0f, 3.0f,
|
||||||
|
81.0f, 45.0f, 90.0f, 15.0f, 21.0f,
|
||||||
|
30.0f, 26.0f, 43.0f, 22.0f, 11.0f,
|
||||||
|
9.0f, 5.0f, 25.0f, 10.0f, 14.0f,
|
||||||
|
3.0f, 2.0f, 9.0f, 4.0f, 6.0f,
|
||||||
|
|
||||||
|
21.0f, 27.0f, 52.0f, 63.0f, 7.0f,
|
||||||
|
15.0f, 6.0f, 44.0f, 14.0f, 21.0f,
|
||||||
|
27.0f, 0.0f, 125.0f, 72.0f, 22.0f,
|
||||||
|
0.0f, 0.0f, 40.0f, 16.0f, 24.0f,
|
||||||
|
0.0f, 0.0f, 72.0f, 0.0f, 16.0f};
|
||||||
|
|
||||||
|
EXPECT_EQ(golden, output_data) << "Result mismatch";
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue