Add profile function.
This commit is contained in:
parent
878c690ac4
commit
50fb9bf6dc
|
@ -290,7 +290,20 @@ def test(epoch):
|
||||||
f"({accuracy:.0f}%)\n"
|
f"({accuracy:.0f}%)\n"
|
||||||
)
|
)
|
||||||
model.printWeight()
|
model.printWeight()
|
||||||
|
def profiler():
|
||||||
|
for batch_idx, (data, target) in enumerate(train_loader):
|
||||||
|
data, target = data.to(device), target.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
|
||||||
|
with record_function("model_inference"):
|
||||||
|
output = model(data)
|
||||||
|
loss = criterion(output, target)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||||
|
if batch_idx > 10:
|
||||||
|
prof.export_chrome_trace("local.json")
|
||||||
|
assert False
|
||||||
|
|
||||||
for epoch in range(1, 300):
|
for epoch in range(1, 300):
|
||||||
train(epoch)
|
train(epoch)
|
||||||
|
|
Loading…
Reference in New Issue