|
|
@ -8,6 +8,7 @@ from typing import Dict, Tuple
|
|
|
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
|
|
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
from torch.utils.data import BatchSampler
|
|
|
|
from torch.utils.data import BatchSampler
|
|
|
|
|
|
|
|
import copy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MeaningMap:
|
|
|
|
class MeaningMap:
|
|
|
@ -18,22 +19,22 @@ class MeaningMap:
|
|
|
|
|
|
|
|
|
|
|
|
path = "./data/"
|
|
|
|
path = "./data/"
|
|
|
|
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
|
|
|
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
|
|
|
file = path + file
|
|
|
|
file = path + file + ".npz"
|
|
|
|
file_slhwm = file + "_slhwm" + ".npy"
|
|
|
|
|
|
|
|
file_dli = file + "_dli" + ".npy"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(path):
|
|
|
|
if not os.path.exists(path):
|
|
|
|
os.mkdir(path)
|
|
|
|
os.mkdir(path)
|
|
|
|
if os.path.exists(file_slhwm) and os.path.exists(file_dli) and use_cache:
|
|
|
|
if os.path.exists(file) and use_cache:
|
|
|
|
print("Load from disk cache: " + file)
|
|
|
|
print("Load from disk cache: " + file)
|
|
|
|
slhwm = np.load(file_slhwm)
|
|
|
|
loaded = np.load(file)
|
|
|
|
dli = np.load(file_dli)
|
|
|
|
slhwm = loaded["slhwm"]
|
|
|
|
|
|
|
|
dlra = loaded["dlra"]
|
|
|
|
self.ms_map = slhwm[:, 4:]
|
|
|
|
self.ms_map = slhwm[:, 4:]
|
|
|
|
self.ms_data = dli[:, 0]
|
|
|
|
self.ms_data = dlra[:, 0]
|
|
|
|
self.ms_start = slhwm[:, 0]
|
|
|
|
self.ms_start = slhwm[:, 0]
|
|
|
|
self.ms_len = slhwm[:, 1]
|
|
|
|
self.ms_len = slhwm[:, 1]
|
|
|
|
self.ms_level = dli[:, 1]
|
|
|
|
self.ms_level = dlra[:, 1]
|
|
|
|
self.ms_idx = dli[:, 2].astype(np.uint32)
|
|
|
|
self.ms_rank_idx = dlra[:, 2].astype(np.uint32)
|
|
|
|
|
|
|
|
self.ms_rank_all = dlra[:, 3].astype(np.uint32)
|
|
|
|
self.ms_height = slhwm[:, 2]
|
|
|
|
self.ms_height = slhwm[:, 2]
|
|
|
|
self.ms_weight = slhwm[:, 3]
|
|
|
|
self.ms_weight = slhwm[:, 3]
|
|
|
|
print("Load end")
|
|
|
|
print("Load end")
|
|
|
@ -62,7 +63,8 @@ class MeaningMap:
|
|
|
|
|
|
|
|
|
|
|
|
ms_data = [] # meaning sequence
|
|
|
|
ms_data = [] # meaning sequence
|
|
|
|
ms_level = [] # meaning level, vocab's level is 0
|
|
|
|
ms_level = [] # meaning level, vocab's level is 0
|
|
|
|
ms_idx = [] # meaning index of lowest level
|
|
|
|
ms_rank_idx = [] # meaning index of all level
|
|
|
|
|
|
|
|
ms_rank_all = [] # meaning all of all level
|
|
|
|
ms_start = [] # meaning sequence start
|
|
|
|
ms_start = [] # meaning sequence start
|
|
|
|
ms_len = [] # meaning sequence length
|
|
|
|
ms_len = [] # meaning sequence length
|
|
|
|
ms_height = [] # meaning tree height
|
|
|
|
ms_height = [] # meaning tree height
|
|
|
@ -71,7 +73,8 @@ class MeaningMap:
|
|
|
|
for i in range(self.vocab_size):
|
|
|
|
for i in range(self.vocab_size):
|
|
|
|
ms_data.append(np.array([i]))
|
|
|
|
ms_data.append(np.array([i]))
|
|
|
|
ms_level.append(np.array([0]))
|
|
|
|
ms_level.append(np.array([0]))
|
|
|
|
ms_idx.append(np.array([0]))
|
|
|
|
ms_rank_idx.append(np.array([0]))
|
|
|
|
|
|
|
|
ms_rank_all.append(np.array([0]))
|
|
|
|
ms_start.append(index)
|
|
|
|
ms_start.append(index)
|
|
|
|
ms_len.append(1)
|
|
|
|
ms_len.append(1)
|
|
|
|
ms_height.append(0)
|
|
|
|
ms_height.append(0)
|
|
|
@ -80,59 +83,70 @@ class MeaningMap:
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(self.vocab_size, size):
|
|
|
|
for i in range(self.vocab_size, size):
|
|
|
|
m = map[i]
|
|
|
|
m = map[i]
|
|
|
|
m = m[m >= 0]
|
|
|
|
m = m[m >= 0] # donot cut off the map such as [0]
|
|
|
|
|
|
|
|
|
|
|
|
m_list = m.tolist()
|
|
|
|
m_list = m.tolist()
|
|
|
|
|
|
|
|
m_len = len(m_list)
|
|
|
|
assert m_list, "map list can not be empty list"
|
|
|
|
assert m_list, "map list can not be empty list"
|
|
|
|
|
|
|
|
|
|
|
|
ma = np.concatenate([ms_data[newm] for newm in m_list])
|
|
|
|
ma = np.concatenate([ms_data[newm] for newm in m_list])
|
|
|
|
ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
|
|
|
|
ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
|
|
|
|
mi = np.concatenate(
|
|
|
|
mr = np.concatenate(
|
|
|
|
[
|
|
|
|
[
|
|
|
|
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_idx[newm] * 16 + i)
|
|
|
|
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
|
|
|
|
for i, newm in enumerate(m_list)
|
|
|
|
for i, newm in enumerate(m_list)
|
|
|
|
]
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
ml = ml[ma > 0]
|
|
|
|
mrl = np.concatenate(
|
|
|
|
mi = mi[ma > 0]
|
|
|
|
[
|
|
|
|
ma = ma[ma > 0]
|
|
|
|
([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
|
|
|
|
|
|
|
|
for i, newm in enumerate(m_list)
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
# ml = ml[ma > 0] # cut off the 0 token, such as [12,32,0,42,32]
|
|
|
|
|
|
|
|
# mr = mr[ma > 0]
|
|
|
|
|
|
|
|
# mrl = mrl[ma > 0]
|
|
|
|
|
|
|
|
# ma = ma[ma > 0]
|
|
|
|
|
|
|
|
|
|
|
|
ms_data.append(ma)
|
|
|
|
ms_data.append(ma)
|
|
|
|
ms_level.append(ml)
|
|
|
|
ms_level.append(ml)
|
|
|
|
ms_idx.append(mi)
|
|
|
|
ms_rank_idx.append(mr)
|
|
|
|
|
|
|
|
ms_rank_all.append(mrl)
|
|
|
|
ms_start.append(index)
|
|
|
|
ms_start.append(index)
|
|
|
|
ms_len.append(len(ma))
|
|
|
|
ms_len.append(len(ma))
|
|
|
|
ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
|
|
|
|
ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
|
|
|
|
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
|
|
|
|
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
|
|
|
|
index = index + len(ma)
|
|
|
|
index = index + len(ma)
|
|
|
|
|
|
|
|
|
|
|
|
# offsets = [0, 0, 4, 8, 12, 16, 20, 24, 28]
|
|
|
|
|
|
|
|
# for idxmi, mi in enumerate(ms_idx):
|
|
|
|
|
|
|
|
# level = ms_level[idxmi]
|
|
|
|
|
|
|
|
# for idxnum, num in enumerate(mi):
|
|
|
|
|
|
|
|
# l = level[idxnum]
|
|
|
|
|
|
|
|
# elements = [(num >> offset) & 0xF for offset in offsets[l:0:-1]]
|
|
|
|
|
|
|
|
# num = (num >> (l * 4)) << (l * 4)
|
|
|
|
|
|
|
|
# num += sum(elem << (i * 4) for i, elem in enumerate(elements))
|
|
|
|
|
|
|
|
# mi[idxnum] = num
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
|
|
|
|
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
|
|
|
|
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
|
|
|
|
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
|
|
|
|
ms_idx = np.array(list(chain(*ms_idx))).astype(np.uint32)
|
|
|
|
ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
|
|
|
|
|
|
|
|
ms_rank_all = np.array(list(chain(*ms_rank_all))).astype(np.uint32)
|
|
|
|
|
|
|
|
|
|
|
|
d = np.ones(ms_idx.shape, dtype=np.uint32)
|
|
|
|
d = np.ones(ms_rank_idx.shape, dtype=np.uint32)
|
|
|
|
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
|
|
|
|
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
|
|
|
|
ms_idx = (
|
|
|
|
ms_rank_idx = (
|
|
|
|
((ms_idx & 0xF) << 28)
|
|
|
|
((ms_rank_idx & 0xF) << 28)
|
|
|
|
+ ((ms_idx & 0xF0) << 20)
|
|
|
|
+ ((ms_rank_idx & 0xF0) << 20)
|
|
|
|
+ ((ms_idx & 0xF00) << 12)
|
|
|
|
+ ((ms_rank_idx & 0xF00) << 12)
|
|
|
|
+ ((ms_idx & 0xF000) << 4)
|
|
|
|
+ ((ms_rank_idx & 0xF000) << 4)
|
|
|
|
+ ((ms_idx & 0xF0000) >> 4)
|
|
|
|
+ ((ms_rank_idx & 0xF0000) >> 4)
|
|
|
|
+ ((ms_idx & 0xF00000) >> 12)
|
|
|
|
+ ((ms_rank_idx & 0xF00000) >> 12)
|
|
|
|
+ ((ms_idx & 0xF000000) >> 20)
|
|
|
|
+ ((ms_rank_idx & 0xF000000) >> 20)
|
|
|
|
+ ((ms_idx & 0xF0000000) >> 28)
|
|
|
|
+ ((ms_rank_idx & 0xF0000000) >> 28)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
ms_idx = ((ms_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
|
|
|
|
ms_rank_idx = ((ms_rank_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
|
|
|
|
|
|
|
|
ms_rank_all = (
|
|
|
|
|
|
|
|
((ms_rank_all & 0xF) << 28)
|
|
|
|
|
|
|
|
+ ((ms_rank_all & 0xF0) << 20)
|
|
|
|
|
|
|
|
+ ((ms_rank_all & 0xF00) << 12)
|
|
|
|
|
|
|
|
+ ((ms_rank_all & 0xF000) << 4)
|
|
|
|
|
|
|
|
+ ((ms_rank_all & 0xF0000) >> 4)
|
|
|
|
|
|
|
|
+ ((ms_rank_all & 0xF00000) >> 12)
|
|
|
|
|
|
|
|
+ ((ms_rank_all & 0xF000000) >> 20)
|
|
|
|
|
|
|
|
+ ((ms_rank_all & 0xF0000000) >> 28)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
ms_rank_all = ((ms_rank_all >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
|
|
|
|
|
|
|
|
|
|
|
|
ms_start = np.array(ms_start).astype(np.int32)
|
|
|
|
ms_start = np.array(ms_start).astype(np.int32)
|
|
|
|
ms_height = np.array(ms_height).astype(np.int32)
|
|
|
|
ms_height = np.array(ms_height).astype(np.int32)
|
|
|
@ -149,17 +163,17 @@ class MeaningMap:
|
|
|
|
),
|
|
|
|
),
|
|
|
|
axis=1,
|
|
|
|
axis=1,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1)
|
|
|
|
dlra = np.stack((ms_data, ms_level, ms_rank_idx.astype(np.int32), ms_rank_all.astype(np.int32)), axis=1)
|
|
|
|
|
|
|
|
np.savez(file, slhwm=slhwm, dlra=dlra)
|
|
|
|
|
|
|
|
|
|
|
|
np.save(file_slhwm, slhwm)
|
|
|
|
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
|
|
|
|
np.save(file_dli, dli)
|
|
|
|
self.ms_level = ms_level
|
|
|
|
|
|
|
|
self.ms_rank_idx = ms_rank_idx
|
|
|
|
|
|
|
|
self.ms_rank_all = ms_rank_all
|
|
|
|
|
|
|
|
|
|
|
|
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
|
|
|
|
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
|
|
|
|
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
|
|
|
|
|
|
|
|
self.ms_start = ms_start
|
|
|
|
self.ms_start = ms_start
|
|
|
|
self.ms_len = ms_len
|
|
|
|
self.ms_len = ms_len
|
|
|
|
self.ms_level = ms_level
|
|
|
|
|
|
|
|
self.ms_idx = ms_idx
|
|
|
|
|
|
|
|
self.ms_height = ms_height
|
|
|
|
self.ms_height = ms_height
|
|
|
|
self.ms_weight = ms_weight
|
|
|
|
self.ms_weight = ms_weight
|
|
|
|
print("Disk cache build end.")
|
|
|
|
print("Disk cache build end.")
|
|
|
@ -167,7 +181,12 @@ class MeaningMap:
|
|
|
|
def get_sequence(self, meaning): # return sequence[meaning]
|
|
|
|
def get_sequence(self, meaning): # return sequence[meaning]
|
|
|
|
start = self.ms_start[meaning]
|
|
|
|
start = self.ms_start[meaning]
|
|
|
|
len = self.ms_len[meaning]
|
|
|
|
len = self.ms_len[meaning]
|
|
|
|
return self.ms_data[start : start + len], self.ms_level[start : start + len], self.ms_idx[start : start + len]
|
|
|
|
return (
|
|
|
|
|
|
|
|
self.ms_data[start : start + len],
|
|
|
|
|
|
|
|
self.ms_level[start : start + len],
|
|
|
|
|
|
|
|
self.ms_rank_idx[start : start + len],
|
|
|
|
|
|
|
|
self.ms_rank_all[start : start + len],
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def get_tree(self, meaning): # return meaning all sub items
|
|
|
|
def get_tree(self, meaning): # return meaning all sub items
|
|
|
|
tree = {}
|
|
|
|
tree = {}
|
|
|
@ -206,73 +225,70 @@ class MeaningMap:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MeaningDataset(Dataset):
|
|
|
|
class MeaningDataset(Dataset):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
start=131072,
|
|
|
|
start,
|
|
|
|
end=1048576,
|
|
|
|
end,
|
|
|
|
size=32768,
|
|
|
|
size,
|
|
|
|
vocab_size=4096,
|
|
|
|
vocab_size,
|
|
|
|
max_subitem=10,
|
|
|
|
max_subitem=10,
|
|
|
|
min_seq_len=2,
|
|
|
|
min_seq_len=2,
|
|
|
|
seed=42,
|
|
|
|
seed=42,
|
|
|
|
data=None,
|
|
|
|
|
|
|
|
length=None,
|
|
|
|
|
|
|
|
tree=None,
|
|
|
|
|
|
|
|
level=None,
|
|
|
|
|
|
|
|
idx=None,
|
|
|
|
|
|
|
|
use_cache=True,
|
|
|
|
use_cache=True,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
if data != None and length != None and tree != None and level != None and idx != None:
|
|
|
|
|
|
|
|
self.data = data
|
|
|
|
|
|
|
|
self.length = length
|
|
|
|
|
|
|
|
self.tree = tree
|
|
|
|
|
|
|
|
self.level = level
|
|
|
|
|
|
|
|
self.idx = idx
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
np.random.seed(seed)
|
|
|
|
np.random.seed(seed)
|
|
|
|
map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
|
|
|
|
map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
|
|
|
|
|
|
|
|
np.random.seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
self.tree = []
|
|
|
|
self.tree = []
|
|
|
|
self.data = []
|
|
|
|
self.seq = []
|
|
|
|
self.level = []
|
|
|
|
self.level = []
|
|
|
|
self.idx = []
|
|
|
|
self.rank_idx = []
|
|
|
|
self.length = []
|
|
|
|
self.rank_all = []
|
|
|
|
|
|
|
|
self.seq_meaning = []
|
|
|
|
|
|
|
|
self.m_height = map.ms_height
|
|
|
|
|
|
|
|
self.m_weight = map.ms_weight
|
|
|
|
meanings = np.random.randint(start, end, size=(size))
|
|
|
|
meanings = np.random.randint(start, end, size=(size))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq_len = []
|
|
|
|
for m in meanings:
|
|
|
|
for m in meanings:
|
|
|
|
d, l, i = map.get_sequence(m)
|
|
|
|
d, l, i, a = map.get_sequence(m)
|
|
|
|
if len(d) >= min_seq_len:
|
|
|
|
if len(d) >= min_seq_len:
|
|
|
|
self.tree.append({m: map.get_tree(m)})
|
|
|
|
self.tree.append({m: map.get_tree(m)})
|
|
|
|
self.data.append(d)
|
|
|
|
self.seq.append(d)
|
|
|
|
self.level.append(l)
|
|
|
|
self.level.append(l)
|
|
|
|
self.idx.append(i)
|
|
|
|
self.rank_idx.append(i)
|
|
|
|
self.length.append(len(d))
|
|
|
|
self.rank_all.append(a)
|
|
|
|
|
|
|
|
self.seq_meaning.append(m)
|
|
|
|
|
|
|
|
seq_len.append(len(d))
|
|
|
|
|
|
|
|
|
|
|
|
unique, counts = np.unique(self.length, return_counts=True)
|
|
|
|
unique, counts = np.unique(seq_len, return_counts=True)
|
|
|
|
print("----------------------------------------------------------------")
|
|
|
|
print("----------------------------------------------------------------")
|
|
|
|
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
|
|
|
|
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
|
|
|
|
print("MeaningDataset size:" + str(len(self.length)))
|
|
|
|
print("MeaningDataset size:" + str(len(seq_len)))
|
|
|
|
print("MeaningDataset max sequence length:" + str(max(unique)))
|
|
|
|
print("MeaningDataset max sequence length:" + str(max(unique)))
|
|
|
|
print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
|
|
|
|
print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
|
|
|
|
print("----------------------------------------------------------------")
|
|
|
|
print("----------------------------------------------------------------")
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.data)
|
|
|
|
return len(self.seq)
|
|
|
|
|
|
|
|
|
|
|
|
def len(self):
|
|
|
|
def len(self):
|
|
|
|
return len(self.data)
|
|
|
|
return len(self.seq)
|
|
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
output = {}
|
|
|
|
output = {}
|
|
|
|
data = torch.tensor(self.data[idx]).long()
|
|
|
|
data = torch.tensor(self.seq[idx]).long()
|
|
|
|
output["input_ids"] = data
|
|
|
|
output["input_ids"] = data
|
|
|
|
output["labels"] = data.clone()
|
|
|
|
output["labels"] = data.clone()
|
|
|
|
output["token_type_ids"] = torch.zeros(data.shape)
|
|
|
|
output["token_type_ids"] = torch.zeros(data.shape)
|
|
|
|
output["tree"] = self.tree[idx]
|
|
|
|
output["tree"] = self.tree[idx]
|
|
|
|
output["level"] = self.level[idx]
|
|
|
|
output["level"] = self.level[idx]
|
|
|
|
output["idx"] = self.idx[idx]
|
|
|
|
|
|
|
|
return output
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
def get_batch(self, idx_list): # must equal sequence length
|
|
|
|
def get_batch(self, idx_list): # must equal sequence length
|
|
|
|
data = [self.data[i] for i in idx_list]
|
|
|
|
data = [self.seq[i] for i in idx_list]
|
|
|
|
output = {}
|
|
|
|
output = {}
|
|
|
|
data = torch.tensor(np.stack(data, axis=0)).long()
|
|
|
|
data = torch.tensor(np.stack(data, axis=0)).long()
|
|
|
|
output["input_ids"] = data
|
|
|
|
output["input_ids"] = data
|
|
|
@ -280,45 +296,35 @@ class MeaningDataset(Dataset):
|
|
|
|
output["token_type_ids"] = torch.zeros(data.shape)
|
|
|
|
output["token_type_ids"] = torch.zeros(data.shape)
|
|
|
|
output["tree"] = [self.tree[i] for i in idx_list]
|
|
|
|
output["tree"] = [self.tree[i] for i in idx_list]
|
|
|
|
output["level"] = [self.level[i] for i in idx_list]
|
|
|
|
output["level"] = [self.level[i] for i in idx_list]
|
|
|
|
output["idx"] = [self.idx[i] for i in idx_list]
|
|
|
|
|
|
|
|
return output
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
def get_token(self, idx): # must equal sequence length
|
|
|
|
def get_token(self, idx): # must equal sequence length
|
|
|
|
return self.data[idx]
|
|
|
|
return self.seq[idx]
|
|
|
|
|
|
|
|
|
|
|
|
def get_tree(self, idx):
|
|
|
|
def get_tree(self, idx):
|
|
|
|
return self.tree[idx]
|
|
|
|
return self.tree[idx]
|
|
|
|
|
|
|
|
|
|
|
|
def print_tree(self, idx):
|
|
|
|
def print_tree(self, idx):
|
|
|
|
tokens = self.data[idx]
|
|
|
|
tokens = self.seq[idx]
|
|
|
|
tree = self.get_tree(idx)
|
|
|
|
tree = self.get_tree(idx)
|
|
|
|
s = str(tokens) + "\n"
|
|
|
|
s = str(tokens) + "\n"
|
|
|
|
s += MeaningMap.get_tree_str(tree, "")
|
|
|
|
s += MeaningMap.get_tree_str(tree, "")
|
|
|
|
return s
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def copy(self, start, end):
|
|
|
|
|
|
|
|
new = copy.deepcopy(self)
|
|
|
|
|
|
|
|
new.tree = new.tree[start:end]
|
|
|
|
|
|
|
|
new.seq = new.seq[start:end]
|
|
|
|
|
|
|
|
new.level = new.level[start:end]
|
|
|
|
|
|
|
|
new.rank_idx = new.rank_idx[start:end]
|
|
|
|
|
|
|
|
new.rank_all = new.rank_all[start:end]
|
|
|
|
|
|
|
|
new.seq_meaning = new.seq_meaning[start:end]
|
|
|
|
|
|
|
|
return new
|
|
|
|
|
|
|
|
|
|
|
|
def split(self, ratio):
|
|
|
|
def split(self, ratio):
|
|
|
|
l = len(self.data)
|
|
|
|
l = self.len()
|
|
|
|
middle = int(l * ratio)
|
|
|
|
middle = int(l * ratio)
|
|
|
|
d_shuffle = self.data.copy()
|
|
|
|
return self.copy(0, middle), self.copy(middle, l)
|
|
|
|
l_shuffle = self.length.copy()
|
|
|
|
|
|
|
|
m_shuffle = self.tree.copy()
|
|
|
|
|
|
|
|
level_shuffle = self.level.copy()
|
|
|
|
|
|
|
|
i_shuffle = self.idx.copy()
|
|
|
|
|
|
|
|
md1 = MeaningDataset(
|
|
|
|
|
|
|
|
data=d_shuffle[:middle],
|
|
|
|
|
|
|
|
length=l_shuffle[:middle],
|
|
|
|
|
|
|
|
tree=m_shuffle[:middle],
|
|
|
|
|
|
|
|
level=level_shuffle[:middle],
|
|
|
|
|
|
|
|
idx=i_shuffle[:middle],
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
md2 = MeaningDataset(
|
|
|
|
|
|
|
|
data=d_shuffle[middle:],
|
|
|
|
|
|
|
|
length=l_shuffle[middle:],
|
|
|
|
|
|
|
|
tree=m_shuffle[middle:],
|
|
|
|
|
|
|
|
level=level_shuffle[middle:],
|
|
|
|
|
|
|
|
idx=i_shuffle[middle:],
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return md1, md2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def token_frequency(self):
|
|
|
|
def token_frequency(self):
|
|
|
|
freq = {}
|
|
|
|
freq = {}
|
|
|
@ -326,10 +332,12 @@ class MeaningDataset(Dataset):
|
|
|
|
MeaningMap.token_frequency(t, freq)
|
|
|
|
MeaningMap.token_frequency(t, freq)
|
|
|
|
return freq
|
|
|
|
return freq
|
|
|
|
|
|
|
|
|
|
|
|
def get_seq_mask(idx, level, index):
|
|
|
|
def get_seq_mask(self, idx, level, index):
|
|
|
|
assert index < 15, "index must < 15"
|
|
|
|
assert index < 15, "index must < 15"
|
|
|
|
assert level < 8, "level must < 8"
|
|
|
|
assert level < 8, "level must < 8"
|
|
|
|
return [((int(i / (16**level)) & 0xF) == index) for i in idx]
|
|
|
|
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
|
|
|
|
|
|
|
|
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
|
|
|
|
|
|
|
return rank_idx == (rank_all + index if index < 0 else index)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BatchGroupMeaningDataloader(Dataset):
|
|
|
|
class BatchGroupMeaningDataloader(Dataset):
|
|
|
@ -338,11 +346,11 @@ class BatchGroupMeaningDataloader(Dataset):
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.drop_last = drop_last
|
|
|
|
self.drop_last = drop_last
|
|
|
|
|
|
|
|
|
|
|
|
length = dataset.length
|
|
|
|
seq_len = [len(s) for s in dataset.seq]
|
|
|
|
unique, counts = np.unique(length, return_counts=True)
|
|
|
|
unique, counts = np.unique(seq_len, return_counts=True)
|
|
|
|
gl = {}
|
|
|
|
gl = {}
|
|
|
|
for u in unique:
|
|
|
|
for u in unique:
|
|
|
|
gl[u] = np.where(length == u)[0]
|
|
|
|
gl[u] = np.where(seq_len == u)[0]
|
|
|
|
|
|
|
|
|
|
|
|
lens = list(gl.keys())
|
|
|
|
lens = list(gl.keys())
|
|
|
|
gs = {}
|
|
|
|
gs = {}
|
|
|
@ -368,7 +376,7 @@ class BatchGroupMeaningDataloader(Dataset):
|
|
|
|
index = index[index_shuffle]
|
|
|
|
index = index[index_shuffle]
|
|
|
|
self.indexBatch = index
|
|
|
|
self.indexBatch = index
|
|
|
|
print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
|
|
|
|
print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
|
|
|
|
print("Dataloader total:" + str(len(length)) + " drop:" + str(len(length) - len(index) * batch_size))
|
|
|
|
print("Dataloader total:" + str(len(seq_len)) + " drop:" + str(len(seq_len) - len(index) * batch_size))
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.indexBatch)
|
|
|
|
return len(self.indexBatch)
|
|
|
@ -390,38 +398,109 @@ class BatchGroupMeaningDataloader(Dataset):
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
md = MeaningDataset(100000, 115200, vocab_size=128, size=1024, use_cache=False)
|
|
|
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
|
|
|
train, val = md.split(0.95)
|
|
|
|
train, val = md.split(0.95)
|
|
|
|
fdaf = md.__getitem__(920)
|
|
|
|
fdaf = md.__getitem__(920)
|
|
|
|
print(md.print_tree(920))
|
|
|
|
print(md.print_tree(920))
|
|
|
|
print(md.idx[920])
|
|
|
|
print(md.rank_idx[920])
|
|
|
|
fdasfe = MeaningDataset.get_seq_mask(md.idx[920], 1, 1)
|
|
|
|
print(md.rank_all[920])
|
|
|
|
print(fdasfe)
|
|
|
|
mask = md.get_seq_mask(920, 0, -1)
|
|
|
|
|
|
|
|
print(mask)
|
|
|
|
|
|
|
|
mask = md.get_seq_mask(920, 1, 0)
|
|
|
|
|
|
|
|
print(mask)
|
|
|
|
|
|
|
|
mask = md.get_seq_mask(920, 1, -1)
|
|
|
|
|
|
|
|
print(mask)
|
|
|
|
|
|
|
|
mask = md.get_seq_mask(920, 1, 1)
|
|
|
|
|
|
|
|
print(mask)
|
|
|
|
|
|
|
|
assert all(
|
|
|
|
|
|
|
|
np.equal(
|
|
|
|
|
|
|
|
mask[0:57],
|
|
|
|
|
|
|
|
np.array(
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
True,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
False,
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
), "False"
|
|
|
|
freq = md.token_frequency()
|
|
|
|
freq = md.token_frequency()
|
|
|
|
|
|
|
|
|
|
|
|
dl = BatchGroupMeaningDataloader(train, 2)
|
|
|
|
dl = BatchGroupMeaningDataloader(train, 2)
|
|
|
|
length = len(dl)
|
|
|
|
# length = len(dl)
|
|
|
|
it = iter(dl)
|
|
|
|
# it = iter(dl)
|
|
|
|
ne1 = next(it)
|
|
|
|
# ne1 = next(it)
|
|
|
|
ne2 = next(it)
|
|
|
|
# ne2 = next(it)
|
|
|
|
ne3 = next(it)
|
|
|
|
# ne3 = next(it)
|
|
|
|
|
|
|
|
|
|
|
|
map1 = dl.get_tree(0)
|
|
|
|
# map1 = dl.get_tree(0)
|
|
|
|
map2 = dl.get_tree(1)
|
|
|
|
# map2 = dl.get_tree(1)
|
|
|
|
print(dl.print_tree(0))
|
|
|
|
# print(dl.print_tree(0))
|
|
|
|
|
|
|
|
|
|
|
|
dl = DataLoader(
|
|
|
|
# dl = DataLoader(
|
|
|
|
train,
|
|
|
|
# train,
|
|
|
|
num_workers=1,
|
|
|
|
# num_workers=1,
|
|
|
|
persistent_workers=True,
|
|
|
|
# persistent_workers=True,
|
|
|
|
shuffle=False,
|
|
|
|
# shuffle=False,
|
|
|
|
)
|
|
|
|
# )
|
|
|
|
it = iter(dl)
|
|
|
|
# it = iter(dl)
|
|
|
|
ne1 = next(it)
|
|
|
|
# ne1 = next(it)
|
|
|
|
ne2 = next(it)
|
|
|
|
# ne2 = next(it)
|
|
|
|
ne3 = next(it)
|
|
|
|
# ne3 = next(it)
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(10):
|
|
|
|
# for i in range(10):
|
|
|
|
daf = next(it)["input_ids"].numpy().tolist()
|
|
|
|
# print(next(it)["input_ids"].numpy().tolist())
|
|
|
|
|
|
|
|
|
|
|
|
print(daf)
|
|
|
|
|
|
|
|