#ifndef TIM_VX_LAYOUT_INFER_CONTEXT_H_ #define TIM_VX_LAYOUT_INFER_CONTEXT_H_ #include "permute_vector.h" #include "tim/transform/layout_inference.h" #include namespace tim { namespace transform { namespace layout_inference_impl { class LayoutInferContext { public: LayoutInferContext(const std::shared_ptr& src_graph, std::shared_ptr& infer_graph); void SetPermuteVector(std::shared_ptr tensor, std::shared_ptr pv); const std::shared_ptr GetPermuteVector( const std::shared_ptr& tensor) const; void MarkVisited(const std::shared_ptr& op); bool IsVisited(const std::shared_ptr& op) const; bool IsReadyForInfer(const std::shared_ptr& op) const; void UpdateTensorMap(const std::shared_ptr& t_src, const std::shared_ptr& t_layout); std::shared_ptr GetMappedTensor( const std::shared_ptr& t_src) const; std::shared_ptr GetMappedGraphInputTensor( const std::shared_ptr& t_src) const; std::shared_ptr GetMappedGraphOutputTensor( const std::shared_ptr& t_src) const; void UpdateGraphInputMap(const std::shared_ptr& i_src, const std::shared_ptr& i_layout); void UpdateGraphOutputMap(const std::shared_ptr& o_src, const std::shared_ptr& o_layout); std::map, std::shared_ptr> GetGraphInputMap() const { return graph_input_map_; } std::map, std::shared_ptr> GetGraphOutputMap() const { return graph_output_map_; } const std::shared_ptr& src_graph_; std::shared_ptr& infer_graph_; private: std::map, std::shared_ptr> tensor_pv_; std::unordered_map, bool> op_visited_; // tensor_in_src -> tensor_in_layout std::map, std::shared_ptr> tensor_map_; std::map, std::shared_ptr> graph_input_map_; std::map, std::shared_ptr> graph_output_map_; }; } // namespace layout_inference_impl } // namespace transform } // namespace tim #endif