Refine meaning dataset import.
This commit is contained in:
		
							parent
							
								
									71ab0bb57d
								
							
						
					
					
						commit
						b62444a9dc
					
				|  | @ -1,12 +1,13 @@ | |||
| import os | ||||
| import torch, datasets | ||||
| import tracemalloc | ||||
| import math, gc, time, random, copy | ||||
| from itertools import chain | ||||
| from typing import Dict, Tuple | ||||
| from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split | ||||
| import numpy as np | ||||
| from torch.utils.data import BatchSampler | ||||
| from meaning.node_tree import NodeTree | ||||
| from node_tree import NodeTree | ||||
| 
 | ||||
| 
 | ||||
| class MeaningMap: | ||||
|  | @ -518,6 +519,24 @@ class BatchGroupMeaningDataloader(Dataset): | |||
| 
 | ||||
| if __name__ == "__main__": | ||||
| 
 | ||||
|     tracemalloc.start() | ||||
| 
 | ||||
|     md = MeaningDataset( | ||||
|         100000, | ||||
|         300000, | ||||
|         min_subitem=2, | ||||
|         max_subitem=6, | ||||
|         vocab_size=32, | ||||
|         size=1024, | ||||
|         stride=2, | ||||
|         with_tree=False, | ||||
|         use_cache=True, | ||||
|     ) | ||||
|     current, peak = tracemalloc.get_traced_memory() | ||||
|     print(f"当前内存使用: {current / 1024 / 1024 / 1024:.4f} GB") | ||||
|     print(f"峰值内存使用: {peak / 1024 / 1024 / 1024:.4f} GB") | ||||
|     tracemalloc.stop() | ||||
| 
 | ||||
|     md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False) | ||||
|     item = md.__getitem__(920) | ||||
|     mm = md.get_meaning_map() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Colin
						Colin