Update meaning dataset.

This commit is contained in:
Colin 2024-04-10 00:34:47 +08:00
parent 7560166b76
commit 7434427ec9
2 changed files with 216 additions and 323 deletions

View File

@ -7,14 +7,16 @@ meaning数据集是一个模仿自然语言以及抽象表达的数据集。
1. token表示最终体现的基本数据表达类似单词。vocab_size表示代表token的数量。 1. token表示最终体现的基本数据表达类似单词。vocab_size表示代表token的数量。
2. meaning表示一种语义符号所有的meaning都由一个编号表达编号越大表示语义越复杂 2. meaning表示一种语义符号所有的meaning都由一个编号表达编号越大表示语义越复杂
3. 所有的meaning都可以由更低标号表达 3. 所有的meaning都可以由更低标号表达
4. 从0到vocab_size的编号表示基本meaning是不能被拆解的也就是token 4. 从0到(vocab_size-1)的编号表示基本meaning是不能被拆解的也就是token
5. meaning通过一层层的向低编号的meaning进行组合替换最终形成一个最底层是token的树形数据 5. meaning通过一层层的向低编号的meaning进行组合替换最终形成一个最底层是token的树形数据
6. level表示当前token相对于root meaning的距离 6. level表示当前token相对于root meaning的距离
7. idx表示当前token在不同层的排序编号每4位表示在一层里面的编号低4位表示最低层级的index高位无用的位用1填充 7. rank_idx表示当前token在不同层的排序编号每4位表示在一层里面的编号低4位表示最低层级的rank_idx高位无用的位用1填充
7. rank_all表示当前token在不同层的分子个数每4位表示在一层里面的编号低4位表示最低层级的rank_all高位无用的位用1填充
8. tree用于存储每个meaning的拆解的数据使用字典表达一个树形结构 8. tree用于存储每个meaning的拆解的数据使用字典表达一个树形结构
9. get_seq_mask返回一个sequence每个token在对应level是不是对应的index 9. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层index=-1:最后一个index=0:第一个
10. meaning_height 10. meaning_height 当前meaning的总高度
11. meaning_weight 11. meaning_weight 当前meaning的总宽度
``` ```
vocab_size = 256 meaning = 115200 vocab_size = 256 meaning = 115200

View File

@ -8,6 +8,7 @@ from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np import numpy as np
from torch.utils.data import BatchSampler from torch.utils.data import BatchSampler
import copy
class MeaningMap: class MeaningMap:
@ -18,7 +19,7 @@ class MeaningMap:
path = "./data/" path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem) file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
file = path + file file = path + file + ".npz"
if not os.path.exists(path): if not os.path.exists(path):
os.mkdir(path) os.mkdir(path)
@ -26,13 +27,14 @@ class MeaningMap:
print("Load from disk cache: " + file) print("Load from disk cache: " + file)
loaded = np.load(file) loaded = np.load(file)
slhwm = loaded["slhwm"] slhwm = loaded["slhwm"]
dli = loaded["dli"] dlra = loaded["dlra"]
self.ms_map = slhwm[:, 4:] self.ms_map = slhwm[:, 4:]
self.ms_data = dli[:, 0] self.ms_data = dlra[:, 0]
self.ms_start = slhwm[:, 0] self.ms_start = slhwm[:, 0]
self.ms_len = slhwm[:, 1] self.ms_len = slhwm[:, 1]
self.ms_level = dli[:, 1] self.ms_level = dlra[:, 1]
self.ms_idx = dli[:, 2].astype(np.uint32) self.ms_rank_idx = dlra[:, 2].astype(np.uint32)
self.ms_rank_all = dlra[:, 3].astype(np.uint32)
self.ms_height = slhwm[:, 2] self.ms_height = slhwm[:, 2]
self.ms_weight = slhwm[:, 3] self.ms_weight = slhwm[:, 3]
print("Load end") print("Load end")
@ -61,7 +63,8 @@ class MeaningMap:
ms_data = [] # meaning sequence ms_data = [] # meaning sequence
ms_level = [] # meaning level, vocab's level is 0 ms_level = [] # meaning level, vocab's level is 0
ms_idx = [] # meaning index of lowest level ms_rank_idx = [] # meaning index of all level
ms_rank_all = [] # meaning all of all level
ms_start = [] # meaning sequence start ms_start = [] # meaning sequence start
ms_len = [] # meaning sequence length ms_len = [] # meaning sequence length
ms_height = [] # meaning tree height ms_height = [] # meaning tree height
@ -70,7 +73,8 @@ class MeaningMap:
for i in range(self.vocab_size): for i in range(self.vocab_size):
ms_data.append(np.array([i])) ms_data.append(np.array([i]))
ms_level.append(np.array([0])) ms_level.append(np.array([0]))
ms_idx.append(np.array([0])) ms_rank_idx.append(np.array([0]))
ms_rank_all.append(np.array([0]))
ms_start.append(index) ms_start.append(index)
ms_len.append(1) ms_len.append(1)
ms_height.append(0) ms_height.append(0)
@ -79,59 +83,70 @@ class MeaningMap:
for i in range(self.vocab_size, size): for i in range(self.vocab_size, size):
m = map[i] m = map[i]
m = m[m >= 0] m = m[m >= 0] # donot cut off the map such as [0]
m_list = m.tolist() m_list = m.tolist()
m_len = len(m_list)
assert m_list, "map list can not be empty list" assert m_list, "map list can not be empty list"
ma = np.concatenate([ms_data[newm] for newm in m_list]) ma = np.concatenate([ms_data[newm] for newm in m_list])
ml = np.concatenate([ms_level[newm] + 1 for newm in m_list]) ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
mi = np.concatenate( mr = np.concatenate(
[ [
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_idx[newm] * 16 + i) ([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
for i, newm in enumerate(m_list) for i, newm in enumerate(m_list)
] ]
) )
ml = ml[ma > 0] mrl = np.concatenate(
mi = mi[ma > 0] [
ma = ma[ma > 0] ([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
for i, newm in enumerate(m_list)
]
)
# ml = ml[ma > 0] # cut off the 0 token, such as [12,32,0,42,32]
# mr = mr[ma > 0]
# mrl = mrl[ma > 0]
# ma = ma[ma > 0]
ms_data.append(ma) ms_data.append(ma)
ms_level.append(ml) ms_level.append(ml)
ms_idx.append(mi) ms_rank_idx.append(mr)
ms_rank_all.append(mrl)
ms_start.append(index) ms_start.append(index)
ms_len.append(len(ma)) ms_len.append(len(ma))
ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1) ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list)) ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
index = index + len(ma) index = index + len(ma)
# offsets = [0, 0, 4, 8, 12, 16, 20, 24, 28]
# for idxmi, mi in enumerate(ms_idx):
# level = ms_level[idxmi]
# for idxnum, num in enumerate(mi):
# l = level[idxnum]
# elements = [(num >> offset) & 0xF for offset in offsets[l:0:-1]]
# num = (num >> (l * 4)) << (l * 4)
# num += sum(elem << (i * 4) for i, elem in enumerate(elements))
# mi[idxnum] = num
ms_data = np.array(list(chain(*ms_data))).astype(np.int32) ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
ms_level = np.array(list(chain(*ms_level))).astype(np.int32) ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
ms_idx = np.array(list(chain(*ms_idx))).astype(np.uint32) ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
ms_rank_all = np.array(list(chain(*ms_rank_all))).astype(np.uint32)
d = np.ones(ms_idx.shape, dtype=np.uint32) d = np.ones(ms_rank_idx.shape, dtype=np.uint32)
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32) d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
ms_idx = ( ms_rank_idx = (
((ms_idx & 0xF) << 28) ((ms_rank_idx & 0xF) << 28)
+ ((ms_idx & 0xF0) << 20) + ((ms_rank_idx & 0xF0) << 20)
+ ((ms_idx & 0xF00) << 12) + ((ms_rank_idx & 0xF00) << 12)
+ ((ms_idx & 0xF000) << 4) + ((ms_rank_idx & 0xF000) << 4)
+ ((ms_idx & 0xF0000) >> 4) + ((ms_rank_idx & 0xF0000) >> 4)
+ ((ms_idx & 0xF00000) >> 12) + ((ms_rank_idx & 0xF00000) >> 12)
+ ((ms_idx & 0xF000000) >> 20) + ((ms_rank_idx & 0xF000000) >> 20)
+ ((ms_idx & 0xF0000000) >> 28) + ((ms_rank_idx & 0xF0000000) >> 28)
) )
ms_idx = ((ms_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32) ms_rank_idx = ((ms_rank_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
ms_rank_all = (
((ms_rank_all & 0xF) << 28)
+ ((ms_rank_all & 0xF0) << 20)
+ ((ms_rank_all & 0xF00) << 12)
+ ((ms_rank_all & 0xF000) << 4)
+ ((ms_rank_all & 0xF0000) >> 4)
+ ((ms_rank_all & 0xF00000) >> 12)
+ ((ms_rank_all & 0xF000000) >> 20)
+ ((ms_rank_all & 0xF0000000) >> 28)
)
ms_rank_all = ((ms_rank_all >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
ms_start = np.array(ms_start).astype(np.int32) ms_start = np.array(ms_start).astype(np.int32)
ms_height = np.array(ms_height).astype(np.int32) ms_height = np.array(ms_height).astype(np.int32)
@ -148,15 +163,17 @@ class MeaningMap:
), ),
axis=1, axis=1,
) )
dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1) dlra = np.stack((ms_data, ms_level, ms_rank_idx.astype(np.int32), ms_rank_all.astype(np.int32)), axis=1)
np.savez(file, slhwm=slhwm, dli=dli) np.savez(file, slhwm=slhwm, dlra=dlra)
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
self.ms_level = ms_level
self.ms_rank_idx = ms_rank_idx
self.ms_rank_all = ms_rank_all
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)] self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
self.ms_start = ms_start self.ms_start = ms_start
self.ms_len = ms_len self.ms_len = ms_len
self.ms_level = ms_level
self.ms_idx = ms_idx
self.ms_height = ms_height self.ms_height = ms_height
self.ms_weight = ms_weight self.ms_weight = ms_weight
print("Disk cache build end.") print("Disk cache build end.")
@ -164,7 +181,12 @@ class MeaningMap:
def get_sequence(self, meaning): # return sequence[meaning] def get_sequence(self, meaning): # return sequence[meaning]
start = self.ms_start[meaning] start = self.ms_start[meaning]
len = self.ms_len[meaning] len = self.ms_len[meaning]
return self.ms_data[start : start + len], self.ms_level[start : start + len], self.ms_idx[start : start + len] return (
self.ms_data[start : start + len],
self.ms_level[start : start + len],
self.ms_rank_idx[start : start + len],
self.ms_rank_all[start : start + len],
)
def get_tree(self, meaning): # return meaning all sub items def get_tree(self, meaning): # return meaning all sub items
tree = {} tree = {}
@ -203,73 +225,70 @@ class MeaningMap:
class MeaningDataset(Dataset): class MeaningDataset(Dataset):
def __init__( def __init__(
self, self,
start=131072, start,
end=1048576, end,
size=32768, size,
vocab_size=4096, vocab_size,
max_subitem=10, max_subitem=10,
min_seq_len=2, min_seq_len=2,
seed=42, seed=42,
data=None,
length=None,
tree=None,
level=None,
idx=None,
use_cache=True, use_cache=True,
): ):
if data != None and length != None and tree != None and level != None and idx != None:
self.data = data
self.length = length
self.tree = tree
self.level = level
self.idx = idx
return
np.random.seed(seed) np.random.seed(seed)
map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache) map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
np.random.seed(seed)
self.tree = [] self.tree = []
self.data = [] self.seq = []
self.level = [] self.level = []
self.idx = [] self.rank_idx = []
self.length = [] self.rank_all = []
self.seq_meaning = []
self.m_height = map.ms_height
self.m_weight = map.ms_weight
meanings = np.random.randint(start, end, size=(size)) meanings = np.random.randint(start, end, size=(size))
seq_len = []
for m in meanings: for m in meanings:
d, l, i = map.get_sequence(m) d, l, i, a = map.get_sequence(m)
if len(d) >= min_seq_len: if len(d) >= min_seq_len:
self.tree.append({m: map.get_tree(m)}) self.tree.append({m: map.get_tree(m)})
self.data.append(d) self.seq.append(d)
self.level.append(l) self.level.append(l)
self.idx.append(i) self.rank_idx.append(i)
self.length.append(len(d)) self.rank_all.append(a)
self.seq_meaning.append(m)
seq_len.append(len(d))
unique, counts = np.unique(self.length, return_counts=True) unique, counts = np.unique(seq_len, return_counts=True)
print("----------------------------------------------------------------") print("----------------------------------------------------------------")
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start)) print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
print("MeaningDataset size:" + str(len(self.length))) print("MeaningDataset size:" + str(len(seq_len)))
print("MeaningDataset max sequence length:" + str(max(unique))) print("MeaningDataset max sequence length:" + str(max(unique)))
print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)])) print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
print("----------------------------------------------------------------") print("----------------------------------------------------------------")
def __len__(self): def __len__(self):
return len(self.data) return len(self.seq)
def len(self): def len(self):
return len(self.data) return len(self.seq)
def __getitem__(self, idx): def __getitem__(self, idx):
output = {} output = {}
data = torch.tensor(self.data[idx]).long() data = torch.tensor(self.seq[idx]).long()
output["input_ids"] = data output["input_ids"] = data
output["labels"] = data.clone() output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape) output["token_type_ids"] = torch.zeros(data.shape)
output["tree"] = self.tree[idx] output["tree"] = self.tree[idx]
output["level"] = self.level[idx] output["level"] = self.level[idx]
output["idx"] = self.idx[idx]
return output return output
def get_batch(self, idx_list): # must equal sequence length def get_batch(self, idx_list): # must equal sequence length
data = [self.data[i] for i in idx_list] data = [self.seq[i] for i in idx_list]
output = {} output = {}
data = torch.tensor(np.stack(data, axis=0)).long() data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data output["input_ids"] = data
@ -277,45 +296,35 @@ class MeaningDataset(Dataset):
output["token_type_ids"] = torch.zeros(data.shape) output["token_type_ids"] = torch.zeros(data.shape)
output["tree"] = [self.tree[i] for i in idx_list] output["tree"] = [self.tree[i] for i in idx_list]
output["level"] = [self.level[i] for i in idx_list] output["level"] = [self.level[i] for i in idx_list]
output["idx"] = [self.idx[i] for i in idx_list]
return output return output
def get_token(self, idx): # must equal sequence length def get_token(self, idx): # must equal sequence length
return self.data[idx] return self.seq[idx]
def get_tree(self, idx): def get_tree(self, idx):
return self.tree[idx] return self.tree[idx]
def print_tree(self, idx): def print_tree(self, idx):
tokens = self.data[idx] tokens = self.seq[idx]
tree = self.get_tree(idx) tree = self.get_tree(idx)
s = str(tokens) + "\n" s = str(tokens) + "\n"
s += MeaningMap.get_tree_str(tree, "") s += MeaningMap.get_tree_str(tree, "")
return s return s
def copy(self, start, end):
new = copy.deepcopy(self)
new.tree = new.tree[start:end]
new.seq = new.seq[start:end]
new.level = new.level[start:end]
new.rank_idx = new.rank_idx[start:end]
new.rank_all = new.rank_all[start:end]
new.seq_meaning = new.seq_meaning[start:end]
return new
def split(self, ratio): def split(self, ratio):
l = len(self.data) l = self.len()
middle = int(l * ratio) middle = int(l * ratio)
d_shuffle = self.data.copy() return self.copy(0, middle), self.copy(middle, l)
l_shuffle = self.length.copy()
m_shuffle = self.tree.copy()
level_shuffle = self.level.copy()
i_shuffle = self.idx.copy()
md1 = MeaningDataset(
data=d_shuffle[:middle],
length=l_shuffle[:middle],
tree=m_shuffle[:middle],
level=level_shuffle[:middle],
idx=i_shuffle[:middle],
)
md2 = MeaningDataset(
data=d_shuffle[middle:],
length=l_shuffle[middle:],
tree=m_shuffle[middle:],
level=level_shuffle[middle:],
idx=i_shuffle[middle:],
)
return md1, md2
def token_frequency(self): def token_frequency(self):
freq = {} freq = {}
@ -323,10 +332,12 @@ class MeaningDataset(Dataset):
MeaningMap.token_frequency(t, freq) MeaningMap.token_frequency(t, freq)
return freq return freq
def get_seq_mask(idx, level, index): def get_seq_mask(self, idx, level, index):
assert index < 15, "index must < 15" assert index < 15, "index must < 15"
assert level < 8, "level must < 8" assert level < 8, "level must < 8"
return [((int(i / (16**level)) & 0xF) == index) for i in idx] rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
return rank_idx == (rank_all + index if index < 0 else index)
class BatchGroupMeaningDataloader(Dataset): class BatchGroupMeaningDataloader(Dataset):
@ -335,11 +346,11 @@ class BatchGroupMeaningDataloader(Dataset):
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last
length = dataset.length seq_len = [len(s) for s in dataset.seq]
unique, counts = np.unique(length, return_counts=True) unique, counts = np.unique(seq_len, return_counts=True)
gl = {} gl = {}
for u in unique: for u in unique:
gl[u] = np.where(length == u)[0] gl[u] = np.where(seq_len == u)[0]
lens = list(gl.keys()) lens = list(gl.keys())
gs = {} gs = {}
@ -365,7 +376,7 @@ class BatchGroupMeaningDataloader(Dataset):
index = index[index_shuffle] index = index[index_shuffle]
self.indexBatch = index self.indexBatch = index
print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index))) print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
print("Dataloader total:" + str(len(length)) + " drop:" + str(len(length) - len(index) * batch_size)) print("Dataloader total:" + str(len(seq_len)) + " drop:" + str(len(seq_len) - len(index) * batch_size))
def __len__(self): def __len__(self):
return len(self.indexBatch) return len(self.indexBatch)
@ -387,229 +398,109 @@ class BatchGroupMeaningDataloader(Dataset):
if __name__ == "__main__": if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True) md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
train, val = md.split(0.95) train, val = md.split(0.95)
fdaf = md.__getitem__(920) fdaf = md.__getitem__(920)
print(md.print_tree(920)) print(md.print_tree(920))
print(md.idx[920]) print(md.rank_idx[920])
mask = MeaningDataset.get_seq_mask(md.idx[920], 1, 1) print(md.rank_all[920])
mask = md.get_seq_mask(920, 0, -1)
print(mask) print(mask)
assert mask == [ mask = md.get_seq_mask(920, 1, 0)
False, print(mask)
True, mask = md.get_seq_mask(920, 1, -1)
True, print(mask)
True, mask = md.get_seq_mask(920, 1, 1)
True, print(mask)
True, assert all(
True, np.equal(
True, mask[0:57],
True, np.array(
True, [
True, False,
False, False,
False, False,
False, False,
False, False,
False, False,
False, False,
False, False,
False, False,
True, False,
True, False,
True, True,
True, True,
True, True,
True, True,
True, True,
False, True,
False, True,
False, True,
False, True,
False, False,
False, False,
False, False,
False, False,
False, False,
False, False,
True, True,
True, False,
False, False,
False, False,
False, False,
False, False,
False, False,
False, False,
True, True,
True, True,
True, True,
True, True,
True, False,
True, False,
True, False,
False, False,
False, True,
False, True,
False, True,
False, True,
False, True,
False, True,
False, True,
False, True,
False, True,
False, False,
False, False,
False, False,
True, False,
True, False,
True, False,
True, ]
True, ),
False, )
False, ), "False"
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
], "False"
freq = md.token_frequency() freq = md.token_frequency()
dl = BatchGroupMeaningDataloader(train, 2) dl = BatchGroupMeaningDataloader(train, 2)
length = len(dl) # length = len(dl)
it = iter(dl) # it = iter(dl)
ne1 = next(it) # ne1 = next(it)
ne2 = next(it) # ne2 = next(it)
ne3 = next(it) # ne3 = next(it)
map1 = dl.get_tree(0) # map1 = dl.get_tree(0)
map2 = dl.get_tree(1) # map2 = dl.get_tree(1)
print(dl.print_tree(0)) # print(dl.print_tree(0))
dl = DataLoader( # dl = DataLoader(
train, # train,
num_workers=1, # num_workers=1,
persistent_workers=True, # persistent_workers=True,
shuffle=False, # shuffle=False,
) # )
it = iter(dl) # it = iter(dl)
ne1 = next(it) # ne1 = next(it)
ne2 = next(it) # ne2 = next(it)
ne3 = next(it) # ne3 = next(it)
for i in range(10): # for i in range(10):
print(next(it)["input_ids"].numpy().tolist()) # print(next(it)["input_ids"].numpy().tolist())