Fixed layout inference crash(assert) if node have multiply output
Signed-off-by: xiang.zhang <xiang.zhang@verisilicon.com>
This commit is contained in:
parent
374841cbd9
commit
994f8a9c2a
|
|
@ -55,14 +55,19 @@ class DefaultLayoutInfer : public OpLayoutInfer {
|
|||
for (const auto& i_src : op_->impl()->InputsTensor()) {
|
||||
(*cloned_op).BindInput(context_->GetMapedTensor(i_src));
|
||||
}
|
||||
auto required_pv =
|
||||
MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size());
|
||||
auto out_infer = CreateOutputsTensor(required_pv);
|
||||
|
||||
// TODO: bind all output
|
||||
std::vector<std::shared_ptr<IPermuteVector>> required_pv_lst;
|
||||
for (auto out_tensor: op_->impl()->OutputsTensor()) {
|
||||
required_pv_lst.push_back(MakeShared(out_tensor->GetShape().size()));
|
||||
}
|
||||
auto out_infer = CreateOutputsTensor(required_pv_lst);
|
||||
|
||||
(*cloned_op).BindOutputs(out_infer);
|
||||
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
|
||||
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
|
||||
uint32_t i = 0;
|
||||
for (auto out_tensor : op_->impl()->OutputsTensor()) {
|
||||
context_->SetPermuteVector(out_tensor, required_pv_lst[i++]);
|
||||
next_tensors.push_back(out_tensor);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -83,19 +83,48 @@ std::shared_ptr<vx::Tensor> OpLayoutInfer::InsertPermute(
|
|||
}
|
||||
|
||||
std::vector<std::shared_ptr<vx::Tensor>> OpLayoutInfer::CreateOutputsTensor(
|
||||
std::shared_ptr<IPermuteVector> required_pv) {
|
||||
std::vector<std::shared_ptr<vx::Tensor>> ouptuts_tensor;
|
||||
std::shared_ptr<IPermuteVector> required_pv) {
|
||||
std::vector<std::shared_ptr<vx::Tensor>> outputs_tensor;
|
||||
|
||||
if (op_->impl()->OutputsTensor().size() > 1) {
|
||||
// todo(sven): potential bug here if node have multi-output and require layout inference
|
||||
std::cout <<"warning at "<< __FUNCTION__ << ", #" << __LINE__ << std::endl;
|
||||
}
|
||||
|
||||
uint32_t i = 0;
|
||||
for (const auto& o : op_->impl()->OutputsTensor()) {
|
||||
auto in_shape = o->GetShape();
|
||||
auto out_spec = o->GetSpec();
|
||||
if (!required_pv->IsAligned()) {
|
||||
if (!(required_pv->IsAligned())) {
|
||||
out_spec = out_spec.AsTransientSpec();
|
||||
}
|
||||
auto t_infer = context_->infer_graph_->CreateTensor(out_spec);
|
||||
context_->UpdateTensorMap(o, t_infer);
|
||||
ouptuts_tensor.push_back(t_infer);
|
||||
outputs_tensor.push_back(t_infer);
|
||||
i++;
|
||||
}
|
||||
return ouptuts_tensor;
|
||||
return outputs_tensor;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<vx::Tensor>> OpLayoutInfer::CreateOutputsTensor(
|
||||
const std::vector<std::shared_ptr<IPermuteVector>>& required_pv) {
|
||||
std::vector<std::shared_ptr<vx::Tensor>> outputs_tensor;
|
||||
|
||||
assert(required_pv.size() == (op_->impl()->OutputsTensor().size()));
|
||||
|
||||
uint32_t i = 0;
|
||||
for (const auto& o : op_->impl()->OutputsTensor()) {
|
||||
auto in_shape = o->GetShape();
|
||||
auto out_spec = o->GetSpec();
|
||||
if (!(required_pv[i]->IsAligned())) {
|
||||
out_spec = out_spec.AsTransientSpec();
|
||||
}
|
||||
auto t_infer = context_->infer_graph_->CreateTensor(out_spec);
|
||||
context_->UpdateTensorMap(o, t_infer);
|
||||
outputs_tensor.push_back(t_infer);
|
||||
i++;
|
||||
}
|
||||
return outputs_tensor;
|
||||
}
|
||||
|
||||
vx::PadType OpLayoutInfer::TranslatePadType(int32_t pad) {
|
||||
|
|
|
|||
|
|
@ -61,6 +61,8 @@ class OpLayoutInfer {
|
|||
std::vector<std::shared_ptr<vx::Tensor>> CreateOutputsTensor(
|
||||
std::shared_ptr<IPermuteVector> required_pv);
|
||||
|
||||
std::vector<std::shared_ptr<vx::Tensor>> CreateOutputsTensor(
|
||||
const std::vector<std::shared_ptr<IPermuteVector>>& required_pv);
|
||||
vx::PadType TranslatePadType(int32_t pad);
|
||||
|
||||
vx::PoolType TranslatePoolType(int32_t pool);
|
||||
|
|
|
|||
Loading…
Reference in New Issue