import torch # 1. 定义模型(需满足静态形状和静态控制流) class SimpleModel(torch.nn.Module): def __init__(self): super().__init__() self.fc = torch.nn.Linear(3 * 224 * 224, 1000) def forward(self, x): x = x.view(x.size(0), -1) # 静态形状操作 return torch.relu(self.fc(x)) # 避免动态控制流 model = SimpleModel().cuda() # 2. 准备静态输入/输出占位张量 static_input = torch.randn(16, 3, 224, 224, device='cuda') static_output = torch.zeros(16, 1000, device='cuda') # 3. 预热阶段(必须在非默认流) s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): # 预热3次 static_output = model(static_input) torch.cuda.current_stream().wait_stream(s) # 4. 创建并捕获CUDA图 g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): # 注意:此处操作不会实际执行,仅记录计算图 static_output = model(static_input) # 5. 使用图计算(更新数据+重放) def run_graph(new_input): # 将新数据复制到捕获的输入张量 static_input.copy_(new_input) # 重放计算图 g.replay() return static_output.clone() # 返回结果副本 # 测试 new_data = torch.randn(16, 3, 224, 224, device='cuda') result = run_graph(new_data) print(result.shape) # torch.Size([16, 1000])