Refine meaning dataset import.
This commit is contained in:
		
							parent
							
								
									71ab0bb57d
								
							
						
					
					
						commit
						b62444a9dc
					
				| 
						 | 
				
			
			@ -1,12 +1,13 @@
 | 
			
		|||
import os
 | 
			
		||||
import torch, datasets
 | 
			
		||||
import tracemalloc
 | 
			
		||||
import math, gc, time, random, copy
 | 
			
		||||
from itertools import chain
 | 
			
		||||
from typing import Dict, Tuple
 | 
			
		||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
 | 
			
		||||
import numpy as np
 | 
			
		||||
from torch.utils.data import BatchSampler
 | 
			
		||||
from meaning.node_tree import NodeTree
 | 
			
		||||
from node_tree import NodeTree
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MeaningMap:
 | 
			
		||||
| 
						 | 
				
			
			@ -518,6 +519,24 @@ class BatchGroupMeaningDataloader(Dataset):
 | 
			
		|||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
    tracemalloc.start()
 | 
			
		||||
 | 
			
		||||
    md = MeaningDataset(
 | 
			
		||||
        100000,
 | 
			
		||||
        300000,
 | 
			
		||||
        min_subitem=2,
 | 
			
		||||
        max_subitem=6,
 | 
			
		||||
        vocab_size=32,
 | 
			
		||||
        size=1024,
 | 
			
		||||
        stride=2,
 | 
			
		||||
        with_tree=False,
 | 
			
		||||
        use_cache=True,
 | 
			
		||||
    )
 | 
			
		||||
    current, peak = tracemalloc.get_traced_memory()
 | 
			
		||||
    print(f"当前内存使用: {current / 1024 / 1024 / 1024:.4f} GB")
 | 
			
		||||
    print(f"峰值内存使用: {peak / 1024 / 1024 / 1024:.4f} GB")
 | 
			
		||||
    tracemalloc.stop()
 | 
			
		||||
 | 
			
		||||
    md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False)
 | 
			
		||||
    item = md.__getitem__(920)
 | 
			
		||||
    mm = md.get_meaning_map()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue