Catch the correct output when output has consumer (#239)
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com> Co-authored-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
1f85d21558
commit
2c38f89d06
|
|
@ -26,11 +26,18 @@ class LayoutInferContext {
|
||||||
void UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>& i_src,
|
void UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>& i_src,
|
||||||
const std::shared_ptr<vx::Tensor>& i_layout);
|
const std::shared_ptr<vx::Tensor>& i_layout);
|
||||||
|
|
||||||
|
void UpdateGraphOutputMap(const std::shared_ptr<vx::Tensor>& o_src,
|
||||||
|
const std::shared_ptr<vx::Tensor>& o_layout);
|
||||||
|
|
||||||
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
|
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
|
||||||
GetGraphInputMap() const {
|
GetGraphInputMap() const {
|
||||||
return graph_input_map_;
|
return graph_input_map_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
|
||||||
|
GetGraphOutputMap() const {
|
||||||
|
return graph_output_map_;
|
||||||
|
}
|
||||||
const std::shared_ptr<vx::Graph>& src_graph_;
|
const std::shared_ptr<vx::Graph>& src_graph_;
|
||||||
std::shared_ptr<vx::Graph>& infer_graph_;
|
std::shared_ptr<vx::Graph>& infer_graph_;
|
||||||
|
|
||||||
|
|
@ -43,6 +50,8 @@ class LayoutInferContext {
|
||||||
tensor_map_;
|
tensor_map_;
|
||||||
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
|
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
|
||||||
graph_input_map_;
|
graph_input_map_;
|
||||||
|
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
|
||||||
|
graph_output_map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace layout_inference_impl
|
} // namespace layout_inference_impl
|
||||||
|
|
|
||||||
|
|
@ -149,6 +149,11 @@ void LayoutInferContext::UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>&
|
||||||
graph_input_map_[i_src] = i_layout;
|
graph_input_map_[i_src] = i_layout;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LayoutInferContext::UpdateGraphOutputMap(const std::shared_ptr<vx::Tensor>& o_src,
|
||||||
|
const std::shared_ptr<vx::Tensor>& o_layout) {
|
||||||
|
graph_output_map_[o_src] = o_layout;
|
||||||
|
}
|
||||||
|
|
||||||
#define REGIST_LAYOUT_INFERENCE(op_idx, name) \
|
#define REGIST_LAYOUT_INFERENCE(op_idx, name) \
|
||||||
case op_idx: { \
|
case op_idx: { \
|
||||||
auto op_infer = std::make_shared<name##LayoutInfer>(op, ctx); \
|
auto op_infer = std::make_shared<name##LayoutInfer>(op, ctx); \
|
||||||
|
|
@ -305,8 +310,8 @@ std::pair<std::shared_ptr<vx::Graph>,
|
||||||
for (const auto& graph_input : layout_infer_ctx->GetGraphInputMap()) {
|
for (const auto& graph_input : layout_infer_ctx->GetGraphInputMap()) {
|
||||||
graph_io_map[graph_input.first] = graph_input.second;
|
graph_io_map[graph_input.first] = graph_input.second;
|
||||||
}
|
}
|
||||||
for (const auto& out_src : src_graph->OutputsTensor()) {
|
for (const auto& graph_output : layout_infer_ctx->GetGraphOutputMap()) {
|
||||||
graph_io_map[out_src] = layout_infer_ctx->GetMapedTensor(out_src);
|
graph_io_map[graph_output.first] = graph_output.second;
|
||||||
}
|
}
|
||||||
return std::make_pair(infer_graph, graph_io_map);
|
return std::make_pair(infer_graph, graph_io_map);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -40,12 +40,14 @@ void OpLayoutInfer::OnOutputs(
|
||||||
for (const auto& out : op_outputs) {
|
for (const auto& out : op_outputs) {
|
||||||
if (graph_outputs.end() !=
|
if (graph_outputs.end() !=
|
||||||
std::find(graph_outputs.begin(), graph_outputs.end(), out)) {
|
std::find(graph_outputs.begin(), graph_outputs.end(), out)) {
|
||||||
|
context_->UpdateGraphOutputMap(out, context_->GetMapedTensor(out));
|
||||||
auto pv = context_->GetPermuteVector(out);
|
auto pv = context_->GetPermuteVector(out);
|
||||||
if (!pv->IsAligned()) {
|
if (!pv->IsAligned()) {
|
||||||
auto perm_out = InsertPermute(context_->GetMapedTensor(out),
|
auto perm_out = InsertPermute(context_->GetMapedTensor(out),
|
||||||
pv->Reverse(), true, out);
|
pv->Reverse(), true, out);
|
||||||
// Update graph out tensor
|
// Update graph out tensor
|
||||||
context_->UpdateTensorMap(out, perm_out);
|
context_->UpdateTensorMap(out, perm_out);
|
||||||
|
context_->UpdateGraphOutputMap(out, perm_out);
|
||||||
}
|
}
|
||||||
if (!context_->src_graph_->GetConsumersOp(out).empty()) {
|
if (!context_->src_graph_->GetConsumersOp(out).empty()) {
|
||||||
// The tensor is output of graph, but it also is the input of other operations
|
// The tensor is output of graph, but it also is the input of other operations
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue