add DataLayout::IcOcWH for TVM usage (#231)

This commit is contained in:
Antkillerfarm 2021-11-30 21:33:14 +08:00 committed by GitHub
parent e001d53ddf
commit b38bd41933
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 5 deletions

View File

@ -63,11 +63,12 @@ enum class RoundingPolicy { TO_ZERO, RTNE };
enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA };
enum class DataLayout {
ANY,
WHCN,
CWHN,
ANY,
IcWHOc /*TF*/,
OcIcWH /*TVM*/,
IcWHOc, /*TF*/
OcIcWH, /*TVM for classic conv2d in tflite model*/
IcOcWH, /*TVM for depthwise conv2d in tflite model*/
WHIcOc /*TIM-VX default*/
};

View File

@ -66,6 +66,10 @@ class Conv2dLayoutInfer : public OpLayoutInfer {
trans_pv = std::make_shared<PermuteVector<4>>(kOcIcWH2WHIcOc);
infer_tensor = PermuteConstTensor(
in, trans_pv);
} else if (src_conv2d->KernelDataLayout() == vx::DataLayout::IcOcWH) {
trans_pv = std::make_shared<PermuteVector<4>>(kIcOcWH2WHIcOc);
infer_tensor = PermuteConstTensor(
in, trans_pv);
} else {
infer_tensor = PermuteConstTensor(in, required_pv);
trans_pv = required_pv;

View File

@ -318,9 +318,12 @@ bool OpLayoutInfer::TransposeConstTensorData(
reverse_shape.push_back(input->GetShape()[i]);
}
std::vector<uint32_t> perm = KOcHWIc2OcIcHW;
std::vector<uint32_t>tmp_vec = kOcIcWH2WHIcOc;
if (pv->AsStdVec() == tmp_vec) {
std::vector<uint32_t>tmp_vec0 = kOcIcWH2WHIcOc;
std::vector<uint32_t>tmp_vec1 = kIcOcWH2WHIcOc;
if (pv->AsStdVec() == tmp_vec0) {
perm = kHWIcOc2OcIcHW;
} else if (pv->AsStdVec() == tmp_vec1) {
perm = kHWOcIc2OcIcHW;
}
vsi_nn_Transpose(out_data.data(), (uint8_t*)(input->GetDataRef()),
(uint32_t*)(reverse_shape.data()),

View File

@ -39,7 +39,10 @@ constexpr std::initializer_list<uint32_t> KOcHWIc2OcIcHW = {0, 3, 1, 2};
constexpr std::initializer_list<uint32_t> kIcWHOc2WHIcOc = {1, 2, 0, 3};
constexpr std::initializer_list<uint32_t> kHWIcOc2OcIcHW = {3, 2, 0, 1};
constexpr std::initializer_list<uint32_t> kHWOcIc2OcIcHW = {2, 3, 0, 1};
constexpr std::initializer_list<uint32_t> kOcIcWH2WHIcOc = {2, 3, 1, 0};
constexpr std::initializer_list<uint32_t> kIcOcWH2WHIcOc = {2, 3, 0, 1};
class OpLayoutInfer {
public: