From b38bd41933423cdf9b3c6c9028fd9943027f536f Mon Sep 17 00:00:00 2001 From: Antkillerfarm Date: Tue, 30 Nov 2021 21:33:14 +0800 Subject: [PATCH] add DataLayout::IcOcWH for TVM usage (#231) --- include/tim/vx/types.h | 7 ++++--- src/tim/transform/ops/conv2d_layout_inference.h | 4 ++++ src/tim/transform/ops/op_layout_inference.cc | 7 +++++-- src/tim/transform/ops/op_layout_inference.h | 3 +++ 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/include/tim/vx/types.h b/include/tim/vx/types.h index 9ea9727..d8acebf 100644 --- a/include/tim/vx/types.h +++ b/include/tim/vx/types.h @@ -63,11 +63,12 @@ enum class RoundingPolicy { TO_ZERO, RTNE }; enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA }; enum class DataLayout { + ANY, WHCN, CWHN, - ANY, - IcWHOc /*TF*/, - OcIcWH /*TVM*/, + IcWHOc, /*TF*/ + OcIcWH, /*TVM for classic conv2d in tflite model*/ + IcOcWH, /*TVM for depthwise conv2d in tflite model*/ WHIcOc /*TIM-VX default*/ }; diff --git a/src/tim/transform/ops/conv2d_layout_inference.h b/src/tim/transform/ops/conv2d_layout_inference.h index b24cdc5..368ea02 100644 --- a/src/tim/transform/ops/conv2d_layout_inference.h +++ b/src/tim/transform/ops/conv2d_layout_inference.h @@ -66,6 +66,10 @@ class Conv2dLayoutInfer : public OpLayoutInfer { trans_pv = std::make_shared>(kOcIcWH2WHIcOc); infer_tensor = PermuteConstTensor( in, trans_pv); + } else if (src_conv2d->KernelDataLayout() == vx::DataLayout::IcOcWH) { + trans_pv = std::make_shared>(kIcOcWH2WHIcOc); + infer_tensor = PermuteConstTensor( + in, trans_pv); } else { infer_tensor = PermuteConstTensor(in, required_pv); trans_pv = required_pv; diff --git a/src/tim/transform/ops/op_layout_inference.cc b/src/tim/transform/ops/op_layout_inference.cc index 5457290..8dbdb74 100644 --- a/src/tim/transform/ops/op_layout_inference.cc +++ b/src/tim/transform/ops/op_layout_inference.cc @@ -318,9 +318,12 @@ bool OpLayoutInfer::TransposeConstTensorData( reverse_shape.push_back(input->GetShape()[i]); } std::vector perm = KOcHWIc2OcIcHW; - std::vectortmp_vec = kOcIcWH2WHIcOc; - if (pv->AsStdVec() == tmp_vec) { + std::vectortmp_vec0 = kOcIcWH2WHIcOc; + std::vectortmp_vec1 = kIcOcWH2WHIcOc; + if (pv->AsStdVec() == tmp_vec0) { perm = kHWIcOc2OcIcHW; + } else if (pv->AsStdVec() == tmp_vec1) { + perm = kHWOcIc2OcIcHW; } vsi_nn_Transpose(out_data.data(), (uint8_t*)(input->GetDataRef()), (uint32_t*)(reverse_shape.data()), diff --git a/src/tim/transform/ops/op_layout_inference.h b/src/tim/transform/ops/op_layout_inference.h index 7fe1eb2..ed3c45d 100644 --- a/src/tim/transform/ops/op_layout_inference.h +++ b/src/tim/transform/ops/op_layout_inference.h @@ -39,7 +39,10 @@ constexpr std::initializer_list KOcHWIc2OcIcHW = {0, 3, 1, 2}; constexpr std::initializer_list kIcWHOc2WHIcOc = {1, 2, 0, 3}; constexpr std::initializer_list kHWIcOc2OcIcHW = {3, 2, 0, 1}; +constexpr std::initializer_list kHWOcIc2OcIcHW = {2, 3, 0, 1}; + constexpr std::initializer_list kOcIcWH2WHIcOc = {2, 3, 1, 0}; +constexpr std::initializer_list kIcOcWH2WHIcOc = {2, 3, 0, 1}; class OpLayoutInfer { public: