diff --git a/include/tim/vx/ops/conv2d.h b/include/tim/vx/ops/conv2d.h index 8143757..961a220 100644 --- a/include/tim/vx/ops/conv2d.h +++ b/include/tim/vx/ops/conv2d.h @@ -38,13 +38,17 @@ class Conv2d : public Operation { const std::array& ksize, const std::array& stride, const std::array& dilation, int32_t multiplier = 0, - DataLayout layout = DataLayout::WHCN); + DataLayout input_layout = DataLayout::WHCN, + DataLayout kernel_layout = DataLayout::WHIcOc); Conv2d(Graph* graph, int32_t weights, PadType padding, const std::array& ksize, const std::array& stride, const std::array& dilation, const std::array& pad, int32_t multiplier = 0, - DataLayout layout = DataLayout::WHCN); + DataLayout input_layout = DataLayout::WHCN, + DataLayout kernel_layout = DataLayout::WHIcOc); + + DataLayout KernelDataLayout() { return kernel_layout_; } protected: const uint32_t weights_; @@ -54,6 +58,7 @@ class Conv2d : public Operation { const std::array dilation_; const std::array pad_; const int32_t multiplier_; + const DataLayout kernel_layout_; }; } // namespace ops diff --git a/include/tim/vx/types.h b/include/tim/vx/types.h index 2cad3fc..f035a1b 100644 --- a/include/tim/vx/types.h +++ b/include/tim/vx/types.h @@ -74,7 +74,14 @@ enum class ActivationType { enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA }; -enum class DataLayout { WHCN, CWHN, ANY }; +enum class DataLayout { + WHCN, + CWHN, + ANY, + IcWHOc /*TF*/, + OcIcWH /*TVM*/, + WHIcOc /*TIM-VX default*/ +}; } // namespace vx } // namespace tim diff --git a/src/tim/transform/ops/conv2d_layout_inference.h b/src/tim/transform/ops/conv2d_layout_inference.h index fb0a0f2..6ddef10 100644 --- a/src/tim/transform/ops/conv2d_layout_inference.h +++ b/src/tim/transform/ops/conv2d_layout_inference.h @@ -61,8 +61,16 @@ class Conv2dLayoutInfer : public OpLayoutInfer { } else { // For input/weight if (!required_pv->IsAligned()) { - infer_tensor = PermuteConstTensor(in, required_pv); - trans_pv = required_pv; + auto src_conv2d = std::static_pointer_cast(op_); + // Support TVM Kernel Layout + if (src_conv2d->KernelDataLayout() == vx::DataLayout::OcIcWH) { + trans_pv = std::make_shared>(kOcIcWH2WHIcOc); + infer_tensor = PermuteConstTensor( + in, trans_pv); + } else { + infer_tensor = PermuteConstTensor(in, required_pv); + trans_pv = required_pv; + } } else { infer_tensor = context_->infer_graph_->CreateTensor( in->GetSpec(), in->GetDataRef()); @@ -108,11 +116,11 @@ class Conv2dLayoutInfer : public OpLayoutInfer { int32_t out_channels = op_->impl()->node()->nn_param.conv2d.weights; auto conv2d = context_->infer_graph_->CreateOperation( out_channels, pad_type, ksize, stride, dilation, pad, multiplier, - vx::DataLayout::WHCN); + vx::DataLayout::WHCN, vx::DataLayout::WHIcOc); auto otensor_infer = CreateOutputsTensor(required_pv); - (*conv2d).BindInputs({context_->GetMapedTensor(input_tensors[0]), - context_->GetMapedTensor(input_tensors[1]), - context_->GetMapedTensor(input_tensors[2])}); + for (const auto& i_src : input_tensors) { + (*conv2d).BindInput(context_->GetMapedTensor(i_src)); + } (*conv2d).BindOutput(otensor_infer[0]); context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv); // Add out tensor of src_graph into next_tensor diff --git a/src/tim/transform/ops/op_layout_inference.cc b/src/tim/transform/ops/op_layout_inference.cc index 716c407..17d5957 100644 --- a/src/tim/transform/ops/op_layout_inference.cc +++ b/src/tim/transform/ops/op_layout_inference.cc @@ -154,16 +154,36 @@ OpLayoutInfer::AlignPermuteVectorForMutilInputs() { auto src_inputs = op_->impl()->InputsTensor(); // Suppose the inputs have same dimension rank // TODO(yzw): should choose a optimal required_pv - auto required_pv = context_->GetPermuteVector(src_inputs[0]); - for (const auto& i_src : src_inputs) { - std::shared_ptr perm_out; - auto pv = context_->GetPermuteVector(i_src); - auto final_pv = pv->Reverse()->Add(required_pv); - if (!final_pv->IsAligned()) { + std::shared_ptr required_pv = nullptr; + for (const auto& in : src_inputs) { + if (!in->IsConstTensor()) { + required_pv = context_->GetPermuteVector(in); + break; + } + } + + if (!required_pv) { + // all inputs are constant tensors + for (const auto& i_src : src_inputs) { + context_->UpdateTensorMap( + i_src, context_->infer_graph_->CreateTensor(i_src->GetSpec(), + i_src->GetDataRef())); + context_->SetPermuteVector(i_src, MakeShared(i_src->GetShape().size())); + } + } else { + for (const auto& i_src : src_inputs) { + std::shared_ptr perm_out; if (i_src->IsConstTensor()) { - perm_out = PermuteConstTensor(i_src, final_pv); + required_pv->IsAligned() + ? perm_out = context_->infer_graph_->CreateTensor(i_src->GetSpec(), + i_src->GetDataRef()) + : perm_out = PermuteConstTensor(i_src, required_pv); } else { - perm_out = InsertPermute(context_->GetMapedTensor(i_src), final_pv); + auto final_pv = + context_->GetPermuteVector(i_src)->Reverse()->Add(required_pv); + final_pv->IsAligned() ? perm_out = context_->GetMapedTensor(i_src) + : perm_out = InsertPermute( + context_->GetMapedTensor(i_src), final_pv); } context_->UpdateTensorMap(i_src, perm_out); context_->SetPermuteVector(i_src, required_pv); @@ -175,17 +195,21 @@ OpLayoutInfer::AlignPermuteVectorForMutilInputs() { void OpLayoutInfer::ReverseInputsPermuteVector() { for (const auto& i_src : op_->impl()->InputsTensor()) { std::shared_ptr perm_out; - auto input_pv = context_->GetPermuteVector(i_src); - if (!input_pv->IsAligned()) { - if (i_src->IsConstTensor()) { - perm_out = PermuteConstTensor(i_src, input_pv); - } else { + std::shared_ptr input_pv; + if (i_src->IsConstTensor()) { + perm_out = context_->infer_graph_->CreateTensor(i_src->GetSpec(), + i_src->GetDataRef()); + input_pv = MakeShared(i_src->GetShape().size()); + } else { + perm_out = context_->GetMapedTensor(i_src); + input_pv = context_->GetPermuteVector(i_src); + if (!input_pv->IsAligned()) { perm_out = - InsertPermute(context_->GetMapedTensor(i_src), input_pv->Reverse()); + InsertPermute(perm_out, input_pv->Reverse()); } - context_->UpdateTensorMap(i_src, perm_out); - context_->SetPermuteVector(i_src, MakeShared(input_pv->Rank())); } + context_->UpdateTensorMap(i_src, perm_out); + context_->SetPermuteVector(i_src, MakeShared(input_pv->Rank())); } } @@ -202,12 +226,15 @@ bool OpLayoutInfer::TransposeConstTensorData( return false; } - std::vector perm = KOcHWIc2OcIcHW; vx::ShapeType reverse_shape; for (int32_t i = input->GetShape().size() - 1; i >= 0; i--) { reverse_shape.push_back(input->GetShape()[i]); } - + std::vector perm = KOcHWIc2OcIcHW; + std::vectortmp_vec = kOcIcWH2WHIcOc; + if (pv->AsStdVec() == tmp_vec) { + perm = kHWIcOc2OcIcHW; + } vsi_nn_Transpose(out_data.data(), (uint8_t*)(input->GetDataRef()), (uint32_t*)(reverse_shape.data()), static_cast(input->GetShape().size()), @@ -240,7 +267,7 @@ std::vector OpLayoutInfer::MapPadding(const std::vector& per assert(perm.size() == padding.size()); std::vector r(padding.size()); - for (int i = 0; i < padding.size(); ++i) { + for (uint32_t i = 0; i < padding.size(); ++i) { r[i] = padding[perm[i]]; } diff --git a/src/tim/transform/ops/op_layout_inference.h b/src/tim/transform/ops/op_layout_inference.h index 3badeee..61cb6cb 100644 --- a/src/tim/transform/ops/op_layout_inference.h +++ b/src/tim/transform/ops/op_layout_inference.h @@ -34,7 +34,12 @@ namespace tim { namespace transform { constexpr std::initializer_list kCWHN2WHCN = {1, 2, 0, 3}; + constexpr std::initializer_list KOcHWIc2OcIcHW = {0, 3, 1, 2}; +constexpr std::initializer_list kIcWHOc2WHIcOc = {1, 2, 0, 3}; + +constexpr std::initializer_list kHWIcOc2OcIcHW = {3, 2, 0, 1}; +constexpr std::initializer_list kOcIcWH2WHIcOc = {2, 3, 1, 0}; class OpLayoutInfer { public: diff --git a/src/tim/vx/ops/conv2d.cc b/src/tim/vx/ops/conv2d.cc index 0be25c9..a166a40 100644 --- a/src/tim/vx/ops/conv2d.cc +++ b/src/tim/vx/ops/conv2d.cc @@ -34,24 +34,26 @@ namespace ops { Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding, const std::array& ksize, const std::array& stride, - const std::array& dilation, - int32_t multiplier, DataLayout layout) - : Conv2d(graph, weights, padding, ksize, stride, dilation, - {0, 0, 0, 0}, multiplier, layout) {} + const std::array& dilation, int32_t multiplier, + DataLayout input_layout, DataLayout kernel_layout) + : Conv2d(graph, weights, padding, ksize, stride, dilation, {0, 0, 0, 0}, + multiplier, input_layout, kernel_layout) {} Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding, const std::array& ksize, const std::array& stride, const std::array& dilation, - const std::array& pad, int32_t multiplier, DataLayout layout) - : Operation(graph, VSI_NN_OP_CONV2D, 0, 0, layout), + const std::array& pad, int32_t multiplier, + DataLayout input_layout, DataLayout kernel_layout) + : Operation(graph, VSI_NN_OP_CONV2D, 0, 0, input_layout), weights_(weights), padding_(padding), ksize_(ksize), stride_(stride), dilation_(dilation), pad_(pad), - multiplier_(multiplier) { + multiplier_(multiplier), + kernel_layout_(kernel_layout) { this->impl()->node()->nn_param.conv2d.ksize[0] = ksize_[0]; this->impl()->node()->nn_param.conv2d.ksize[1] = ksize_[1]; this->impl()->node()->nn_param.conv2d.stride[0] = stride_[0];