#ifndef TIM_VX_LAYOUT_INFER_CONTEXT_H_ #define TIM_VX_LAYOUT_INFER_CONTEXT_H_ #include "permute_vector.h" #include "tim/layout_infer/layout_inference.h" namespace tim { namespace transform { namespace layout_inference_impl { class LayoutInferContext { public: LayoutInferContext(const std::shared_ptr& src_graph, std::shared_ptr& infer_graph) : src_graph_(src_graph), infer_graph_(infer_graph) {} void SetPermuteVector(std::shared_ptr tensor, std::shared_ptr pv); const std::shared_ptr GetPermuteVector( const std::shared_ptr& tensor) const; void MarkVisited(const std::shared_ptr& op); bool IsVisited(const std::shared_ptr& op) const; bool IsReadyForInfer(const std::shared_ptr& op) const; void UpdateTensorMap(const std::shared_ptr& t_src, const std::shared_ptr& t_layout); std::shared_ptr GetMapedTensor( const std::shared_ptr& t_src) const; void UpdateGraphInputMap(const std::shared_ptr& i_src, const std::shared_ptr& i_layout); std::map, std::shared_ptr> GetGraphInputMap() const { return graph_input_map_; } const std::shared_ptr& src_graph_; std::shared_ptr& infer_graph_; private: std::map, std::shared_ptr> tensor_pv_; std::vector> visited_op_; // tensor_in_src -> tensor_in_layout std::map, std::shared_ptr> tensor_map_; std::map, std::shared_ptr> graph_input_map_; }; } // namespace layout_inference_impl } // namespace transform } // namespace tim #endif