Add profile function.

This commit is contained in:
Colin 2025-06-09 15:57:53 +08:00
parent 878c690ac4
commit 50fb9bf6dc
1 changed files with 14 additions and 1 deletions

View File

@ -290,7 +290,20 @@ def test(epoch):
f"({accuracy:.0f}%)\n"
)
model.printWeight()
def profiler():
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("model_inference"):
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
if batch_idx > 10:
prof.export_chrome_trace("local.json")
assert False
for epoch in range(1, 300):
train(epoch)