Refine meaning dataset import.

This commit is contained in:
Colin 2025-08-10 14:30:51 +08:00
parent 71ab0bb57d
commit b62444a9dc
2 changed files with 20 additions and 1 deletions

0
wit/meaning/__init__.py Normal file
View File

View File

@ -1,12 +1,13 @@
import os import os
import torch, datasets import torch, datasets
import tracemalloc
import math, gc, time, random, copy import math, gc, time, random, copy
from itertools import chain from itertools import chain
from typing import Dict, Tuple from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np import numpy as np
from torch.utils.data import BatchSampler from torch.utils.data import BatchSampler
from meaning.node_tree import NodeTree from node_tree import NodeTree
class MeaningMap: class MeaningMap:
@ -518,6 +519,24 @@ class BatchGroupMeaningDataloader(Dataset):
if __name__ == "__main__": if __name__ == "__main__":
tracemalloc.start()
md = MeaningDataset(
100000,
300000,
min_subitem=2,
max_subitem=6,
vocab_size=32,
size=1024,
stride=2,
with_tree=False,
use_cache=True,
)
current, peak = tracemalloc.get_traced_memory()
print(f"当前内存使用: {current / 1024 / 1024 / 1024:.4f} GB")
print(f"峰值内存使用: {peak / 1024 / 1024 / 1024:.4f} GB")
tracemalloc.stop()
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False) md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False)
item = md.__getitem__(920) item = md.__getitem__(920)
mm = md.get_meaning_map() mm = md.get_meaning_map()