support conv3d (#238)

Signed-off-by: Zongwu Yang <zongwu.yang@verisilicon.com>
This commit is contained in:
Zongwu.Yang 2022-01-11 14:13:15 +08:00 committed by GitHub
parent ff25226adb
commit 4229ad88b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 489 additions and 0 deletions

View File

@ -84,5 +84,6 @@
#include "tim/vx/ops/transpose.h"
#include "tim/vx/ops/unidirectional_sequence_lstm.h"
#include "tim/vx/ops/unstack.h"
#include "tim/vx/ops/conv3d.h"
#endif /* TIM_VX_OPS_H_ */

102
include/tim/vx/ops/conv3d.h Normal file
View File

@ -0,0 +1,102 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_VX_OPS_CONV3D_H_
#define TIM_VX_OPS_CONV3D_H_
#include <array>
#include "tim/vx/direct_map_op.h"
namespace tim {
namespace vx {
namespace ops {
/**
* ## Conv3d
*
* Performs a 3-D convolution operation
*
* Input:
* - input [WHDCN].
* - kernel [ WHDIcOc ] (Ic: Input Channels. Oc: Output Channels).
* - bias [ O ]. Optional.
*
* Attribute:
* - weights : the output channel number for weight tensor.
* - ksize : the height and width for weight tensor.
* - padding : AUTO, VALID or SAME.
* - pad : pad value for each spatial axis. (left, right, top, bottom, front, rear).
* - stride : stride along each spatial axis.
* - dilation : dilation value along each spatial axis of the filter.
* - multiplier: function similar to group attribute on other framework,
* but the value is different. multiplier = weights / group.
* - input_layout : WHDCN or WHCDN.
* - kernel_layout : WHDIcOc
*/
class Conv3d : public DirectMapOp {
public:
Conv3d(Graph* graph, PadType padding,
const std::array<int32_t, 3>& stride,
const std::array<int32_t, 3>& dilation, int32_t multiplier = 0,
DataLayout input_layout = DataLayout::WHDCN,
DataLayout kernel_layout = DataLayout::WHDIcOc);
Conv3d(Graph* graph, const std::array<int32_t, 6> pad,
const std::array<int32_t, 3>& stride,
const std::array<int32_t, 3>& dilation, int32_t multiplier = 0,
DataLayout input_layout = DataLayout::WHDCN,
DataLayout kernel_layout = DataLayout::WHDIcOc);
Conv3d(Graph* graph, int32_t weights, PadType padding,
const std::array<int32_t, 3>& ksize,
const std::array<int32_t, 3>& stride,
const std::array<int32_t, 3>& dilation, int32_t multiplier = 0,
DataLayout input_layout = DataLayout::WHDCN,
DataLayout kernel_layout = DataLayout::WHDIcOc);
Conv3d(Graph* graph, int32_t weights, PadType padding,
const std::array<int32_t, 3>& ksize,
const std::array<int32_t, 3>& stride,
const std::array<int32_t, 3>& dilation,
const std::array<int32_t, 6>& pad, int32_t multiplier = 0,
DataLayout input_layout = DataLayout::WHDCN,
DataLayout kernel_layout = DataLayout::WHDIcOc);
DataLayout KernelDataLayout() { return kernel_layout_; }
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
protected:
const int32_t weights_;
const PadType padding_;
const std::array<int32_t, 3> ksize_;
const std::array<int32_t, 3> stride_;
const std::array<int32_t, 3> dilation_;
const std::array<int32_t, 6> pad_;
const int32_t multiplier_;
const DataLayout kernel_layout_;
};
} // namespace ops
} // namespace vx
} // namespace tim
#endif /* TIM_VX_OPS_CONV3D_H_ */

View File

@ -72,6 +72,10 @@ enum class DataLayout {
WHIcOc, /*TIM-VX default*/
WCN, /*for conv1d*/
WIcOc, /*for conv1d*/
WHDCN, /* pytorch conv3d input */
WHDIcOc, /* pytorch conv3d kernel */
CWHDN, /* tensorflow conv3d input */
OcIcWHD, /* tensorflow conv3d kernel */
};
} // namespace vx

View File

@ -58,6 +58,7 @@
#include "ops/arg_layout_inference.h"
#include "ops/deconv2d_layout_inference.h"
#include "ops/batchnorm_layout_inference.h"
#include "ops/conv3d_layout_inference.h"
#include "ops/default_layout_inference.h"
#include "ops/transpose_layout_inference.h"
@ -259,6 +260,7 @@ std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DECONVOLUTION, DeConv2d);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BATCH_NORM, BatchNorm);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_PERMUTE, Transpose);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_CONV3D, Conv3d);
REGIST_LOGICAL_LAYOUT_INFERENCE(VSI_NN_OP_LOGICAL_OPS);
REGIST_REDUCE_LAYOUT_INFERENCE(VSI_NN_OP_REDUCE);
// use default layout inference

View File

@ -0,0 +1,142 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_CONV3D_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_CONV3D_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/conv3d.h"
#include "permute_vector.h"
#include "ops/op_layout_inference.h"
namespace tim {
namespace transform {
class Conv3dLayoutInfer : public OpLayoutInfer {
public:
Conv3dLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
vx::DataLayout layout = op_->impl()->layout_;
auto required_pv = MakeShared(5);
if (layout == vx::DataLayout::CWHDN) {
required_pv = std::make_shared<PermuteVector<5>>(kCWHDN2WHDCN);
}
auto input_tensors = op_->impl()->InputsTensor();
for (const auto& in : input_tensors) {
std::shared_ptr<vx::Tensor> infer_tensor;
std::shared_ptr<IPermuteVector> trans_pv;
if (in->IsConstTensor() &&
!(in->GetSpec().attr_ & vx::TensorAttribute::INPUT)) {
// For bias
if (in->GetShape().size() == 1) {
infer_tensor = context_->infer_graph_->CreateTensor(
in->GetSpec(), in->GetDataRef());
trans_pv = MakeShared(1);
} else {
// For input/weight
if (!required_pv->IsAligned()) {
auto src_conv3d = std::static_pointer_cast<vx::ops::Conv3d>(op_);
if (src_conv3d->KernelDataLayout() == vx::DataLayout::OcIcWHD) {
trans_pv = std::make_shared<PermuteVector<5>>(kOcIcWHD2WHDIcOc);
infer_tensor = PermuteConstTensor(
in, trans_pv);
} else {
infer_tensor = PermuteConstTensor(in, required_pv);
trans_pv = required_pv;
}
} else {
infer_tensor = context_->infer_graph_->CreateTensor(
in->GetSpec(), in->GetDataRef());
trans_pv = MakeShared(required_pv->Rank());
}
}
} else {
// For bias
if (in->GetShape().size() == 1) {
infer_tensor = context_->GetMapedTensor(in);
trans_pv = MakeShared(1);
} else {
// For input/weight
auto pv = context_->GetPermuteVector(in);
auto final_pv = pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
infer_tensor =
InsertPermute(context_->GetMapedTensor(in), final_pv);
trans_pv = required_pv;
} else {
infer_tensor = context_->GetMapedTensor(in);
trans_pv = pv;
}
}
}
context_->UpdateTensorMap(in, infer_tensor);
context_->SetPermuteVector(in, trans_pv);
}
auto pad_type = TranslatePadType(op_->impl()->node()->nn_param.conv3d.pad_type);
std::array<int32_t, 3> ksize = {
op_->impl()->node()->nn_param.conv3d.ksize[0],
op_->impl()->node()->nn_param.conv3d.ksize[1],
op_->impl()->node()->nn_param.conv3d.ksize[2]};
std::array<int32_t, 3> stride = {
op_->impl()->node()->nn_param.conv3d.stride[0],
op_->impl()->node()->nn_param.conv3d.stride[1],
op_->impl()->node()->nn_param.conv3d.stride[2]
};
std::array<int32_t, 3> dilation = {
op_->impl()->node()->nn_param.conv3d.dilation[0],
op_->impl()->node()->nn_param.conv3d.dilation[1],
op_->impl()->node()->nn_param.conv3d.dilation[2]
};
std::array<int32_t, 6> pad = {
op_->impl()->node()->nn_param.conv3d.pad[0],
op_->impl()->node()->nn_param.conv3d.pad[1],
op_->impl()->node()->nn_param.conv3d.pad[2],
op_->impl()->node()->nn_param.conv3d.pad[3],
op_->impl()->node()->nn_param.conv3d.pad[4],
op_->impl()->node()->nn_param.conv3d.pad[5]
};
int32_t multiplier = op_->impl()->node()->nn_param.conv3d.multiplier;
int32_t out_channels = op_->impl()->node()->nn_param.conv3d.weights;
auto conv3d = context_->infer_graph_->CreateOperation<vx::ops::Conv3d>(
out_channels, pad_type, ksize, stride, dilation, pad, multiplier,
vx::DataLayout::WHDCN, vx::DataLayout::WHDIcOc);
auto otensor_infer = CreateOutputsTensor(required_pv);
for (const auto& i_src : input_tensors) {
(*conv3d).BindInput(context_->GetMapedTensor(i_src));
}
(*conv3d).BindOutput(otensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
// Add out tensor of src_graph into next_tensor
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -322,10 +322,13 @@ bool OpLayoutInfer::TransposeConstTensorData(
std::vector<uint32_t> perm = KOcHWIc2OcIcHW;
std::vector<uint32_t> tmp_vec0 = kOcIcWH2WHIcOc;
std::vector<uint32_t> tmp_vec1 = kIcOcWH2WHIcOc;
std::vector<uint32_t> tmp_vec2 = kOcIcWHD2WHDIcOc;
if (pv->AsStdVec() == tmp_vec0) {
perm = kHWIcOc2OcIcHW;
} else if (pv->AsStdVec() == tmp_vec1) {
perm = kHWOcIc2OcIcHW;
} else if (pv->AsStdVec() == tmp_vec2) {
perm = kDHWIcOc2OcIcDHW;
}
std::vector<vsi_size_t> native_shape_array;

View File

@ -44,6 +44,10 @@ constexpr std::initializer_list<uint32_t> kHWOcIc2OcIcHW = {2, 3, 0, 1};
constexpr std::initializer_list<uint32_t> kOcIcWH2WHIcOc = {2, 3, 1, 0};
constexpr std::initializer_list<uint32_t> kIcOcWH2WHIcOc = {2, 3, 0, 1};
constexpr std::initializer_list<uint32_t> kCWHDN2WHDCN = {1, 2, 3, 0, 4};
constexpr std::initializer_list<uint32_t> kOcIcWHD2WHDIcOc = {2, 3, 4, 1, 0};
constexpr std::initializer_list<uint32_t> kDHWIcOc2OcIcDHW = {4, 3, 0, 1, 2};
class OpLayoutInfer {
public:
OpLayoutInfer(

95
src/tim/vx/ops/conv3d.cc Normal file
View File

@ -0,0 +1,95 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include "tim/vx/ops/conv3d.h"
#include "direct_map_op_impl.h"
#include "type_utils.h"
#include "vsi_nn_pub.h"
namespace tim {
namespace vx {
namespace ops {
Conv3d::Conv3d(Graph* graph, PadType padding,
const std::array<int32_t, 3>& stride,
const std::array<int32_t, 3>& dilation, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: Conv3d(graph, 0, padding, {0, 0, 0}, stride, dilation, {0, 0, 0, 0, 0, 0},
multiplier, input_layout, kernel_layout) {}
Conv3d::Conv3d(Graph* graph, const std::array<int32_t, 6> pad,
const std::array<int32_t, 3>& stride,
const std::array<int32_t, 3>& dilation, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: Conv3d(graph, 0, PadType::AUTO, {0, 0, 0}, stride, dilation, pad,
multiplier, input_layout, kernel_layout) {}
Conv3d::Conv3d(Graph* graph, int32_t weights, PadType padding,
const std::array<int32_t, 3>& ksize,
const std::array<int32_t, 3>& stride,
const std::array<int32_t, 3>& dilation, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: Conv3d(graph, weights, padding, ksize, stride, dilation,
{0, 0, 0, 0, 0, 0}, multiplier, input_layout, kernel_layout) {}
Conv3d::Conv3d(Graph* graph, int32_t weights, PadType padding,
const std::array<int32_t, 3>& ksize,
const std::array<int32_t, 3>& stride,
const std::array<int32_t, 3>& dilation,
const std::array<int32_t, 6>& pad, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: DirectMapOp(graph, VSI_NN_OP_CONV3D, 0, 0, input_layout),
weights_(weights),
padding_(padding),
ksize_(ksize),
stride_(stride),
dilation_(dilation),
pad_(pad),
multiplier_(multiplier),
kernel_layout_(kernel_layout) {
this->impl()->node()->nn_param.conv3d.stride[0] = stride_[0];
this->impl()->node()->nn_param.conv3d.stride[1] = stride_[1];
this->impl()->node()->nn_param.conv3d.stride[2] = stride_[2];
this->impl()->node()->nn_param.conv3d.pad_type = TranslatePadType(padding_);
this->impl()->node()->nn_param.conv3d.dilation[0] = dilation_[0];
this->impl()->node()->nn_param.conv3d.dilation[1] = dilation_[1];
this->impl()->node()->nn_param.conv3d.dilation[2] = dilation_[2];
this->impl()->node()->nn_param.conv3d.pad[0] = pad_[0];
this->impl()->node()->nn_param.conv3d.pad[1] = pad_[1];
this->impl()->node()->nn_param.conv3d.pad[2] = pad_[2];
this->impl()->node()->nn_param.conv3d.pad[3] = pad_[3];
this->impl()->node()->nn_param.conv3d.pad[4] = pad_[4];
this->impl()->node()->nn_param.conv3d.pad[5] = pad_[5];
this->impl()->node()->nn_param.conv3d.weights = weights_;
this->impl()->node()->nn_param.conv3d.multiplier = multiplier_;
}
std::shared_ptr<Operation> Conv3d::Clone(std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Conv3d>(
this->weights_, this->padding_, this->ksize_, this->stride_,
this->dilation_, this->pad_, this->multiplier_, this->impl_->layout_,
this->kernel_layout_);
}
} // namespace ops
} // namespace vx
} // namespace tim

View File

@ -0,0 +1,136 @@
#include "tim/vx/ops/conv3d.h"
#include "tim/transform/layout_inference.h"
#include "gtest/gtest.h"
#include "test_utils.h"
#include "tim/vx/context.h"
#include "tim/vx/graph.h"
#include "tim/vx/types.h"
TEST(Conv3d, shape_1_1_2_3_3_float32_simple_whdcn) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType input_shape({3, 3, 2, 1, 1}); //whdcn
tim::vx::ShapeType weight_shape({2, 2, 1, 1, 2}); //whdIcOc
tim::vx::ShapeType output_shape(
{2, 2, 2, weight_shape[4], input_shape[4]}); //whdcn
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec weight_spec(tim::vx::DataType::FLOAT32, weight_shape,
tim::vx::TensorAttribute::CONSTANT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
std::vector<float> input_data = {
0.222290, -0.735840, -2.349609, 1.327148, 0.645020, 0.059631,
-1.081055, -1.307617, -0.306641, -0.520996, 0.041046, 3.234375,
-2.269531, -2.121094, 1.269531, -0.593750, -1.734375, -2.640625};
std::vector<float> weight_data = {1.345703, 1.777344, -1.022461, -1.070312,
1.372070, -0.918945, 0.480713, 1.415039};
// whdcn
std::vector<float> golden = {-3.056641, -5.890625, 5.437500, 2.638672,
3.962891, 6.613281, -4.359375, 4.000000,
2.531250, 1.543945, -1.141602, -0.232300,
-4.843750, -2.138672, -3.904297, -8.648438};
auto input_tensor = graph->CreateTensor(input_spec);
auto weight_tensor = graph->CreateTensor(weight_spec, weight_data.data());
auto output_tensor = graph->CreateTensor(output_spec);
std::array<int32_t, 6> padding ({0, 0, 0, 0, 0, 0});
std::array<int32_t, 3> stride({1, 1, 1});
std::array<int32_t, 3> dilation({1, 1, 1});
auto conv3d = graph->CreateOperation<tim::vx::ops::Conv3d>(
padding, stride, dilation);
(*conv3d)
.BindInput(input_tensor)
.BindInput(weight_tensor)
.BindOutput(output_tensor);
EXPECT_TRUE(graph->Compile());
input_tensor->CopyDataToTensor(input_data.data());
EXPECT_TRUE(graph->Run());
uint32_t output_size = 1;
for (auto i : output_tensor->GetShape()) {
output_size *= i;
}
std::vector<float> output(output_size);
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
for (uint32_t idx = 0; idx < golden.size(); idx++) {
EXPECT_TRUE(std::abs(golden[idx] - output[idx]) < 0.01);
}
}
TEST(Conv3d, shape_1_1_2_3_3_float32_simple_cwhdn) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
tim::vx::ShapeType input_shape({1, 3, 3, 2, 1}); //cwhdn
tim::vx::ShapeType weight_shape({2, 1, 2, 2, 1}); //OcIcWHD
tim::vx::ShapeType output_shape(
{weight_shape[0], 2, 2, 2, input_shape[4]}); //cwhdn
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec weight_spec(tim::vx::DataType::FLOAT32, weight_shape,
tim::vx::TensorAttribute::CONSTANT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
std::vector<float> input_data = {
0.97471274, 0.76463452, 0.86721926, 0.92130888, 0.03260213, 0.08942557,
0.44689693, 0.97484119, 0.55602722, 0.82500644, 0.9202445, 0.37466433,
0.91804717, 0.56083073, 0.98317178, 0.60991722, 0.39409797, 0.40177473};
std::vector<float> weight_data = {0.88074152, 0.43367621, 0.74519104,
0.30248252, 0.93564262, 0.78602735,
0.66508319, 0.84253425};
std::vector<float> golden = {2.3119678, 1.4056407, 1.4096688, 0.69489276,
1.902216, 1.5820216, 1.3772604, 1.2759123,
2.6443386, 1.8302729, 2.268322, 1.7816017,
2.0592608, 1.3792293, 1.8625461, 1.1888919};
auto input_tensor = graph->CreateTensor(input_spec);
auto weight_tensor = graph->CreateTensor(weight_spec, weight_data.data());
auto output_tensor = graph->CreateTensor(output_spec);
std::array<int32_t, 6> padding ({0, 0, 0, 0, 0, 0});
std::array<int32_t, 3> stride({1, 1, 1});
std::array<int32_t, 3> dilation({1, 1, 1});
auto conv3d = graph->CreateOperation<tim::vx::ops::Conv3d>(
padding, stride, dilation, 0, tim::vx::DataLayout::CWHDN, tim::vx::DataLayout::OcIcWHD);
(*conv3d)
.BindInput(input_tensor)
.BindInput(weight_tensor)
.BindOutput(output_tensor);
auto final_graph = tim::transform::LayoutInference(graph, ctx);
EXPECT_TRUE(final_graph.first->Compile());
final_graph.second[input_tensor]->CopyDataToTensor(input_data.data());
EXPECT_TRUE(final_graph.first->Run());
uint32_t output_size = 1;
for (auto i : output_tensor->GetShape()) {
output_size *= i;
}
std::vector<float> output(output_size);
EXPECT_TRUE(final_graph.second[output_tensor]->CopyDataFromTensor(output.data()));
for (uint32_t idx = 0; idx < golden.size(); idx++) {
EXPECT_TRUE(std::abs(golden[idx] - output[idx]) < 0.01);
}
}