add DataLayout::IcOcWH for TVM usage (#231)
This commit is contained in:
parent
e001d53ddf
commit
b38bd41933
|
|
@ -63,11 +63,12 @@ enum class RoundingPolicy { TO_ZERO, RTNE };
|
||||||
enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA };
|
enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA };
|
||||||
|
|
||||||
enum class DataLayout {
|
enum class DataLayout {
|
||||||
|
ANY,
|
||||||
WHCN,
|
WHCN,
|
||||||
CWHN,
|
CWHN,
|
||||||
ANY,
|
IcWHOc, /*TF*/
|
||||||
IcWHOc /*TF*/,
|
OcIcWH, /*TVM for classic conv2d in tflite model*/
|
||||||
OcIcWH /*TVM*/,
|
IcOcWH, /*TVM for depthwise conv2d in tflite model*/
|
||||||
WHIcOc /*TIM-VX default*/
|
WHIcOc /*TIM-VX default*/
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,10 @@ class Conv2dLayoutInfer : public OpLayoutInfer {
|
||||||
trans_pv = std::make_shared<PermuteVector<4>>(kOcIcWH2WHIcOc);
|
trans_pv = std::make_shared<PermuteVector<4>>(kOcIcWH2WHIcOc);
|
||||||
infer_tensor = PermuteConstTensor(
|
infer_tensor = PermuteConstTensor(
|
||||||
in, trans_pv);
|
in, trans_pv);
|
||||||
|
} else if (src_conv2d->KernelDataLayout() == vx::DataLayout::IcOcWH) {
|
||||||
|
trans_pv = std::make_shared<PermuteVector<4>>(kIcOcWH2WHIcOc);
|
||||||
|
infer_tensor = PermuteConstTensor(
|
||||||
|
in, trans_pv);
|
||||||
} else {
|
} else {
|
||||||
infer_tensor = PermuteConstTensor(in, required_pv);
|
infer_tensor = PermuteConstTensor(in, required_pv);
|
||||||
trans_pv = required_pv;
|
trans_pv = required_pv;
|
||||||
|
|
|
||||||
|
|
@ -318,9 +318,12 @@ bool OpLayoutInfer::TransposeConstTensorData(
|
||||||
reverse_shape.push_back(input->GetShape()[i]);
|
reverse_shape.push_back(input->GetShape()[i]);
|
||||||
}
|
}
|
||||||
std::vector<uint32_t> perm = KOcHWIc2OcIcHW;
|
std::vector<uint32_t> perm = KOcHWIc2OcIcHW;
|
||||||
std::vector<uint32_t>tmp_vec = kOcIcWH2WHIcOc;
|
std::vector<uint32_t>tmp_vec0 = kOcIcWH2WHIcOc;
|
||||||
if (pv->AsStdVec() == tmp_vec) {
|
std::vector<uint32_t>tmp_vec1 = kIcOcWH2WHIcOc;
|
||||||
|
if (pv->AsStdVec() == tmp_vec0) {
|
||||||
perm = kHWIcOc2OcIcHW;
|
perm = kHWIcOc2OcIcHW;
|
||||||
|
} else if (pv->AsStdVec() == tmp_vec1) {
|
||||||
|
perm = kHWOcIc2OcIcHW;
|
||||||
}
|
}
|
||||||
vsi_nn_Transpose(out_data.data(), (uint8_t*)(input->GetDataRef()),
|
vsi_nn_Transpose(out_data.data(), (uint8_t*)(input->GetDataRef()),
|
||||||
(uint32_t*)(reverse_shape.data()),
|
(uint32_t*)(reverse_shape.data()),
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,10 @@ constexpr std::initializer_list<uint32_t> KOcHWIc2OcIcHW = {0, 3, 1, 2};
|
||||||
constexpr std::initializer_list<uint32_t> kIcWHOc2WHIcOc = {1, 2, 0, 3};
|
constexpr std::initializer_list<uint32_t> kIcWHOc2WHIcOc = {1, 2, 0, 3};
|
||||||
|
|
||||||
constexpr std::initializer_list<uint32_t> kHWIcOc2OcIcHW = {3, 2, 0, 1};
|
constexpr std::initializer_list<uint32_t> kHWIcOc2OcIcHW = {3, 2, 0, 1};
|
||||||
|
constexpr std::initializer_list<uint32_t> kHWOcIc2OcIcHW = {2, 3, 0, 1};
|
||||||
|
|
||||||
constexpr std::initializer_list<uint32_t> kOcIcWH2WHIcOc = {2, 3, 1, 0};
|
constexpr std::initializer_list<uint32_t> kOcIcWH2WHIcOc = {2, 3, 1, 0};
|
||||||
|
constexpr std::initializer_list<uint32_t> kIcOcWH2WHIcOc = {2, 3, 0, 1};
|
||||||
|
|
||||||
class OpLayoutInfer {
|
class OpLayoutInfer {
|
||||||
public:
|
public:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue