Fix layout inference for traspose convolution
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
1672ef99ed
commit
f8f2c6d519
|
|
@ -58,7 +58,7 @@ class DeConv2dLayoutInfer : public OpLayoutInfer {
|
||||||
in->GetDataRef());
|
in->GetDataRef());
|
||||||
trans_pv = MakeShared(1);
|
trans_pv = MakeShared(1);
|
||||||
} else {
|
} else {
|
||||||
// For input/weight
|
// For weight
|
||||||
if (!required_pv->IsAligned()) {
|
if (!required_pv->IsAligned()) {
|
||||||
auto src_deconv2d =
|
auto src_deconv2d =
|
||||||
std::static_pointer_cast<vx::ops::DeConv2d>(op_);
|
std::static_pointer_cast<vx::ops::DeConv2d>(op_);
|
||||||
|
|
@ -66,6 +66,11 @@ class DeConv2dLayoutInfer : public OpLayoutInfer {
|
||||||
if (src_deconv2d->KernelDataLayout() == vx::DataLayout::OcIcWH) {
|
if (src_deconv2d->KernelDataLayout() == vx::DataLayout::OcIcWH) {
|
||||||
trans_pv = std::make_shared<PermuteVector<4>>(kOcIcWH2WHIcOc);
|
trans_pv = std::make_shared<PermuteVector<4>>(kOcIcWH2WHIcOc);
|
||||||
infer_tensor = PermuteConstTensor(in, trans_pv);
|
infer_tensor = PermuteConstTensor(in, trans_pv);
|
||||||
|
} else if (src_deconv2d->KernelDataLayout() ==
|
||||||
|
vx::DataLayout::WHIcOc) {
|
||||||
|
infer_tensor = context_->infer_graph_->CreateTensor(
|
||||||
|
in->GetSpec(), in->GetDataRef());
|
||||||
|
trans_pv = MakeShared(required_pv->Rank());
|
||||||
} else {
|
} else {
|
||||||
infer_tensor = PermuteConstTensor(in, required_pv);
|
infer_tensor = PermuteConstTensor(in, required_pv);
|
||||||
trans_pv = required_pv;
|
trans_pv = required_pv;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue