2024-03-13 19:41:02 +08:00
|
|
|
import os
|
|
|
|
import datasets
|
|
|
|
import torch
|
|
|
|
import math
|
|
|
|
import random
|
|
|
|
from itertools import chain
|
|
|
|
from typing import Dict, Tuple
|
|
|
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
|
|
|
import numpy as np
|
2024-03-18 11:43:41 +08:00
|
|
|
from torch.utils.data import BatchSampler
|
2024-04-10 00:34:47 +08:00
|
|
|
import copy
|
2024-03-13 19:41:02 +08:00
|
|
|
|
|
|
|
|
2024-04-02 19:59:05 +08:00
|
|
|
class MeaningMap:
|
2024-04-14 01:27:58 +08:00
|
|
|
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
|
|
|
|
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
|
|
|
|
assert min_subitem <= max_subitem, "Invalid input"
|
2024-03-13 19:41:02 +08:00
|
|
|
self.size = size
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
self.max_subitem = max_subitem
|
2024-04-14 01:27:58 +08:00
|
|
|
self.min_subitem = min_subitem
|
2024-03-13 19:41:02 +08:00
|
|
|
|
2024-03-26 11:32:02 +08:00
|
|
|
path = "./data/"
|
2024-04-14 01:27:58 +08:00
|
|
|
file = "structured_language_" + str(size) + "_" + str(vocab_size)
|
|
|
|
file += "_" + str(max_subitem) + "_" + str(min_subitem)
|
2024-04-10 00:34:47 +08:00
|
|
|
file = path + file + ".npz"
|
2024-03-13 19:41:02 +08:00
|
|
|
|
2024-03-26 11:32:02 +08:00
|
|
|
if not os.path.exists(path):
|
|
|
|
os.mkdir(path)
|
2024-04-07 17:32:21 +08:00
|
|
|
if os.path.exists(file) and use_cache:
|
2024-03-13 19:41:02 +08:00
|
|
|
print("Load from disk cache: " + file)
|
2024-04-07 17:32:21 +08:00
|
|
|
loaded = np.load(file)
|
|
|
|
slhwm = loaded["slhwm"]
|
2024-04-10 00:34:47 +08:00
|
|
|
dlra = loaded["dlra"]
|
2024-04-07 00:25:21 +08:00
|
|
|
self.ms_map = slhwm[:, 4:]
|
2024-04-10 00:34:47 +08:00
|
|
|
self.ms_data = dlra[:, 0]
|
2024-04-07 00:25:21 +08:00
|
|
|
self.ms_start = slhwm[:, 0]
|
|
|
|
self.ms_len = slhwm[:, 1]
|
2024-04-10 00:34:47 +08:00
|
|
|
self.ms_level = dlra[:, 1]
|
|
|
|
self.ms_rank_idx = dlra[:, 2].astype(np.uint32)
|
|
|
|
self.ms_rank_all = dlra[:, 3].astype(np.uint32)
|
2024-04-07 00:25:21 +08:00
|
|
|
self.ms_height = slhwm[:, 2]
|
|
|
|
self.ms_weight = slhwm[:, 3]
|
2024-03-26 18:15:55 +08:00
|
|
|
print("Load end")
|
2024-03-18 11:43:41 +08:00
|
|
|
else:
|
|
|
|
print("Disk cache miss, build new one.")
|
|
|
|
|
2024-04-07 17:03:35 +08:00
|
|
|
map = np.empty((size, max_subitem), dtype=np.int32)
|
2024-03-18 11:43:41 +08:00
|
|
|
|
|
|
|
index = np.arange(0, size)
|
2024-04-07 00:25:21 +08:00
|
|
|
map = np.random.random((size, max_subitem))
|
2024-03-18 11:43:41 +08:00
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
mask_zero = map.copy()
|
2024-04-14 01:27:58 +08:00
|
|
|
mask_zero[:, 0:min_subitem] = 0.0
|
2024-03-18 11:43:41 +08:00
|
|
|
mask_zero.sort(axis=1)
|
|
|
|
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
|
|
|
|
mask_zero = mask_zero > thre
|
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
item_sum = map.sum(axis=1)
|
2024-03-18 11:43:41 +08:00
|
|
|
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
|
2024-04-07 17:03:35 +08:00
|
|
|
map = (map * scale).astype(np.int32)
|
2024-04-07 00:25:21 +08:00
|
|
|
|
2024-04-07 17:03:35 +08:00
|
|
|
map[mask_zero] = -1
|
2024-03-18 11:43:41 +08:00
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
map[:vocab_size, 0] = np.arange(0, vocab_size)
|
2024-04-07 17:03:35 +08:00
|
|
|
map[:vocab_size, 1:] = -1
|
2024-03-18 11:43:41 +08:00
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
ms_data = [] # meaning sequence
|
|
|
|
ms_level = [] # meaning level, vocab's level is 0
|
2024-04-10 00:34:47 +08:00
|
|
|
ms_rank_idx = [] # meaning index of all level
|
|
|
|
ms_rank_all = [] # meaning all of all level
|
2024-03-18 11:43:41 +08:00
|
|
|
ms_start = [] # meaning sequence start
|
|
|
|
ms_len = [] # meaning sequence length
|
2024-04-07 00:25:21 +08:00
|
|
|
ms_height = [] # meaning tree height
|
|
|
|
ms_weight = [] # meaning tree weight
|
2024-03-18 11:43:41 +08:00
|
|
|
index = 0
|
|
|
|
for i in range(self.vocab_size):
|
2024-04-07 17:03:35 +08:00
|
|
|
ms_data.append(np.array([i]))
|
|
|
|
ms_level.append(np.array([0]))
|
2024-04-10 00:34:47 +08:00
|
|
|
ms_rank_idx.append(np.array([0]))
|
|
|
|
ms_rank_all.append(np.array([0]))
|
2024-03-18 11:43:41 +08:00
|
|
|
ms_start.append(index)
|
|
|
|
ms_len.append(1)
|
2024-04-07 00:25:21 +08:00
|
|
|
ms_height.append(0)
|
|
|
|
ms_weight.append(1)
|
2024-04-07 17:03:35 +08:00
|
|
|
index = index + 1
|
2024-03-18 11:43:41 +08:00
|
|
|
|
|
|
|
for i in range(self.vocab_size, size):
|
2024-04-07 00:25:21 +08:00
|
|
|
m = map[i]
|
2024-04-10 00:34:47 +08:00
|
|
|
m = m[m >= 0] # donot cut off the map such as [0]
|
2024-04-07 17:03:35 +08:00
|
|
|
|
|
|
|
m_list = m.tolist()
|
2024-04-10 00:34:47 +08:00
|
|
|
m_len = len(m_list)
|
2024-04-07 17:03:35 +08:00
|
|
|
assert m_list, "map list can not be empty list"
|
|
|
|
|
|
|
|
ma = np.concatenate([ms_data[newm] for newm in m_list])
|
|
|
|
ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
|
2024-04-10 00:34:47 +08:00
|
|
|
mr = np.concatenate(
|
2024-04-07 17:03:35 +08:00
|
|
|
[
|
2024-04-10 00:34:47 +08:00
|
|
|
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
|
2024-04-07 17:03:35 +08:00
|
|
|
for i, newm in enumerate(m_list)
|
|
|
|
]
|
|
|
|
)
|
2024-04-10 00:34:47 +08:00
|
|
|
mrl = np.concatenate(
|
|
|
|
[
|
|
|
|
([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
|
|
|
|
for i, newm in enumerate(m_list)
|
|
|
|
]
|
|
|
|
)
|
2024-04-07 00:25:21 +08:00
|
|
|
ms_data.append(ma)
|
|
|
|
ms_level.append(ml)
|
2024-04-10 00:34:47 +08:00
|
|
|
ms_rank_idx.append(mr)
|
|
|
|
ms_rank_all.append(mrl)
|
2024-04-07 17:03:35 +08:00
|
|
|
ms_start.append(index)
|
|
|
|
ms_len.append(len(ma))
|
|
|
|
ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
|
|
|
|
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
|
2024-03-18 11:43:41 +08:00
|
|
|
index = index + len(ma)
|
2024-04-07 00:25:21 +08:00
|
|
|
|
|
|
|
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
|
|
|
|
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
|
2024-04-10 00:34:47 +08:00
|
|
|
ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
|
|
|
|
ms_rank_all = np.array(list(chain(*ms_rank_all))).astype(np.uint32)
|
2024-04-07 00:25:21 +08:00
|
|
|
|
2024-04-10 00:34:47 +08:00
|
|
|
d = np.ones(ms_rank_idx.shape, dtype=np.uint32)
|
2024-04-07 00:25:21 +08:00
|
|
|
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
|
2024-04-10 00:34:47 +08:00
|
|
|
ms_rank_idx = (
|
|
|
|
((ms_rank_idx & 0xF) << 28)
|
|
|
|
+ ((ms_rank_idx & 0xF0) << 20)
|
|
|
|
+ ((ms_rank_idx & 0xF00) << 12)
|
|
|
|
+ ((ms_rank_idx & 0xF000) << 4)
|
|
|
|
+ ((ms_rank_idx & 0xF0000) >> 4)
|
|
|
|
+ ((ms_rank_idx & 0xF00000) >> 12)
|
|
|
|
+ ((ms_rank_idx & 0xF000000) >> 20)
|
|
|
|
+ ((ms_rank_idx & 0xF0000000) >> 28)
|
|
|
|
)
|
|
|
|
ms_rank_idx = ((ms_rank_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
|
|
|
|
ms_rank_all = (
|
|
|
|
((ms_rank_all & 0xF) << 28)
|
|
|
|
+ ((ms_rank_all & 0xF0) << 20)
|
|
|
|
+ ((ms_rank_all & 0xF00) << 12)
|
|
|
|
+ ((ms_rank_all & 0xF000) << 4)
|
|
|
|
+ ((ms_rank_all & 0xF0000) >> 4)
|
|
|
|
+ ((ms_rank_all & 0xF00000) >> 12)
|
|
|
|
+ ((ms_rank_all & 0xF000000) >> 20)
|
|
|
|
+ ((ms_rank_all & 0xF0000000) >> 28)
|
2024-04-07 00:25:21 +08:00
|
|
|
)
|
2024-04-10 00:34:47 +08:00
|
|
|
ms_rank_all = ((ms_rank_all >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
|
2024-04-07 00:25:21 +08:00
|
|
|
|
2024-04-07 17:03:35 +08:00
|
|
|
ms_start = np.array(ms_start).astype(np.int32)
|
|
|
|
ms_height = np.array(ms_height).astype(np.int32)
|
|
|
|
ms_weight = np.array(ms_weight).astype(np.int32)
|
|
|
|
ms_len = np.array(ms_len).astype(np.int32)
|
2024-04-07 00:25:21 +08:00
|
|
|
|
|
|
|
slhwm = np.concatenate(
|
|
|
|
(
|
|
|
|
ms_start.reshape((-1, 1)),
|
|
|
|
ms_len.reshape((-1, 1)),
|
|
|
|
ms_height.reshape((-1, 1)),
|
|
|
|
ms_weight.reshape((-1, 1)),
|
2024-04-07 17:03:35 +08:00
|
|
|
map,
|
2024-04-07 00:25:21 +08:00
|
|
|
),
|
|
|
|
axis=1,
|
|
|
|
)
|
2024-04-10 00:34:47 +08:00
|
|
|
dlra = np.stack((ms_data, ms_level, ms_rank_idx.astype(np.int32), ms_rank_all.astype(np.int32)), axis=1)
|
|
|
|
np.savez(file, slhwm=slhwm, dlra=dlra)
|
2024-04-07 00:25:21 +08:00
|
|
|
|
|
|
|
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
|
2024-04-10 00:34:47 +08:00
|
|
|
self.ms_level = ms_level
|
|
|
|
self.ms_rank_idx = ms_rank_idx
|
|
|
|
self.ms_rank_all = ms_rank_all
|
|
|
|
|
|
|
|
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
|
2024-03-18 11:43:41 +08:00
|
|
|
self.ms_start = ms_start
|
|
|
|
self.ms_len = ms_len
|
2024-04-07 00:25:21 +08:00
|
|
|
self.ms_height = ms_height
|
|
|
|
self.ms_weight = ms_weight
|
2024-03-18 11:43:41 +08:00
|
|
|
print("Disk cache build end.")
|
2024-03-13 19:41:02 +08:00
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
def get_sequence(self, meaning): # return sequence[meaning]
|
2024-03-13 19:41:02 +08:00
|
|
|
start = self.ms_start[meaning]
|
|
|
|
len = self.ms_len[meaning]
|
2024-04-10 00:34:47 +08:00
|
|
|
return (
|
|
|
|
self.ms_data[start : start + len],
|
|
|
|
self.ms_level[start : start + len],
|
|
|
|
self.ms_rank_idx[start : start + len],
|
|
|
|
self.ms_rank_all[start : start + len],
|
|
|
|
)
|
2024-03-13 19:41:02 +08:00
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
def get_tree(self, meaning): # return meaning all sub items
|
|
|
|
tree = {}
|
2024-04-02 19:59:05 +08:00
|
|
|
ms = self.ms_map[meaning]
|
|
|
|
for m in ms[ms > 0].tolist():
|
2024-04-07 00:25:21 +08:00
|
|
|
tree[m] = self.get_tree(m) if m >= self.vocab_size else m
|
|
|
|
return tree
|
2024-04-02 19:59:05 +08:00
|
|
|
|
|
|
|
def max_length(self):
|
2024-03-18 11:43:41 +08:00
|
|
|
return max(self.ms_len)
|
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
def get_tree_str(tree, prefix):
|
2024-04-12 20:04:04 +08:00
|
|
|
if isinstance(tree, list):
|
2024-04-07 00:25:21 +08:00
|
|
|
base = ""
|
2024-04-12 20:04:04 +08:00
|
|
|
for t in tree:
|
|
|
|
base += MeaningMap.get_tree_str(t, "")
|
2024-04-07 00:25:21 +08:00
|
|
|
return base
|
2024-04-12 20:04:04 +08:00
|
|
|
else:
|
|
|
|
if isinstance(tree, dict):
|
|
|
|
base = ""
|
|
|
|
last_is_dict = None
|
|
|
|
for key, value in tree.items():
|
|
|
|
new_prefix = (len(str(key)) + 2) * " " + prefix
|
|
|
|
dict_string = MeaningMap.get_tree_str(value, new_prefix)
|
|
|
|
if dict_string:
|
|
|
|
base += "\n" + prefix + str(key) + ": " + dict_string
|
|
|
|
last_is_dict = True
|
|
|
|
else:
|
|
|
|
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
|
|
|
|
last_is_dict = False
|
|
|
|
return base
|
|
|
|
return None
|
|
|
|
|
|
|
|
def get_tree_indexed_str(tree, data, prefix):
|
|
|
|
if isinstance(tree, list):
|
|
|
|
base = ""
|
|
|
|
qlen = 0
|
|
|
|
for i, t in enumerate(tree):
|
|
|
|
s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
|
|
|
|
base += s
|
|
|
|
qlen += l
|
|
|
|
return (base, qlen)
|
|
|
|
else:
|
|
|
|
if isinstance(tree, dict):
|
|
|
|
base = ""
|
|
|
|
qlen = 0
|
|
|
|
last_is_dict = None
|
|
|
|
for key, value in tree.items():
|
|
|
|
new_prefix = (len(str(key)) + 2) * " " + prefix
|
|
|
|
dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
|
|
|
|
if dict_string:
|
|
|
|
base += "\n" + prefix + str(key) + ": " + dict_string
|
|
|
|
last_is_dict = True
|
|
|
|
else:
|
|
|
|
base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
|
|
|
|
last_is_dict = False
|
|
|
|
qlen += l
|
|
|
|
return (base, qlen)
|
|
|
|
return (None, 1)
|
2024-03-13 19:41:02 +08:00
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
def token_frequency(tree, freq):
|
|
|
|
if isinstance(tree, dict):
|
|
|
|
for key, value in tree.items():
|
|
|
|
if key in freq:
|
|
|
|
freq[key] = freq[key] + 1
|
|
|
|
else:
|
|
|
|
freq[key] = 1
|
|
|
|
MeaningMap.token_frequency(value, freq)
|
2024-03-13 19:41:02 +08:00
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
|
|
|
|
class MeaningDataset(Dataset):
|
2024-04-10 00:34:47 +08:00
|
|
|
|
2024-03-18 11:43:41 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
2024-04-10 00:34:47 +08:00
|
|
|
start,
|
|
|
|
end,
|
|
|
|
size,
|
|
|
|
vocab_size,
|
2024-03-18 11:43:41 +08:00
|
|
|
max_subitem=10,
|
2024-04-14 01:27:58 +08:00
|
|
|
min_subitem=1,
|
2024-03-18 11:43:41 +08:00
|
|
|
min_seq_len=2,
|
|
|
|
seed=42,
|
2024-04-07 00:25:21 +08:00
|
|
|
use_cache=True,
|
2024-03-18 11:43:41 +08:00
|
|
|
):
|
2024-03-13 19:41:02 +08:00
|
|
|
np.random.seed(seed)
|
2024-04-14 01:27:58 +08:00
|
|
|
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
|
2024-04-10 00:34:47 +08:00
|
|
|
np.random.seed(seed)
|
|
|
|
|
2024-04-10 14:43:16 +08:00
|
|
|
self.mask_level = None
|
|
|
|
self.mask_idx = None
|
2024-04-07 00:25:21 +08:00
|
|
|
self.tree = []
|
2024-04-10 00:34:47 +08:00
|
|
|
self.seq = []
|
2024-04-07 00:25:21 +08:00
|
|
|
self.level = []
|
2024-04-10 00:34:47 +08:00
|
|
|
self.rank_idx = []
|
|
|
|
self.rank_all = []
|
|
|
|
self.seq_meaning = []
|
|
|
|
self.m_height = map.ms_height
|
|
|
|
self.m_weight = map.ms_weight
|
2024-03-13 19:41:02 +08:00
|
|
|
meanings = np.random.randint(start, end, size=(size))
|
2024-04-10 00:34:47 +08:00
|
|
|
|
|
|
|
seq_len = []
|
2024-03-13 19:41:02 +08:00
|
|
|
for m in meanings:
|
2024-04-10 00:34:47 +08:00
|
|
|
d, l, i, a = map.get_sequence(m)
|
2024-04-07 00:25:21 +08:00
|
|
|
if len(d) >= min_seq_len:
|
|
|
|
self.tree.append({m: map.get_tree(m)})
|
2024-04-10 00:34:47 +08:00
|
|
|
self.seq.append(d)
|
2024-04-07 00:25:21 +08:00
|
|
|
self.level.append(l)
|
2024-04-10 00:34:47 +08:00
|
|
|
self.rank_idx.append(i)
|
|
|
|
self.rank_all.append(a)
|
|
|
|
self.seq_meaning.append(m)
|
|
|
|
seq_len.append(len(d))
|
2024-03-13 19:41:02 +08:00
|
|
|
|
2024-04-10 00:34:47 +08:00
|
|
|
unique, counts = np.unique(seq_len, return_counts=True)
|
2024-03-29 22:10:25 +08:00
|
|
|
print("----------------------------------------------------------------")
|
|
|
|
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
|
2024-04-10 00:34:47 +08:00
|
|
|
print("MeaningDataset size:" + str(len(seq_len)))
|
2024-03-29 22:10:25 +08:00
|
|
|
print("MeaningDataset max sequence length:" + str(max(unique)))
|
|
|
|
print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
|
|
|
|
print("----------------------------------------------------------------")
|
2024-03-26 18:15:55 +08:00
|
|
|
|
2024-03-13 19:41:02 +08:00
|
|
|
def __len__(self):
|
2024-04-10 00:34:47 +08:00
|
|
|
return len(self.seq)
|
2024-03-18 11:43:41 +08:00
|
|
|
|
|
|
|
def len(self):
|
2024-04-10 00:34:47 +08:00
|
|
|
return len(self.seq)
|
2024-03-13 19:41:02 +08:00
|
|
|
|
2024-04-10 14:43:16 +08:00
|
|
|
def set_mask(self, level=None, idx=None):
|
2024-04-12 20:04:04 +08:00
|
|
|
if self.mask_level is not None and self.mask_idx is not None:
|
|
|
|
assert len(self.mask_level) > 0, "len must > 0"
|
|
|
|
assert len(self.mask_level) == len(self.mask_idx), "mask level and mask index must be same length"
|
|
|
|
assert isinstance(self.mask_level, list), "mask level must be list"
|
|
|
|
assert isinstance(self.mask_idx, list), "mask index must be list"
|
2024-04-10 14:43:16 +08:00
|
|
|
self.mask_level = level
|
|
|
|
self.mask_idx = idx
|
|
|
|
|
2024-03-13 19:41:02 +08:00
|
|
|
def __getitem__(self, idx):
|
2024-04-12 20:04:04 +08:00
|
|
|
return self.get_batch([idx])
|
2024-03-13 19:41:02 +08:00
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
def get_batch(self, idx_list): # must equal sequence length
|
2024-04-10 00:34:47 +08:00
|
|
|
data = [self.seq[i] for i in idx_list]
|
2024-04-03 13:03:59 +08:00
|
|
|
output = {}
|
|
|
|
data = torch.tensor(np.stack(data, axis=0)).long()
|
|
|
|
output["input_ids"] = data
|
|
|
|
output["labels"] = data.clone()
|
|
|
|
output["token_type_ids"] = torch.zeros(data.shape)
|
2024-04-07 00:25:21 +08:00
|
|
|
output["tree"] = [self.tree[i] for i in idx_list]
|
|
|
|
output["level"] = [self.level[i] for i in idx_list]
|
2024-04-12 20:04:04 +08:00
|
|
|
output["mask"] = self.get_seq_mask_tensor(idx_list)
|
2024-04-03 13:03:59 +08:00
|
|
|
return output
|
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
def get_token(self, idx): # must equal sequence length
|
2024-04-10 00:34:47 +08:00
|
|
|
return self.seq[idx]
|
2024-04-02 19:59:05 +08:00
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
def get_tree(self, idx):
|
|
|
|
return self.tree[idx]
|
2024-04-03 13:03:59 +08:00
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
def print_tree(self, idx):
|
2024-04-10 00:34:47 +08:00
|
|
|
tokens = self.seq[idx]
|
2024-04-07 00:25:21 +08:00
|
|
|
tree = self.get_tree(idx)
|
|
|
|
s = str(tokens) + "\n"
|
|
|
|
s += MeaningMap.get_tree_str(tree, "")
|
2024-04-03 13:03:59 +08:00
|
|
|
return s
|
|
|
|
|
2024-04-10 00:34:47 +08:00
|
|
|
def copy(self, start, end):
|
|
|
|
new = copy.deepcopy(self)
|
|
|
|
new.tree = new.tree[start:end]
|
|
|
|
new.seq = new.seq[start:end]
|
|
|
|
new.level = new.level[start:end]
|
|
|
|
new.rank_idx = new.rank_idx[start:end]
|
|
|
|
new.rank_all = new.rank_all[start:end]
|
|
|
|
new.seq_meaning = new.seq_meaning[start:end]
|
2024-04-12 20:04:04 +08:00
|
|
|
new.mask_level = self.mask_level
|
|
|
|
new.mask_idx = self.mask_idx
|
2024-04-10 00:34:47 +08:00
|
|
|
return new
|
|
|
|
|
2024-04-02 19:59:05 +08:00
|
|
|
def split(self, ratio):
|
2024-04-10 00:34:47 +08:00
|
|
|
l = self.len()
|
2024-03-18 11:43:41 +08:00
|
|
|
middle = int(l * ratio)
|
2024-04-10 00:34:47 +08:00
|
|
|
return self.copy(0, middle), self.copy(middle, l)
|
2024-03-18 11:43:41 +08:00
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
def token_frequency(self):
|
|
|
|
freq = {}
|
|
|
|
for t in self.tree:
|
|
|
|
MeaningMap.token_frequency(t, freq)
|
|
|
|
return freq
|
2024-03-18 11:43:41 +08:00
|
|
|
|
2024-04-10 00:34:47 +08:00
|
|
|
def get_seq_mask(self, idx, level, index):
|
2024-04-07 00:25:21 +08:00
|
|
|
assert index < 15, "index must < 15"
|
|
|
|
assert level < 8, "level must < 8"
|
2024-04-10 00:34:47 +08:00
|
|
|
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
|
|
|
|
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
|
|
|
return rank_idx == (rank_all + index if index < 0 else index)
|
2024-03-18 11:43:41 +08:00
|
|
|
|
2024-04-12 20:04:04 +08:00
|
|
|
def get_seq_mask_tensor(self, idx_list):
|
|
|
|
if self.mask_level is not None and self.mask_idx is not None:
|
|
|
|
mask = torch.tensor(
|
|
|
|
np.stack([self.get_seq_mask(idx, self.mask_level[0], self.mask_idx[0]) for idx in idx_list], axis=0)
|
|
|
|
)
|
|
|
|
for i, l in enumerate(self.mask_level[1:]):
|
|
|
|
mask = mask & torch.tensor(
|
2024-04-14 01:27:58 +08:00
|
|
|
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i + 1]) for idx in idx_list], axis=0)
|
2024-04-12 20:04:04 +08:00
|
|
|
)
|
|
|
|
return mask
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
|
|
|
|
class BatchGroupMeaningDataloader(Dataset):
|
2024-03-18 11:43:41 +08:00
|
|
|
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
|
|
|
|
self.dataset = dataset
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.drop_last = drop_last
|
|
|
|
|
2024-04-10 00:34:47 +08:00
|
|
|
seq_len = [len(s) for s in dataset.seq]
|
|
|
|
unique, counts = np.unique(seq_len, return_counts=True)
|
2024-03-18 11:43:41 +08:00
|
|
|
gl = {}
|
|
|
|
for u in unique:
|
2024-04-10 00:34:47 +08:00
|
|
|
gl[u] = np.where(seq_len == u)[0]
|
2024-03-18 11:43:41 +08:00
|
|
|
|
|
|
|
lens = list(gl.keys())
|
|
|
|
gs = {}
|
|
|
|
if shuffle:
|
|
|
|
for k in gl.keys():
|
|
|
|
sl = gl[k].copy()
|
|
|
|
np.random.shuffle(sl)
|
|
|
|
gs[k] = sl
|
|
|
|
else:
|
|
|
|
for k in gl.keys():
|
|
|
|
sl = gl[k].copy()
|
|
|
|
gs[k] = sl
|
|
|
|
|
|
|
|
index = np.zeros((0, batch_size), dtype=np.int64)
|
|
|
|
for l in lens:
|
|
|
|
batch = len(gs[l]) // batch_size
|
|
|
|
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
|
|
|
|
index = np.concatenate((index, new), axis=0)
|
2024-04-02 19:59:05 +08:00
|
|
|
|
2024-03-18 11:43:41 +08:00
|
|
|
if shuffle:
|
|
|
|
index_shuffle = np.arange(0, index.shape[0])
|
|
|
|
np.random.shuffle(index_shuffle)
|
|
|
|
index = index[index_shuffle]
|
2024-03-20 22:27:28 +08:00
|
|
|
self.indexBatch = index
|
2024-03-29 22:10:25 +08:00
|
|
|
print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
|
2024-04-10 00:34:47 +08:00
|
|
|
print("Dataloader total:" + str(len(seq_len)) + " drop:" + str(len(seq_len) - len(index) * batch_size))
|
2024-03-18 11:43:41 +08:00
|
|
|
|
|
|
|
def __len__(self):
|
2024-03-20 22:27:28 +08:00
|
|
|
return len(self.indexBatch)
|
2024-03-18 11:43:41 +08:00
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
2024-04-02 19:59:05 +08:00
|
|
|
return self.dataset.get_batch(self.indexBatch[idx])
|
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
def get_tree(self, idx):
|
|
|
|
return [self.dataset.get_tree(i) for i in self.indexBatch[idx]]
|
2024-03-18 11:43:41 +08:00
|
|
|
|
2024-04-07 00:25:21 +08:00
|
|
|
def print_tree(self, idx):
|
|
|
|
idx_list = self.indexBatch[idx]
|
|
|
|
s = "--------------------------------------------------------\n"
|
|
|
|
for i in idx_list:
|
|
|
|
s += self.dataset.print_tree(i)
|
|
|
|
s += "--------------------------------------------------------\n"
|
|
|
|
return s
|
2024-04-03 11:24:00 +08:00
|
|
|
|
2024-03-13 19:41:02 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
2024-04-12 20:04:04 +08:00
|
|
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True)
|
|
|
|
md.set_mask([1], [-1])
|
2024-04-02 19:59:05 +08:00
|
|
|
train, val = md.split(0.95)
|
2024-04-07 00:25:21 +08:00
|
|
|
fdaf = md.__getitem__(920)
|
|
|
|
print(md.print_tree(920))
|
2024-04-10 00:34:47 +08:00
|
|
|
print(md.rank_idx[920])
|
|
|
|
print(md.rank_all[920])
|
|
|
|
mask = md.get_seq_mask(920, 0, -1)
|
|
|
|
print(mask)
|
|
|
|
mask = md.get_seq_mask(920, 1, 0)
|
|
|
|
print(mask)
|
|
|
|
mask = md.get_seq_mask(920, 1, -1)
|
2024-04-07 17:32:21 +08:00
|
|
|
print(mask)
|
2024-04-10 00:34:47 +08:00
|
|
|
mask = md.get_seq_mask(920, 1, 1)
|
|
|
|
print(mask)
|
|
|
|
assert all(
|
|
|
|
np.equal(
|
|
|
|
mask[0:57],
|
|
|
|
np.array(
|
|
|
|
[
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
True,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
False,
|
|
|
|
]
|
|
|
|
),
|
|
|
|
)
|
|
|
|
), "False"
|
2024-04-07 00:25:21 +08:00
|
|
|
freq = md.token_frequency()
|
2024-03-18 11:43:41 +08:00
|
|
|
|
2024-04-14 01:27:58 +08:00
|
|
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
|
|
|
|
md.set_mask([0, 1], [0, -1])
|
|
|
|
dl = BatchGroupMeaningDataloader(md, 1)
|
2024-04-12 20:04:04 +08:00
|
|
|
length = len(dl)
|
|
|
|
it = iter(dl)
|
|
|
|
ne1 = next(it)
|
|
|
|
tree = ne1["tree"]
|
|
|
|
mask = ne1["mask"].cpu().numpy()
|
|
|
|
t = MeaningMap.get_tree_str(tree, "")
|
|
|
|
print(t)
|
|
|
|
m, l = MeaningMap.get_tree_indexed_str(tree, mask, "")
|
|
|
|
print(m)
|
|
|
|
ne2 = next(it)
|
|
|
|
ne3 = next(it)
|
2024-04-10 00:34:47 +08:00
|
|
|
|
|
|
|
# map1 = dl.get_tree(0)
|
|
|
|
# map2 = dl.get_tree(1)
|
|
|
|
# print(dl.print_tree(0))
|
|
|
|
|
|
|
|
# dl = DataLoader(
|
|
|
|
# train,
|
|
|
|
# num_workers=1,
|
|
|
|
# persistent_workers=True,
|
|
|
|
# shuffle=False,
|
|
|
|
# )
|
|
|
|
# it = iter(dl)
|
|
|
|
# ne1 = next(it)
|
|
|
|
# ne2 = next(it)
|
|
|
|
# ne3 = next(it)
|
|
|
|
|
|
|
|
# for i in range(10):
|
|
|
|
# print(next(it)["input_ids"].numpy().tolist())
|