Catch the correct output when output has consumer (#239)

Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>

Co-authored-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
liyuenan 2021-12-15 09:54:54 +08:00 committed by GitHub
parent 1f85d21558
commit 2c38f89d06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 2 deletions

View File

@ -26,11 +26,18 @@ class LayoutInferContext {
void UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>& i_src,
const std::shared_ptr<vx::Tensor>& i_layout);
void UpdateGraphOutputMap(const std::shared_ptr<vx::Tensor>& o_src,
const std::shared_ptr<vx::Tensor>& o_layout);
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
GetGraphInputMap() const {
return graph_input_map_;
}
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
GetGraphOutputMap() const {
return graph_output_map_;
}
const std::shared_ptr<vx::Graph>& src_graph_;
std::shared_ptr<vx::Graph>& infer_graph_;
@ -43,6 +50,8 @@ class LayoutInferContext {
tensor_map_;
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
graph_input_map_;
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
graph_output_map_;
};
} // namespace layout_inference_impl

View File

@ -149,6 +149,11 @@ void LayoutInferContext::UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>&
graph_input_map_[i_src] = i_layout;
}
void LayoutInferContext::UpdateGraphOutputMap(const std::shared_ptr<vx::Tensor>& o_src,
const std::shared_ptr<vx::Tensor>& o_layout) {
graph_output_map_[o_src] = o_layout;
}
#define REGIST_LAYOUT_INFERENCE(op_idx, name) \
case op_idx: { \
auto op_infer = std::make_shared<name##LayoutInfer>(op, ctx); \
@ -305,8 +310,8 @@ std::pair<std::shared_ptr<vx::Graph>,
for (const auto& graph_input : layout_infer_ctx->GetGraphInputMap()) {
graph_io_map[graph_input.first] = graph_input.second;
}
for (const auto& out_src : src_graph->OutputsTensor()) {
graph_io_map[out_src] = layout_infer_ctx->GetMapedTensor(out_src);
for (const auto& graph_output : layout_infer_ctx->GetGraphOutputMap()) {
graph_io_map[graph_output.first] = graph_output.second;
}
return std::make_pair(infer_graph, graph_io_map);
}

View File

@ -40,12 +40,14 @@ void OpLayoutInfer::OnOutputs(
for (const auto& out : op_outputs) {
if (graph_outputs.end() !=
std::find(graph_outputs.begin(), graph_outputs.end(), out)) {
context_->UpdateGraphOutputMap(out, context_->GetMapedTensor(out));
auto pv = context_->GetPermuteVector(out);
if (!pv->IsAligned()) {
auto perm_out = InsertPermute(context_->GetMapedTensor(out),
pv->Reverse(), true, out);
// Update graph out tensor
context_->UpdateTensorMap(out, perm_out);
context_->UpdateGraphOutputMap(out, perm_out);
}
if (!context_->src_graph_->GetConsumersOp(out).empty()) {
// The tensor is output of graph, but it also is the input of other operations