Fix layout inference for traspose convolution
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
1672ef99ed
commit
f8f2c6d519
|
|
@ -58,7 +58,7 @@ class DeConv2dLayoutInfer : public OpLayoutInfer {
|
|||
in->GetDataRef());
|
||||
trans_pv = MakeShared(1);
|
||||
} else {
|
||||
// For input/weight
|
||||
// For weight
|
||||
if (!required_pv->IsAligned()) {
|
||||
auto src_deconv2d =
|
||||
std::static_pointer_cast<vx::ops::DeConv2d>(op_);
|
||||
|
|
@ -66,6 +66,11 @@ class DeConv2dLayoutInfer : public OpLayoutInfer {
|
|||
if (src_deconv2d->KernelDataLayout() == vx::DataLayout::OcIcWH) {
|
||||
trans_pv = std::make_shared<PermuteVector<4>>(kOcIcWH2WHIcOc);
|
||||
infer_tensor = PermuteConstTensor(in, trans_pv);
|
||||
} else if (src_deconv2d->KernelDataLayout() ==
|
||||
vx::DataLayout::WHIcOc) {
|
||||
infer_tensor = context_->infer_graph_->CreateTensor(
|
||||
in->GetSpec(), in->GetDataRef());
|
||||
trans_pv = MakeShared(required_pv->Rank());
|
||||
} else {
|
||||
infer_tensor = PermuteConstTensor(in, required_pv);
|
||||
trans_pv = required_pv;
|
||||
|
|
|
|||
Loading…
Reference in New Issue