Witllm/binary/cudagraph.py

44 lines
1.4 KiB
Python

import torch
# 1. 定义模型(需满足静态形状和静态控制流)
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(3 * 224 * 224, 1000)
def forward(self, x):
x = x.view(x.size(0), -1) # 静态形状操作
return torch.relu(self.fc(x)) # 避免动态控制流
model = SimpleModel().cuda()
# 2. 准备静态输入/输出占位张量
static_input = torch.randn(16, 3, 224, 224, device='cuda')
static_output = torch.zeros(16, 1000, device='cuda')
# 3. 预热阶段(必须在非默认流)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3): # 预热3次
static_output = model(static_input)
torch.cuda.current_stream().wait_stream(s)
# 4. 创建并捕获CUDA图
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
# 注意:此处操作不会实际执行,仅记录计算图
static_output = model(static_input)
# 5. 使用图计算(更新数据+重放)
def run_graph(new_input):
# 将新数据复制到捕获的输入张量
static_input.copy_(new_input)
# 重放计算图
g.replay()
return static_output.clone() # 返回结果副本
# 测试
new_data = torch.randn(16, 3, 224, 224, device='cuda')
result = run_graph(new_data)
print(result.shape) # torch.Size([16, 1000])