Fix layout inference for traspose convolution

Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
yuenan.li 2021-06-21 09:43:16 +08:00 committed by Sven
parent 1672ef99ed
commit f8f2c6d519
1 changed files with 6 additions and 1 deletions

View File

@ -58,7 +58,7 @@ class DeConv2dLayoutInfer : public OpLayoutInfer {
in->GetDataRef());
trans_pv = MakeShared(1);
} else {
// For input/weight
// For weight
if (!required_pv->IsAligned()) {
auto src_deconv2d =
std::static_pointer_cast<vx::ops::DeConv2d>(op_);
@ -66,6 +66,11 @@ class DeConv2dLayoutInfer : public OpLayoutInfer {
if (src_deconv2d->KernelDataLayout() == vx::DataLayout::OcIcWH) {
trans_pv = std::make_shared<PermuteVector<4>>(kOcIcWH2WHIcOc);
infer_tensor = PermuteConstTensor(in, trans_pv);
} else if (src_deconv2d->KernelDataLayout() ==
vx::DataLayout::WHIcOc) {
infer_tensor = context_->infer_graph_->CreateTensor(
in->GetSpec(), in->GetDataRef());
trans_pv = MakeShared(required_pv->Rank());
} else {
infer_tensor = PermuteConstTensor(in, required_pv);
trans_pv = required_pv;