From 50fb9bf6dcf946501493ce21ed033c66cf9fe83f Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 9 Jun 2025 15:57:53 +0800 Subject: [PATCH] Add profile function. --- binary/mnist.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/binary/mnist.py b/binary/mnist.py index dbad614..302c99d 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -290,7 +290,20 @@ def test(epoch): f"({accuracy:.0f}%)\n" ) model.printWeight() - +def profiler(): + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("model_inference"): + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + if batch_idx > 10: + prof.export_chrome_trace("local.json") + assert False for epoch in range(1, 300): train(epoch)