44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
|
import torch
|
||
|
|
||
|
# 1. 定义模型(需满足静态形状和静态控制流)
|
||
|
class SimpleModel(torch.nn.Module):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.fc = torch.nn.Linear(3 * 224 * 224, 1000)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = x.view(x.size(0), -1) # 静态形状操作
|
||
|
return torch.relu(self.fc(x)) # 避免动态控制流
|
||
|
|
||
|
model = SimpleModel().cuda()
|
||
|
|
||
|
# 2. 准备静态输入/输出占位张量
|
||
|
static_input = torch.randn(16, 3, 224, 224, device='cuda')
|
||
|
static_output = torch.zeros(16, 1000, device='cuda')
|
||
|
|
||
|
# 3. 预热阶段(必须在非默认流)
|
||
|
s = torch.cuda.Stream()
|
||
|
s.wait_stream(torch.cuda.current_stream())
|
||
|
with torch.cuda.stream(s):
|
||
|
for _ in range(3): # 预热3次
|
||
|
static_output = model(static_input)
|
||
|
torch.cuda.current_stream().wait_stream(s)
|
||
|
|
||
|
# 4. 创建并捕获CUDA图
|
||
|
g = torch.cuda.CUDAGraph()
|
||
|
with torch.cuda.graph(g):
|
||
|
# 注意:此处操作不会实际执行,仅记录计算图
|
||
|
static_output = model(static_input)
|
||
|
|
||
|
# 5. 使用图计算(更新数据+重放)
|
||
|
def run_graph(new_input):
|
||
|
# 将新数据复制到捕获的输入张量
|
||
|
static_input.copy_(new_input)
|
||
|
# 重放计算图
|
||
|
g.replay()
|
||
|
return static_output.clone() # 返回结果副本
|
||
|
|
||
|
# 测试
|
||
|
new_data = torch.randn(16, 3, 224, 224, device='cuda')
|
||
|
result = run_graph(new_data)
|
||
|
print(result.shape) # torch.Size([16, 1000])
|