diff --git a/src/tim/transform/ops/default_layout_inference.h b/src/tim/transform/ops/default_layout_inference.h index 09252cc..fd8bb6e 100644 --- a/src/tim/transform/ops/default_layout_inference.h +++ b/src/tim/transform/ops/default_layout_inference.h @@ -55,14 +55,19 @@ class DefaultLayoutInfer : public OpLayoutInfer { for (const auto& i_src : op_->impl()->InputsTensor()) { (*cloned_op).BindInput(context_->GetMapedTensor(i_src)); } - auto required_pv = - MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size()); - auto out_infer = CreateOutputsTensor(required_pv); - // TODO: bind all output + std::vector> required_pv_lst; + for (auto out_tensor: op_->impl()->OutputsTensor()) { + required_pv_lst.push_back(MakeShared(out_tensor->GetShape().size())); + } + auto out_infer = CreateOutputsTensor(required_pv_lst); + (*cloned_op).BindOutputs(out_infer); - context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv); - next_tensors.push_back(op_->impl()->OutputsTensor()[0]); + uint32_t i = 0; + for (auto out_tensor : op_->impl()->OutputsTensor()) { + context_->SetPermuteVector(out_tensor, required_pv_lst[i++]); + next_tensors.push_back(out_tensor); + } } }; diff --git a/src/tim/transform/ops/op_layout_inference.cc b/src/tim/transform/ops/op_layout_inference.cc index e9008b9..3a8ee10 100644 --- a/src/tim/transform/ops/op_layout_inference.cc +++ b/src/tim/transform/ops/op_layout_inference.cc @@ -83,19 +83,48 @@ std::shared_ptr OpLayoutInfer::InsertPermute( } std::vector> OpLayoutInfer::CreateOutputsTensor( - std::shared_ptr required_pv) { - std::vector> ouptuts_tensor; + std::shared_ptr required_pv) { + std::vector> outputs_tensor; + + if (op_->impl()->OutputsTensor().size() > 1) { + // todo(sven): potential bug here if node have multi-output and require layout inference + std::cout <<"warning at "<< __FUNCTION__ << ", #" << __LINE__ << std::endl; + } + + uint32_t i = 0; for (const auto& o : op_->impl()->OutputsTensor()) { auto in_shape = o->GetShape(); auto out_spec = o->GetSpec(); - if (!required_pv->IsAligned()) { + if (!(required_pv->IsAligned())) { out_spec = out_spec.AsTransientSpec(); } auto t_infer = context_->infer_graph_->CreateTensor(out_spec); context_->UpdateTensorMap(o, t_infer); - ouptuts_tensor.push_back(t_infer); + outputs_tensor.push_back(t_infer); + i++; } - return ouptuts_tensor; + return outputs_tensor; +} + +std::vector> OpLayoutInfer::CreateOutputsTensor( + const std::vector>& required_pv) { + std::vector> outputs_tensor; + + assert(required_pv.size() == (op_->impl()->OutputsTensor().size())); + + uint32_t i = 0; + for (const auto& o : op_->impl()->OutputsTensor()) { + auto in_shape = o->GetShape(); + auto out_spec = o->GetSpec(); + if (!(required_pv[i]->IsAligned())) { + out_spec = out_spec.AsTransientSpec(); + } + auto t_infer = context_->infer_graph_->CreateTensor(out_spec); + context_->UpdateTensorMap(o, t_infer); + outputs_tensor.push_back(t_infer); + i++; + } + return outputs_tensor; } vx::PadType OpLayoutInfer::TranslatePadType(int32_t pad) { diff --git a/src/tim/transform/ops/op_layout_inference.h b/src/tim/transform/ops/op_layout_inference.h index b20f08f..7fe1eb2 100644 --- a/src/tim/transform/ops/op_layout_inference.h +++ b/src/tim/transform/ops/op_layout_inference.h @@ -61,6 +61,8 @@ class OpLayoutInfer { std::vector> CreateOutputsTensor( std::shared_ptr required_pv); + std::vector> CreateOutputsTensor( + const std::vector>& required_pv); vx::PadType TranslatePadType(int32_t pad); vx::PoolType TranslatePoolType(int32_t pool);