diff --git a/.gitignore b/.gitignore index 7b62f12..e2e0f87 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ __pycache__ -.vscode \ No newline at end of file +.vscode +*.txt \ No newline at end of file diff --git a/tools/__init__.py b/tools/__init__.py index d5e2687..70cc191 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -1 +1,2 @@ -from tools import show \ No newline at end of file +from tools import show +from tools import mem_tracker \ No newline at end of file diff --git a/tools/mem_tracker.py b/tools/mem_tracker.py new file mode 100644 index 0000000..f232017 --- /dev/null +++ b/tools/mem_tracker.py @@ -0,0 +1,171 @@ +import gc +import datetime +import inspect + +import torch +import numpy as np +import torch.nn as nn + +dtype_memory_size_dict = { + torch.float64: 64 / 8, + torch.double: 64 / 8, + torch.float32: 32 / 8, + torch.float: 32 / 8, + torch.float16: 16 / 8, + torch.half: 16 / 8, + torch.int64: 64 / 8, + torch.long: 64 / 8, + torch.int32: 32 / 8, + torch.int: 32 / 8, + torch.int16: 16 / 8, + torch.short: 16 / 6, + torch.uint8: 8 / 8, + torch.int8: 8 / 8, +} +# compatibility of torch1.0 +if getattr(torch, "bfloat16", None) is not None: + dtype_memory_size_dict[torch.bfloat16] = 16 / 8 +if getattr(torch, "bool", None) is not None: + dtype_memory_size_dict[torch.bool] = ( + 8 / 8 + ) # pytorch use 1 byte for a bool, see https://github.com/pytorch/pytorch/issues/41571 + + +def get_mem_space(x): + try: + ret = dtype_memory_size_dict[x] + except KeyError: + print(f"dtype {x} is not supported!") + return ret + + +class MemTracker(object): + """ + Class used to track pytorch memory usage + Arguments: + detail(bool, default True): whether the function shows the detail gpu memory usage + path(str): where to save log file + verbose(bool, default False): whether show the trivial exception + device(int): GPU number, default is 0 + """ + + def __init__(self, detail=True, path="", verbose=False, device=0): + self.print_detail = detail + self.last_tensor_sizes = set() + self.gpu_profile_fn = path + f"{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_track.txt" + self.verbose = verbose + self.begin = True + self.device = device + + def get_tensors(self): + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)): + tensor = obj + else: + continue + if tensor.is_cuda: + yield tensor + except Exception as e: + if self.verbose: + print("A trivial exception occured: {}".format(e)) + + def get_tensor_usage(self): + sizes = [np.prod(np.array(tensor.size())) * get_mem_space(tensor.dtype) for tensor in self.get_tensors()] + return np.sum(sizes) / 1024**2 + + def get_allocate_usage(self): + return torch.cuda.memory_allocated() / 1024**2 + + def clear_cache(self): + gc.collect() + torch.cuda.empty_cache() + + def print_all_gpu_tensor(self, file=None): + for x in self.get_tensors(): + print(x.size(), x.dtype, np.prod(np.array(x.size())) * get_mem_space(x.dtype) / 1024**2, file=file) + + def track(self): + """ + Track the GPU memory usage + """ + frameinfo = inspect.stack()[1] + where_str = frameinfo.filename + " line " + str(frameinfo.lineno) + ": " + frameinfo.function + + with open(self.gpu_profile_fn, "a+") as f: + if self.begin: + f.write( + f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |" + f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb" + f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n" + ) + self.begin = False + + if self.print_detail is True: + ts_list = [(tensor.size(), tensor.dtype) for tensor in self.get_tensors()] + new_tensor_sizes = { + ( + type(x), + tuple(x.size()), + ts_list.count((x.size(), x.dtype)), + np.prod(np.array(x.size())) * get_mem_space(x.dtype) / 1024**2, + x.dtype, + ) + for x in self.get_tensors() + } + for t, s, n, m, data_type in new_tensor_sizes - self.last_tensor_sizes: + f.write( + f"+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} | {data_type}\n" + ) + for t, s, n, m, data_type in self.last_tensor_sizes - new_tensor_sizes: + f.write( + f"- | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} | {data_type}\n" + ) + + self.last_tensor_sizes = new_tensor_sizes + + f.write( + f"\nAt {where_str:<50}" + f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb" + f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n" + ) + + +def ModelSize(model, input, type_size=4): + para = sum([np.prod(list(p.size())) for p in model.parameters()]) + # print('Model {} : Number of params: {}'.format(model._get_name(), para)) + print("Model {} : params: {:4f}M".format(model._get_name(), para * type_size / 1000 / 1000)) + + input_ = input.clone() + input_.requires_grad_(requires_grad=False) + + mods = list(model.modules()) + out_sizes = [] + + for i in range(1, len(mods)): + m = mods[i] + if isinstance(m, nn.ReLU): + if m.inplace: + continue + out = m(input_) + out_sizes.append(np.array(out.size())) + input_ = out + + total_nums = 0 + for i in range(len(out_sizes)): + s = out_sizes[i] + nums = np.prod(np.array(s)) + total_nums += nums + + # print('Model {} : Number of intermedite variables without backward: {}'.format(model._get_name(), total_nums)) + # print('Model {} : Number of intermedite variables with backward: {}'.format(model._get_name(), total_nums*2)) + print( + "Model {} : intermedite variables: {:3f} M (without backward)".format( + model._get_name(), total_nums * type_size / 1000 / 1000 + ) + ) + print( + "Model {} : intermedite variables: {:3f} M (with backward)".format( + model._get_name(), total_nums * type_size * 2 / 1000 / 1000 + ) + ) diff --git a/tools/test.py b/tools/test.py index a97d099..ae26f69 100644 --- a/tools/test.py +++ b/tools/test.py @@ -1,5 +1,6 @@ import show import torch +import mem_tracker # radata = torch.randn(8192, 128) @@ -10,11 +11,10 @@ radata = torch.randn(127) show.DumpTensorToImage(radata, "test.png") -radata = torch.randn(3,127,127) +radata = torch.randn(3, 127, 127) show.DumpTensorToImage(radata, "test.png") - radata = torch.randn(127, 127) show.DumpTensorToLog(radata, "test.log") @@ -22,3 +22,7 @@ show.DumpTensorToLog(radata, "test.log") radata = torch.randn(127, 127) - 0.5 show.ProbGE0(radata) show.DumpProb() + +radata = torch.randn(127, 127).cuda() +tracker = mem_tracker.MemTracker() +tracker.track()