witnn/Amp/test_gdb.py

18 lines
553 B
Python
Raw Normal View History

2022-01-14 17:16:26 +08:00
import torch
from torch.cuda.amp import autocast as autocast
from torch.profiler import profile, ProfilerActivity, record_function
data = torch.ones((300, 300))
data = data.cuda()
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True, record_shapes=True) as prof:
# with record_function("model_inference"):
# with autocast():
# output = torch.mm(data, data)
# print(prof.key_averages().table(sort_by="cpu_time_total"))
with autocast():
output = torch.mm(data, data)
print(output.cpu())