diff --git a/src/tim/transform/ops/deconv2d_layout_inference.h b/src/tim/transform/ops/deconv2d_layout_inference.h index 4343437..cdf9068 100644 --- a/src/tim/transform/ops/deconv2d_layout_inference.h +++ b/src/tim/transform/ops/deconv2d_layout_inference.h @@ -58,7 +58,7 @@ class DeConv2dLayoutInfer : public OpLayoutInfer { in->GetDataRef()); trans_pv = MakeShared(1); } else { - // For input/weight + // For weight if (!required_pv->IsAligned()) { auto src_deconv2d = std::static_pointer_cast(op_); @@ -66,6 +66,11 @@ class DeConv2dLayoutInfer : public OpLayoutInfer { if (src_deconv2d->KernelDataLayout() == vx::DataLayout::OcIcWH) { trans_pv = std::make_shared>(kOcIcWH2WHIcOc); infer_tensor = PermuteConstTensor(in, trans_pv); + } else if (src_deconv2d->KernelDataLayout() == + vx::DataLayout::WHIcOc) { + infer_tensor = context_->infer_graph_->CreateTensor( + in->GetSpec(), in->GetDataRef()); + trans_pv = MakeShared(required_pv->Rank()); } else { infer_tensor = PermuteConstTensor(in, required_pv); trans_pv = required_pv;