Refine meaning dataset import.

This commit is contained in:
Colin 2025-08-10 14:30:51 +08:00
parent 71ab0bb57d
commit b62444a9dc
2 changed files with 20 additions and 1 deletions

0
wit/meaning/__init__.py Normal file
View File

View File

@ -1,12 +1,13 @@
import os
import torch, datasets
import tracemalloc
import math, gc, time, random, copy
from itertools import chain
from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
from meaning.node_tree import NodeTree
from node_tree import NodeTree
class MeaningMap:
@ -518,6 +519,24 @@ class BatchGroupMeaningDataloader(Dataset):
if __name__ == "__main__":
tracemalloc.start()
md = MeaningDataset(
100000,
300000,
min_subitem=2,
max_subitem=6,
vocab_size=32,
size=1024,
stride=2,
with_tree=False,
use_cache=True,
)
current, peak = tracemalloc.get_traced_memory()
print(f"当前内存使用: {current / 1024 / 1024 / 1024:.4f} GB")
print(f"峰值内存使用: {peak / 1024 / 1024 / 1024:.4f} GB")
tracemalloc.stop()
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False)
item = md.__getitem__(920)
mm = md.get_meaning_map()