From 2c38f89d06ef0d8429e223fadb384a21f9e88350 Mon Sep 17 00:00:00 2001 From: liyuenan <37231553+liyuenan2333@users.noreply.github.com> Date: Wed, 15 Dec 2021 09:54:54 +0800 Subject: [PATCH] Catch the correct output when output has consumer (#239) Signed-off-by: yuenan.li Co-authored-by: yuenan.li --- src/tim/transform/layout_infer_context.h | 9 +++++++++ src/tim/transform/layout_inference.cc | 9 +++++++-- src/tim/transform/ops/op_layout_inference.cc | 2 ++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/tim/transform/layout_infer_context.h b/src/tim/transform/layout_infer_context.h index c85b0eb..d63960c 100644 --- a/src/tim/transform/layout_infer_context.h +++ b/src/tim/transform/layout_infer_context.h @@ -26,11 +26,18 @@ class LayoutInferContext { void UpdateGraphInputMap(const std::shared_ptr& i_src, const std::shared_ptr& i_layout); + void UpdateGraphOutputMap(const std::shared_ptr& o_src, + const std::shared_ptr& o_layout); + std::map, std::shared_ptr> GetGraphInputMap() const { return graph_input_map_; } + std::map, std::shared_ptr> + GetGraphOutputMap() const { + return graph_output_map_; + } const std::shared_ptr& src_graph_; std::shared_ptr& infer_graph_; @@ -43,6 +50,8 @@ class LayoutInferContext { tensor_map_; std::map, std::shared_ptr> graph_input_map_; + std::map, std::shared_ptr> + graph_output_map_; }; } // namespace layout_inference_impl diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index 447e00c..a9d0547 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -149,6 +149,11 @@ void LayoutInferContext::UpdateGraphInputMap(const std::shared_ptr& graph_input_map_[i_src] = i_layout; } +void LayoutInferContext::UpdateGraphOutputMap(const std::shared_ptr& o_src, + const std::shared_ptr& o_layout) { + graph_output_map_[o_src] = o_layout; +} + #define REGIST_LAYOUT_INFERENCE(op_idx, name) \ case op_idx: { \ auto op_infer = std::make_shared(op, ctx); \ @@ -305,8 +310,8 @@ std::pair, for (const auto& graph_input : layout_infer_ctx->GetGraphInputMap()) { graph_io_map[graph_input.first] = graph_input.second; } - for (const auto& out_src : src_graph->OutputsTensor()) { - graph_io_map[out_src] = layout_infer_ctx->GetMapedTensor(out_src); + for (const auto& graph_output : layout_infer_ctx->GetGraphOutputMap()) { + graph_io_map[graph_output.first] = graph_output.second; } return std::make_pair(infer_graph, graph_io_map); } diff --git a/src/tim/transform/ops/op_layout_inference.cc b/src/tim/transform/ops/op_layout_inference.cc index 8dbdb74..d48c770 100644 --- a/src/tim/transform/ops/op_layout_inference.cc +++ b/src/tim/transform/ops/op_layout_inference.cc @@ -40,12 +40,14 @@ void OpLayoutInfer::OnOutputs( for (const auto& out : op_outputs) { if (graph_outputs.end() != std::find(graph_outputs.begin(), graph_outputs.end(), out)) { + context_->UpdateGraphOutputMap(out, context_->GetMapedTensor(out)); auto pv = context_->GetPermuteVector(out); if (!pv->IsAligned()) { auto perm_out = InsertPermute(context_->GetMapedTensor(out), pv->Reverse(), true, out); // Update graph out tensor context_->UpdateTensorMap(out, perm_out); + context_->UpdateGraphOutputMap(out, perm_out); } if (!context_->src_graph_->GetConsumersOp(out).empty()) { // The tensor is output of graph, but it also is the input of other operations