Add profile function.
This commit is contained in:
parent
878c690ac4
commit
50fb9bf6dc
|
@ -290,7 +290,20 @@ def test(epoch):
|
|||
f"({accuracy:.0f}%)\n"
|
||||
)
|
||||
model.printWeight()
|
||||
|
||||
def profiler():
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
|
||||
with record_function("model_inference"):
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
if batch_idx > 10:
|
||||
prof.export_chrome_trace("local.json")
|
||||
assert False
|
||||
|
||||
for epoch in range(1, 300):
|
||||
train(epoch)
|
||||
|
|
Loading…
Reference in New Issue