diff --git a/include/tim/vx/ops/groupedconv1d.h b/include/tim/vx/ops/groupedconv1d.h index 388bb8d..439291b 100644 --- a/include/tim/vx/ops/groupedconv1d.h +++ b/include/tim/vx/ops/groupedconv1d.h @@ -55,22 +55,25 @@ namespace ops { class GroupedConv1d : public DirectMapOp { public: - GroupedConv1d(Graph* graph, PadType padding, - uint32_t stride, - uint32_t dilation, - uint32_t group, - DataLayout input_layout = DataLayout::WCN, - DataLayout kernel_layout = DataLayout::WIcOc); + GroupedConv1d(Graph* graph, PadType padding, std::array pad, + uint32_t stride, uint32_t dilation, uint32_t group, + DataLayout input_layout = DataLayout::WCN, + DataLayout kernel_layout = DataLayout::WIcOc); + GroupedConv1d(Graph* graph, PadType padding, const uint32_t stride, + const uint32_t dilation, uint32_t group, + DataLayout input_layout = DataLayout::WCN, + DataLayout kernel_layout = DataLayout::WIcOc); DataLayout KernelDataLayout() { return kernel_layout_; } - std::shared_ptr Clone(std::shared_ptr& graph) const override; + std::shared_ptr Clone( + std::shared_ptr& graph) const override; protected: const PadType padding_; + const std::array pad_; const uint32_t stride_; const uint32_t dilation_; - const std::array pad_; const uint32_t group_; const DataLayout kernel_layout_; }; diff --git a/src/tim/vx/ops/groupedconv1d.cc b/src/tim/vx/ops/groupedconv1d.cc index 46a8ca8..bdd9819 100644 --- a/src/tim/vx/ops/groupedconv1d.cc +++ b/src/tim/vx/ops/groupedconv1d.cc @@ -31,27 +31,37 @@ namespace tim { namespace vx { namespace ops { -GroupedConv1d::GroupedConv1d(Graph* graph, - PadType padding, - const uint32_t stride, - const uint32_t dilation, - uint32_t group, - DataLayout input_layout, DataLayout kernel_layout) +GroupedConv1d::GroupedConv1d(Graph* graph, PadType padding, + const uint32_t stride, const uint32_t dilation, + uint32_t group, DataLayout input_layout, + DataLayout kernel_layout) + : GroupedConv1d(graph, padding, {0, 0}, stride, dilation, group, input_layout, kernel_layout) {} + +GroupedConv1d::GroupedConv1d(Graph* graph, PadType padding, + std::array pad, const uint32_t stride, + const uint32_t dilation, uint32_t group, + DataLayout input_layout, DataLayout kernel_layout) : DirectMapOp(graph, VSI_NN_OP_GROUPED_CONV1D, 3, 1, input_layout), - padding_(padding), stride_(stride), dilation_(dilation), - pad_({0,0}), group_(group), + padding_(padding), + pad_(pad), + stride_(stride), + dilation_(dilation), + group_(group), kernel_layout_(kernel_layout) { - this->impl()->node()->nn_param.grouped_conv1d.pad_type = TranslatePadType(padding_); + this->impl()->node()->nn_param.grouped_conv1d.pad_type = + TranslatePadType(padding_); + this->impl()->node()->nn_param.grouped_conv1d.pad[0] = pad_[0]; + this->impl()->node()->nn_param.grouped_conv1d.pad[1] = pad_[1]; this->impl()->node()->nn_param.grouped_conv1d.stride = stride_; this->impl()->node()->nn_param.grouped_conv1d.group = group_; this->impl()->node()->nn_param.grouped_conv1d.dilation = dilation_; - } +} std::shared_ptr GroupedConv1d::Clone( std::shared_ptr& graph) const { return graph->CreateOperation( - this->padding_, this->stride_, this->dilation_, this->group_, this->impl_->layout_, - this->kernel_layout_); + this->padding_, this->pad_, this->stride_, this->dilation_, this->group_, + this->impl_->layout_, this->kernel_layout_); } } // namespace ops diff --git a/src/tim/vx/ops/groupedconv1d_test.cc b/src/tim/vx/ops/groupedconv1d_test.cc index ce96400..9486a9e 100644 --- a/src/tim/vx/ops/groupedconv1d_test.cc +++ b/src/tim/vx/ops/groupedconv1d_test.cc @@ -70,3 +70,49 @@ TEST(GroupedConv1d, shape_6_2_1_float_ksize_6_stride_1_group_2_no_bias_wcn) { EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); } + +TEST(GroupedConv1d, shape_6_2_1_float_ksize_6_stride_1_group_2_no_bias_wcn_PaddingTest) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType in_shape({2, 4, 1}); + tim::vx::ShapeType param_shape({3, 2, 4}); + tim::vx::ShapeType out_shape({2, 4, 1}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, + in_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec param_spec(tim::vx::DataType::FLOAT32, + param_shape, tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + out_shape, tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto weight_tensor = graph->CreateTensor(param_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + -1, 0, 1, -1.5, 0.5, 1.5, 1, 1 + }; + std::vector weight = { + -3, -2, -1.5, 1.5, 2, 3, + -2.5, -2, -1.5, 1.5, 2, 2.5, + -1, 0, 1, -1.5, 0.5, 1.5, + -1.5, 1.5, 2, -1, 0, 1, + }; + std::vector golden = { + 1.5, -2.25, 1, -2.25, -1.5, -3, 0.5, -3.25 + }; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float))); + EXPECT_TRUE(weight_tensor->CopyDataToTensor(weight.data(), weight.size() * sizeof(float))); + + auto op = graph->CreateOperation(tim::vx::PadType::VALID, 1, 1, 2); + (*op).BindInputs({input_tensor, weight_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + // EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); + EXPECT_EQ(golden, output); +}