add DataLayout::IcOcWH for TVM usage (#231)
This commit is contained in:
parent
e001d53ddf
commit
b38bd41933
|
|
@ -63,11 +63,12 @@ enum class RoundingPolicy { TO_ZERO, RTNE };
|
|||
enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA };
|
||||
|
||||
enum class DataLayout {
|
||||
ANY,
|
||||
WHCN,
|
||||
CWHN,
|
||||
ANY,
|
||||
IcWHOc /*TF*/,
|
||||
OcIcWH /*TVM*/,
|
||||
IcWHOc, /*TF*/
|
||||
OcIcWH, /*TVM for classic conv2d in tflite model*/
|
||||
IcOcWH, /*TVM for depthwise conv2d in tflite model*/
|
||||
WHIcOc /*TIM-VX default*/
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -66,6 +66,10 @@ class Conv2dLayoutInfer : public OpLayoutInfer {
|
|||
trans_pv = std::make_shared<PermuteVector<4>>(kOcIcWH2WHIcOc);
|
||||
infer_tensor = PermuteConstTensor(
|
||||
in, trans_pv);
|
||||
} else if (src_conv2d->KernelDataLayout() == vx::DataLayout::IcOcWH) {
|
||||
trans_pv = std::make_shared<PermuteVector<4>>(kIcOcWH2WHIcOc);
|
||||
infer_tensor = PermuteConstTensor(
|
||||
in, trans_pv);
|
||||
} else {
|
||||
infer_tensor = PermuteConstTensor(in, required_pv);
|
||||
trans_pv = required_pv;
|
||||
|
|
|
|||
|
|
@ -318,9 +318,12 @@ bool OpLayoutInfer::TransposeConstTensorData(
|
|||
reverse_shape.push_back(input->GetShape()[i]);
|
||||
}
|
||||
std::vector<uint32_t> perm = KOcHWIc2OcIcHW;
|
||||
std::vector<uint32_t>tmp_vec = kOcIcWH2WHIcOc;
|
||||
if (pv->AsStdVec() == tmp_vec) {
|
||||
std::vector<uint32_t>tmp_vec0 = kOcIcWH2WHIcOc;
|
||||
std::vector<uint32_t>tmp_vec1 = kIcOcWH2WHIcOc;
|
||||
if (pv->AsStdVec() == tmp_vec0) {
|
||||
perm = kHWIcOc2OcIcHW;
|
||||
} else if (pv->AsStdVec() == tmp_vec1) {
|
||||
perm = kHWOcIc2OcIcHW;
|
||||
}
|
||||
vsi_nn_Transpose(out_data.data(), (uint8_t*)(input->GetDataRef()),
|
||||
(uint32_t*)(reverse_shape.data()),
|
||||
|
|
|
|||
|
|
@ -39,7 +39,10 @@ constexpr std::initializer_list<uint32_t> KOcHWIc2OcIcHW = {0, 3, 1, 2};
|
|||
constexpr std::initializer_list<uint32_t> kIcWHOc2WHIcOc = {1, 2, 0, 3};
|
||||
|
||||
constexpr std::initializer_list<uint32_t> kHWIcOc2OcIcHW = {3, 2, 0, 1};
|
||||
constexpr std::initializer_list<uint32_t> kHWOcIc2OcIcHW = {2, 3, 0, 1};
|
||||
|
||||
constexpr std::initializer_list<uint32_t> kOcIcWH2WHIcOc = {2, 3, 1, 0};
|
||||
constexpr std::initializer_list<uint32_t> kIcOcWH2WHIcOc = {2, 3, 0, 1};
|
||||
|
||||
class OpLayoutInfer {
|
||||
public:
|
||||
|
|
|
|||
Loading…
Reference in New Issue