diff --git a/src/tim/vx/ops/deconv.cc b/src/tim/vx/ops/deconv.cc index 66f2171..7d21bf5 100644 --- a/src/tim/vx/ops/deconv.cc +++ b/src/tim/vx/ops/deconv.cc @@ -57,7 +57,7 @@ DeConv2d::DeConv2d(Graph* graph, int32_t oc_count, PadType pad_type, group_(group) { // TODO(Sven): only support depthwise usage - assert(group != 1 && group == oc_count); + assert((group == 1U) || group == oc_count); this->impl()->node()->nn_param.deconv.ksize[0] = ksize_[0]; this->impl()->node()->nn_param.deconv.ksize[1] = ksize_[1]; this->impl()->node()->nn_param.deconv.stride[0] = stride_[0]; diff --git a/src/tim/vx/ops/deconv2d_test.cc b/src/tim/vx/ops/deconv2d_test.cc index 8e6236f..f443f18 100644 --- a/src/tim/vx/ops/deconv2d_test.cc +++ b/src/tim/vx/ops/deconv2d_test.cc @@ -17,7 +17,7 @@ size_t element_count(const tim::vx::ShapeType& shape) { } // namespace -TEST(OP, deconv_group) { +TEST(OP, deconv_depthwise_two_channel) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -82,5 +82,57 @@ TEST(OP, deconv_group) { 0.0f, 0.0f, 40.0f, 16.0f, 24.0f, 0.0f, 0.0f, 72.0f, 0.0f, 16.0f}; + EXPECT_EQ(golden, output_data) << "Result mismatch"; +} + +TEST(OP, deconv_single_channel) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape ({3, 3, 1, 1}); //whcn + tim::vx::ShapeType kernel_shape({3, 3, 1, 1}); //whc1 same as depthwise convolution + tim::vx::ShapeType output_shape({5, 5, 1, 1}); //whcn + + tim::vx::TensorSpec input_spec (tim::vx::DataType::FLOAT32, input_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec kernel_spec (tim::vx::DataType::FLOAT32, kernel_shape, tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec output_spec (tim::vx::DataType::FLOAT32, output_shape, tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto kernel_tensor = graph->CreateTensor(kernel_spec); + + std::vector input_data = { + 3.0f, 8.0f, 1.0f, 9.0f, 5.0f, 7.0f, 3.0f, 2.0f, 3.0f, + }; + + std::vector kernel_data = { + 9.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 2.0f, + }; + + std::vector output_data(element_count(output_shape)); + + EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size()*4)); + EXPECT_TRUE(kernel_tensor->CopyDataToTensor(kernel_data.data(), kernel_data.size()*4)); + + auto add = graph->CreateOperation( + 1, + tim::vx::PadType::SAME, + std::array({3, 3}), /*ksize*/ + std::array({1, 1}), /*stride*/ + std::array({1, 1}), /*dilation*/ + std::array({0, 0, 0, 0}), /*pad*/ + 1/*group*/); + (*add).BindInputs({input_tensor, kernel_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output_data.data())); + std::vector golden = { + 27.0f, 72.0f, 18.0f, 24.0f, 3.0f, 81.0f, 45.0f, 90.0f, 15.0f, + 21.0f, 30.0f, 26.0f, 43.0f, 22.0f, 11.0f, 9.0f, 5.0f, 25.0f, + 10.0f, 14.0f, 3.0f, 2.0f, 9.0f, 4.0f, 6.0f, + }; + EXPECT_EQ(golden, output_data) << "Result mismatch"; } \ No newline at end of file