Fixed layout inference crash(assert) if node have multiply output

Signed-off-by: xiang.zhang <xiang.zhang@verisilicon.com>
This commit is contained in:
xiang.zhang 2021-09-15 14:52:43 +08:00 committed by Sven
parent 374841cbd9
commit 994f8a9c2a
3 changed files with 47 additions and 11 deletions

View File

@ -55,14 +55,19 @@ class DefaultLayoutInfer : public OpLayoutInfer {
for (const auto& i_src : op_->impl()->InputsTensor()) {
(*cloned_op).BindInput(context_->GetMapedTensor(i_src));
}
auto required_pv =
MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size());
auto out_infer = CreateOutputsTensor(required_pv);
// TODO: bind all output
std::vector<std::shared_ptr<IPermuteVector>> required_pv_lst;
for (auto out_tensor: op_->impl()->OutputsTensor()) {
required_pv_lst.push_back(MakeShared(out_tensor->GetShape().size()));
}
auto out_infer = CreateOutputsTensor(required_pv_lst);
(*cloned_op).BindOutputs(out_infer);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
uint32_t i = 0;
for (auto out_tensor : op_->impl()->OutputsTensor()) {
context_->SetPermuteVector(out_tensor, required_pv_lst[i++]);
next_tensors.push_back(out_tensor);
}
}
};

View File

@ -83,19 +83,48 @@ std::shared_ptr<vx::Tensor> OpLayoutInfer::InsertPermute(
}
std::vector<std::shared_ptr<vx::Tensor>> OpLayoutInfer::CreateOutputsTensor(
std::shared_ptr<IPermuteVector> required_pv) {
std::vector<std::shared_ptr<vx::Tensor>> ouptuts_tensor;
std::shared_ptr<IPermuteVector> required_pv) {
std::vector<std::shared_ptr<vx::Tensor>> outputs_tensor;
if (op_->impl()->OutputsTensor().size() > 1) {
// todo(sven): potential bug here if node have multi-output and require layout inference
std::cout <<"warning at "<< __FUNCTION__ << ", #" << __LINE__ << std::endl;
}
uint32_t i = 0;
for (const auto& o : op_->impl()->OutputsTensor()) {
auto in_shape = o->GetShape();
auto out_spec = o->GetSpec();
if (!required_pv->IsAligned()) {
if (!(required_pv->IsAligned())) {
out_spec = out_spec.AsTransientSpec();
}
auto t_infer = context_->infer_graph_->CreateTensor(out_spec);
context_->UpdateTensorMap(o, t_infer);
ouptuts_tensor.push_back(t_infer);
outputs_tensor.push_back(t_infer);
i++;
}
return ouptuts_tensor;
return outputs_tensor;
}
std::vector<std::shared_ptr<vx::Tensor>> OpLayoutInfer::CreateOutputsTensor(
const std::vector<std::shared_ptr<IPermuteVector>>& required_pv) {
std::vector<std::shared_ptr<vx::Tensor>> outputs_tensor;
assert(required_pv.size() == (op_->impl()->OutputsTensor().size()));
uint32_t i = 0;
for (const auto& o : op_->impl()->OutputsTensor()) {
auto in_shape = o->GetShape();
auto out_spec = o->GetSpec();
if (!(required_pv[i]->IsAligned())) {
out_spec = out_spec.AsTransientSpec();
}
auto t_infer = context_->infer_graph_->CreateTensor(out_spec);
context_->UpdateTensorMap(o, t_infer);
outputs_tensor.push_back(t_infer);
i++;
}
return outputs_tensor;
}
vx::PadType OpLayoutInfer::TranslatePadType(int32_t pad) {

View File

@ -61,6 +61,8 @@ class OpLayoutInfer {
std::vector<std::shared_ptr<vx::Tensor>> CreateOutputsTensor(
std::shared_ptr<IPermuteVector> required_pv);
std::vector<std::shared_ptr<vx::Tensor>> CreateOutputsTensor(
const std::vector<std::shared_ptr<IPermuteVector>>& required_pv);
vx::PadType TranslatePadType(int32_t pad);
vx::PoolType TranslatePoolType(int32_t pool);