Refine meaning dataset import.
This commit is contained in:
parent
71ab0bb57d
commit
b62444a9dc
|
@ -1,12 +1,13 @@
|
|||
import os
|
||||
import torch, datasets
|
||||
import tracemalloc
|
||||
import math, gc, time, random, copy
|
||||
from itertools import chain
|
||||
from typing import Dict, Tuple
|
||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler
|
||||
from meaning.node_tree import NodeTree
|
||||
from node_tree import NodeTree
|
||||
|
||||
|
||||
class MeaningMap:
|
||||
|
@ -518,6 +519,24 @@ class BatchGroupMeaningDataloader(Dataset):
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
tracemalloc.start()
|
||||
|
||||
md = MeaningDataset(
|
||||
100000,
|
||||
300000,
|
||||
min_subitem=2,
|
||||
max_subitem=6,
|
||||
vocab_size=32,
|
||||
size=1024,
|
||||
stride=2,
|
||||
with_tree=False,
|
||||
use_cache=True,
|
||||
)
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
print(f"当前内存使用: {current / 1024 / 1024 / 1024:.4f} GB")
|
||||
print(f"峰值内存使用: {peak / 1024 / 1024 / 1024:.4f} GB")
|
||||
tracemalloc.stop()
|
||||
|
||||
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False)
|
||||
item = md.__getitem__(920)
|
||||
mm = md.get_meaning_map()
|
||||
|
|
Loading…
Reference in New Issue