From 4229ad88b398d657bd36f8bad257ec0e1946356e Mon Sep 17 00:00:00 2001 From: "Zongwu.Yang" Date: Tue, 11 Jan 2022 14:13:15 +0800 Subject: [PATCH] support conv3d (#238) Signed-off-by: Zongwu Yang --- include/tim/vx/ops.h | 1 + include/tim/vx/ops/conv3d.h | 102 +++++++++++++ include/tim/vx/types.h | 4 + src/tim/transform/layout_inference.cc | 2 + .../transform/ops/conv3d_layout_inference.h | 142 ++++++++++++++++++ src/tim/transform/ops/op_layout_inference.cc | 3 + src/tim/transform/ops/op_layout_inference.h | 4 + src/tim/vx/ops/conv3d.cc | 95 ++++++++++++ src/tim/vx/ops/conv3d_test.cc | 136 +++++++++++++++++ 9 files changed, 489 insertions(+) create mode 100644 include/tim/vx/ops/conv3d.h create mode 100644 src/tim/transform/ops/conv3d_layout_inference.h create mode 100644 src/tim/vx/ops/conv3d.cc create mode 100644 src/tim/vx/ops/conv3d_test.cc diff --git a/include/tim/vx/ops.h b/include/tim/vx/ops.h index 1382d34..7e36e3f 100644 --- a/include/tim/vx/ops.h +++ b/include/tim/vx/ops.h @@ -84,5 +84,6 @@ #include "tim/vx/ops/transpose.h" #include "tim/vx/ops/unidirectional_sequence_lstm.h" #include "tim/vx/ops/unstack.h" +#include "tim/vx/ops/conv3d.h" #endif /* TIM_VX_OPS_H_ */ diff --git a/include/tim/vx/ops/conv3d.h b/include/tim/vx/ops/conv3d.h new file mode 100644 index 0000000..c57b677 --- /dev/null +++ b/include/tim/vx/ops/conv3d.h @@ -0,0 +1,102 @@ +/**************************************************************************** +* +* Copyright (c) 2020 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#ifndef TIM_VX_OPS_CONV3D_H_ +#define TIM_VX_OPS_CONV3D_H_ + +#include +#include "tim/vx/direct_map_op.h" + +namespace tim { +namespace vx { +namespace ops { + +/** + * ## Conv3d + * + * Performs a 3-D convolution operation + * + * Input: + * - input [WHDCN]. + * - kernel [ WHDIcOc ] (Ic: Input Channels. Oc: Output Channels). + * - bias [ O ]. Optional. + * + * Attribute: + * - weights : the output channel number for weight tensor. + * - ksize : the height and width for weight tensor. + * - padding : AUTO, VALID or SAME. + * - pad : pad value for each spatial axis. (left, right, top, bottom, front, rear). + * - stride : stride along each spatial axis. + * - dilation : dilation value along each spatial axis of the filter. + * - multiplier: function similar to group attribute on other framework, + * but the value is different. multiplier = weights / group. + * - input_layout : WHDCN or WHCDN. + * - kernel_layout : WHDIcOc + */ + +class Conv3d : public DirectMapOp { + public: + Conv3d(Graph* graph, PadType padding, + const std::array& stride, + const std::array& dilation, int32_t multiplier = 0, + DataLayout input_layout = DataLayout::WHDCN, + DataLayout kernel_layout = DataLayout::WHDIcOc); + Conv3d(Graph* graph, const std::array pad, + const std::array& stride, + const std::array& dilation, int32_t multiplier = 0, + DataLayout input_layout = DataLayout::WHDCN, + DataLayout kernel_layout = DataLayout::WHDIcOc); + Conv3d(Graph* graph, int32_t weights, PadType padding, + const std::array& ksize, + const std::array& stride, + const std::array& dilation, int32_t multiplier = 0, + DataLayout input_layout = DataLayout::WHDCN, + DataLayout kernel_layout = DataLayout::WHDIcOc); + Conv3d(Graph* graph, int32_t weights, PadType padding, + const std::array& ksize, + const std::array& stride, + const std::array& dilation, + const std::array& pad, int32_t multiplier = 0, + DataLayout input_layout = DataLayout::WHDCN, + DataLayout kernel_layout = DataLayout::WHDIcOc); + + DataLayout KernelDataLayout() { return kernel_layout_; } + + std::shared_ptr Clone(std::shared_ptr& graph) const override; + + protected: + const int32_t weights_; + const PadType padding_; + const std::array ksize_; + const std::array stride_; + const std::array dilation_; + const std::array pad_; + const int32_t multiplier_; + const DataLayout kernel_layout_; +}; + +} // namespace ops +} // namespace vx +} // namespace tim + +#endif /* TIM_VX_OPS_CONV3D_H_ */ \ No newline at end of file diff --git a/include/tim/vx/types.h b/include/tim/vx/types.h index 150744b..d7f74b7 100644 --- a/include/tim/vx/types.h +++ b/include/tim/vx/types.h @@ -72,6 +72,10 @@ enum class DataLayout { WHIcOc, /*TIM-VX default*/ WCN, /*for conv1d*/ WIcOc, /*for conv1d*/ + WHDCN, /* pytorch conv3d input */ + WHDIcOc, /* pytorch conv3d kernel */ + CWHDN, /* tensorflow conv3d input */ + OcIcWHD, /* tensorflow conv3d kernel */ }; } // namespace vx diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index 632d417..a576a00 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -58,6 +58,7 @@ #include "ops/arg_layout_inference.h" #include "ops/deconv2d_layout_inference.h" #include "ops/batchnorm_layout_inference.h" +#include "ops/conv3d_layout_inference.h" #include "ops/default_layout_inference.h" #include "ops/transpose_layout_inference.h" @@ -259,6 +260,7 @@ std::vector> HandleLayoutInfer( REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DECONVOLUTION, DeConv2d); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BATCH_NORM, BatchNorm); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_PERMUTE, Transpose); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_CONV3D, Conv3d); REGIST_LOGICAL_LAYOUT_INFERENCE(VSI_NN_OP_LOGICAL_OPS); REGIST_REDUCE_LAYOUT_INFERENCE(VSI_NN_OP_REDUCE); // use default layout inference diff --git a/src/tim/transform/ops/conv3d_layout_inference.h b/src/tim/transform/ops/conv3d_layout_inference.h new file mode 100644 index 0000000..980b958 --- /dev/null +++ b/src/tim/transform/ops/conv3d_layout_inference.h @@ -0,0 +1,142 @@ +/**************************************************************************** + * + * Copyright (c) 2020 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#ifndef TIM_LAYOUT_INFER_CONV3D_LAYOUT_INFERENCE_H_ +#define TIM_LAYOUT_INFER_CONV3D_LAYOUT_INFERENCE_H_ + +#include "tim/vx/ops/conv3d.h" + +#include "permute_vector.h" +#include "ops/op_layout_inference.h" + +namespace tim { +namespace transform { +class Conv3dLayoutInfer : public OpLayoutInfer { + public: + Conv3dLayoutInfer( + const std::shared_ptr op, + std::shared_ptr& context) + : OpLayoutInfer(op, context) {} + void OnInputs( + std::vector>& next_tensors) override { + vx::DataLayout layout = op_->impl()->layout_; + auto required_pv = MakeShared(5); + if (layout == vx::DataLayout::CWHDN) { + required_pv = std::make_shared>(kCWHDN2WHDCN); + } + auto input_tensors = op_->impl()->InputsTensor(); + + for (const auto& in : input_tensors) { + std::shared_ptr infer_tensor; + std::shared_ptr trans_pv; + if (in->IsConstTensor() && + !(in->GetSpec().attr_ & vx::TensorAttribute::INPUT)) { + // For bias + if (in->GetShape().size() == 1) { + infer_tensor = context_->infer_graph_->CreateTensor( + in->GetSpec(), in->GetDataRef()); + trans_pv = MakeShared(1); + } else { + // For input/weight + if (!required_pv->IsAligned()) { + auto src_conv3d = std::static_pointer_cast(op_); + if (src_conv3d->KernelDataLayout() == vx::DataLayout::OcIcWHD) { + trans_pv = std::make_shared>(kOcIcWHD2WHDIcOc); + infer_tensor = PermuteConstTensor( + in, trans_pv); + } else { + infer_tensor = PermuteConstTensor(in, required_pv); + trans_pv = required_pv; + } + } else { + infer_tensor = context_->infer_graph_->CreateTensor( + in->GetSpec(), in->GetDataRef()); + trans_pv = MakeShared(required_pv->Rank()); + } + } + } else { + // For bias + if (in->GetShape().size() == 1) { + infer_tensor = context_->GetMapedTensor(in); + trans_pv = MakeShared(1); + } else { + // For input/weight + auto pv = context_->GetPermuteVector(in); + auto final_pv = pv->Reverse()->Add(required_pv); + if (!final_pv->IsAligned()) { + infer_tensor = + InsertPermute(context_->GetMapedTensor(in), final_pv); + trans_pv = required_pv; + } else { + infer_tensor = context_->GetMapedTensor(in); + trans_pv = pv; + } + } + } + context_->UpdateTensorMap(in, infer_tensor); + context_->SetPermuteVector(in, trans_pv); + } + + auto pad_type = TranslatePadType(op_->impl()->node()->nn_param.conv3d.pad_type); + std::array ksize = { + op_->impl()->node()->nn_param.conv3d.ksize[0], + op_->impl()->node()->nn_param.conv3d.ksize[1], + op_->impl()->node()->nn_param.conv3d.ksize[2]}; + std::array stride = { + op_->impl()->node()->nn_param.conv3d.stride[0], + op_->impl()->node()->nn_param.conv3d.stride[1], + op_->impl()->node()->nn_param.conv3d.stride[2] + }; + std::array dilation = { + op_->impl()->node()->nn_param.conv3d.dilation[0], + op_->impl()->node()->nn_param.conv3d.dilation[1], + op_->impl()->node()->nn_param.conv3d.dilation[2] + }; + std::array pad = { + op_->impl()->node()->nn_param.conv3d.pad[0], + op_->impl()->node()->nn_param.conv3d.pad[1], + op_->impl()->node()->nn_param.conv3d.pad[2], + op_->impl()->node()->nn_param.conv3d.pad[3], + op_->impl()->node()->nn_param.conv3d.pad[4], + op_->impl()->node()->nn_param.conv3d.pad[5] + }; + int32_t multiplier = op_->impl()->node()->nn_param.conv3d.multiplier; + int32_t out_channels = op_->impl()->node()->nn_param.conv3d.weights; + auto conv3d = context_->infer_graph_->CreateOperation( + out_channels, pad_type, ksize, stride, dilation, pad, multiplier, + vx::DataLayout::WHDCN, vx::DataLayout::WHDIcOc); + auto otensor_infer = CreateOutputsTensor(required_pv); + for (const auto& i_src : input_tensors) { + (*conv3d).BindInput(context_->GetMapedTensor(i_src)); + } + (*conv3d).BindOutput(otensor_infer[0]); + context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv); + // Add out tensor of src_graph into next_tensor + next_tensors.push_back(op_->impl()->OutputsTensor()[0]); + } +}; + +} // namespace transform +} // namespace tim + +#endif \ No newline at end of file diff --git a/src/tim/transform/ops/op_layout_inference.cc b/src/tim/transform/ops/op_layout_inference.cc index 9e89f43..65608d8 100644 --- a/src/tim/transform/ops/op_layout_inference.cc +++ b/src/tim/transform/ops/op_layout_inference.cc @@ -322,10 +322,13 @@ bool OpLayoutInfer::TransposeConstTensorData( std::vector perm = KOcHWIc2OcIcHW; std::vector tmp_vec0 = kOcIcWH2WHIcOc; std::vector tmp_vec1 = kIcOcWH2WHIcOc; + std::vector tmp_vec2 = kOcIcWHD2WHDIcOc; if (pv->AsStdVec() == tmp_vec0) { perm = kHWIcOc2OcIcHW; } else if (pv->AsStdVec() == tmp_vec1) { perm = kHWOcIc2OcIcHW; + } else if (pv->AsStdVec() == tmp_vec2) { + perm = kDHWIcOc2OcIcDHW; } std::vector native_shape_array; diff --git a/src/tim/transform/ops/op_layout_inference.h b/src/tim/transform/ops/op_layout_inference.h index ed3c45d..ea7e5e5 100644 --- a/src/tim/transform/ops/op_layout_inference.h +++ b/src/tim/transform/ops/op_layout_inference.h @@ -44,6 +44,10 @@ constexpr std::initializer_list kHWOcIc2OcIcHW = {2, 3, 0, 1}; constexpr std::initializer_list kOcIcWH2WHIcOc = {2, 3, 1, 0}; constexpr std::initializer_list kIcOcWH2WHIcOc = {2, 3, 0, 1}; +constexpr std::initializer_list kCWHDN2WHDCN = {1, 2, 3, 0, 4}; +constexpr std::initializer_list kOcIcWHD2WHDIcOc = {2, 3, 4, 1, 0}; +constexpr std::initializer_list kDHWIcOc2OcIcDHW = {4, 3, 0, 1, 2}; + class OpLayoutInfer { public: OpLayoutInfer( diff --git a/src/tim/vx/ops/conv3d.cc b/src/tim/vx/ops/conv3d.cc new file mode 100644 index 0000000..7135a60 --- /dev/null +++ b/src/tim/vx/ops/conv3d.cc @@ -0,0 +1,95 @@ +/**************************************************************************** +* +* Copyright (c) 2020 Vivante Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +* DEALINGS IN THE SOFTWARE. +* +*****************************************************************************/ +#include "tim/vx/ops/conv3d.h" +#include "direct_map_op_impl.h" +#include "type_utils.h" +#include "vsi_nn_pub.h" + +namespace tim { +namespace vx { +namespace ops { +Conv3d::Conv3d(Graph* graph, PadType padding, + const std::array& stride, + const std::array& dilation, int32_t multiplier, + DataLayout input_layout, DataLayout kernel_layout) + : Conv3d(graph, 0, padding, {0, 0, 0}, stride, dilation, {0, 0, 0, 0, 0, 0}, + multiplier, input_layout, kernel_layout) {} + +Conv3d::Conv3d(Graph* graph, const std::array pad, + const std::array& stride, + const std::array& dilation, int32_t multiplier, + DataLayout input_layout, DataLayout kernel_layout) + : Conv3d(graph, 0, PadType::AUTO, {0, 0, 0}, stride, dilation, pad, + multiplier, input_layout, kernel_layout) {} + +Conv3d::Conv3d(Graph* graph, int32_t weights, PadType padding, + const std::array& ksize, + const std::array& stride, + const std::array& dilation, int32_t multiplier, + DataLayout input_layout, DataLayout kernel_layout) + : Conv3d(graph, weights, padding, ksize, stride, dilation, + {0, 0, 0, 0, 0, 0}, multiplier, input_layout, kernel_layout) {} + +Conv3d::Conv3d(Graph* graph, int32_t weights, PadType padding, + const std::array& ksize, + const std::array& stride, + const std::array& dilation, + const std::array& pad, int32_t multiplier, + DataLayout input_layout, DataLayout kernel_layout) + : DirectMapOp(graph, VSI_NN_OP_CONV3D, 0, 0, input_layout), + weights_(weights), + padding_(padding), + ksize_(ksize), + stride_(stride), + dilation_(dilation), + pad_(pad), + multiplier_(multiplier), + kernel_layout_(kernel_layout) { + this->impl()->node()->nn_param.conv3d.stride[0] = stride_[0]; + this->impl()->node()->nn_param.conv3d.stride[1] = stride_[1]; + this->impl()->node()->nn_param.conv3d.stride[2] = stride_[2]; + this->impl()->node()->nn_param.conv3d.pad_type = TranslatePadType(padding_); + this->impl()->node()->nn_param.conv3d.dilation[0] = dilation_[0]; + this->impl()->node()->nn_param.conv3d.dilation[1] = dilation_[1]; + this->impl()->node()->nn_param.conv3d.dilation[2] = dilation_[2]; + this->impl()->node()->nn_param.conv3d.pad[0] = pad_[0]; + this->impl()->node()->nn_param.conv3d.pad[1] = pad_[1]; + this->impl()->node()->nn_param.conv3d.pad[2] = pad_[2]; + this->impl()->node()->nn_param.conv3d.pad[3] = pad_[3]; + this->impl()->node()->nn_param.conv3d.pad[4] = pad_[4]; + this->impl()->node()->nn_param.conv3d.pad[5] = pad_[5]; + this->impl()->node()->nn_param.conv3d.weights = weights_; + this->impl()->node()->nn_param.conv3d.multiplier = multiplier_; +} + +std::shared_ptr Conv3d::Clone(std::shared_ptr& graph) const { + return graph->CreateOperation( + this->weights_, this->padding_, this->ksize_, this->stride_, + this->dilation_, this->pad_, this->multiplier_, this->impl_->layout_, + this->kernel_layout_); +} + +} // namespace ops +} // namespace vx +} // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/conv3d_test.cc b/src/tim/vx/ops/conv3d_test.cc new file mode 100644 index 0000000..53ff47e --- /dev/null +++ b/src/tim/vx/ops/conv3d_test.cc @@ -0,0 +1,136 @@ +#include "tim/vx/ops/conv3d.h" +#include "tim/transform/layout_inference.h" + +#include "gtest/gtest.h" +#include "test_utils.h" +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/types.h" + +TEST(Conv3d, shape_1_1_2_3_3_float32_simple_whdcn) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({3, 3, 2, 1, 1}); //whdcn + tim::vx::ShapeType weight_shape({2, 2, 1, 1, 2}); //whdIcOc + tim::vx::ShapeType output_shape( + {2, 2, 2, weight_shape[4], input_shape[4]}); //whdcn + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec weight_spec(tim::vx::DataType::FLOAT32, weight_shape, + tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + std::vector input_data = { + 0.222290, -0.735840, -2.349609, 1.327148, 0.645020, 0.059631, + -1.081055, -1.307617, -0.306641, -0.520996, 0.041046, 3.234375, + -2.269531, -2.121094, 1.269531, -0.593750, -1.734375, -2.640625}; + + std::vector weight_data = {1.345703, 1.777344, -1.022461, -1.070312, + 1.372070, -0.918945, 0.480713, 1.415039}; + + // whdcn + std::vector golden = {-3.056641, -5.890625, 5.437500, 2.638672, + 3.962891, 6.613281, -4.359375, 4.000000, + 2.531250, 1.543945, -1.141602, -0.232300, + -4.843750, -2.138672, -3.904297, -8.648438}; + + auto input_tensor = graph->CreateTensor(input_spec); + auto weight_tensor = graph->CreateTensor(weight_spec, weight_data.data()); + + auto output_tensor = graph->CreateTensor(output_spec); + + std::array padding ({0, 0, 0, 0, 0, 0}); + std::array stride({1, 1, 1}); + std::array dilation({1, 1, 1}); + + auto conv3d = graph->CreateOperation( + padding, stride, dilation); + (*conv3d) + .BindInput(input_tensor) + .BindInput(weight_tensor) + .BindOutput(output_tensor); + + EXPECT_TRUE(graph->Compile()); + + input_tensor->CopyDataToTensor(input_data.data()); + + EXPECT_TRUE(graph->Run()); + + uint32_t output_size = 1; + for (auto i : output_tensor->GetShape()) { + output_size *= i; + } + std::vector output(output_size); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + for (uint32_t idx = 0; idx < golden.size(); idx++) { + EXPECT_TRUE(std::abs(golden[idx] - output[idx]) < 0.01); + } +} + +TEST(Conv3d, shape_1_1_2_3_3_float32_simple_cwhdn) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({1, 3, 3, 2, 1}); //cwhdn + tim::vx::ShapeType weight_shape({2, 1, 2, 2, 1}); //OcIcWHD + tim::vx::ShapeType output_shape( + {weight_shape[0], 2, 2, 2, input_shape[4]}); //cwhdn + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec weight_spec(tim::vx::DataType::FLOAT32, weight_shape, + tim::vx::TensorAttribute::CONSTANT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + std::vector input_data = { + 0.97471274, 0.76463452, 0.86721926, 0.92130888, 0.03260213, 0.08942557, + 0.44689693, 0.97484119, 0.55602722, 0.82500644, 0.9202445, 0.37466433, + 0.91804717, 0.56083073, 0.98317178, 0.60991722, 0.39409797, 0.40177473}; + + std::vector weight_data = {0.88074152, 0.43367621, 0.74519104, + 0.30248252, 0.93564262, 0.78602735, + 0.66508319, 0.84253425}; + + std::vector golden = {2.3119678, 1.4056407, 1.4096688, 0.69489276, + 1.902216, 1.5820216, 1.3772604, 1.2759123, + 2.6443386, 1.8302729, 2.268322, 1.7816017, + 2.0592608, 1.3792293, 1.8625461, 1.1888919}; + + auto input_tensor = graph->CreateTensor(input_spec); + auto weight_tensor = graph->CreateTensor(weight_spec, weight_data.data()); + + auto output_tensor = graph->CreateTensor(output_spec); + + std::array padding ({0, 0, 0, 0, 0, 0}); + std::array stride({1, 1, 1}); + std::array dilation({1, 1, 1}); + + auto conv3d = graph->CreateOperation( + padding, stride, dilation, 0, tim::vx::DataLayout::CWHDN, tim::vx::DataLayout::OcIcWHD); + (*conv3d) + .BindInput(input_tensor) + .BindInput(weight_tensor) + .BindOutput(output_tensor); + + auto final_graph = tim::transform::LayoutInference(graph, ctx); + + EXPECT_TRUE(final_graph.first->Compile()); + + final_graph.second[input_tensor]->CopyDataToTensor(input_data.data()); + + EXPECT_TRUE(final_graph.first->Run()); + + uint32_t output_size = 1; + for (auto i : output_tensor->GetShape()) { + output_size *= i; + } + std::vector output(output_size); + EXPECT_TRUE(final_graph.second[output_tensor]->CopyDataFromTensor(output.data())); + for (uint32_t idx = 0; idx < golden.size(); idx++) { + EXPECT_TRUE(std::abs(golden[idx] - output[idx]) < 0.01); + } +} \ No newline at end of file