Refine meaning dataset.
This commit is contained in:
parent
2bc9e3b57e
commit
33d1e22655
|
@ -0,0 +1,67 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
a = np.array([0, 1, 32 + 1, (32 + 1) * 16, 4, 5, 6, 7, 8, 8]).astype(np.uint32)
|
||||||
|
b = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 8]).astype(np.uint32)
|
||||||
|
|
||||||
|
|
||||||
|
d = np.ones(a.shape, dtype=np.uint32)
|
||||||
|
d = (d * 0xFFFFFFFF) << (b * 4)
|
||||||
|
|
||||||
|
c = a.astype(np.uint32)
|
||||||
|
|
||||||
|
cc = (
|
||||||
|
((c & 0xF) << 28)
|
||||||
|
+ ((c & 0xF0) << 20)
|
||||||
|
+ ((c & 0xF00) << 12)
|
||||||
|
+ ((c & 0xF000) << 4)
|
||||||
|
+ ((c & 0xF0000) >> 4)
|
||||||
|
+ ((c & 0xF00000) >> 12)
|
||||||
|
+ ((c & 0xF000000) >> 20)
|
||||||
|
+ ((c & 0xF0000000) >> 28)
|
||||||
|
)
|
||||||
|
cc = (cc >> ((8 - b) * 4)) + d
|
||||||
|
|
||||||
|
print(cc[3] == 4294963218)
|
||||||
|
|
||||||
|
b = np.ones((10)).astype(np.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tree_str_new(tree, prefix):
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
base = ""
|
||||||
|
last_is_dict = None
|
||||||
|
for key, value in tree.items():
|
||||||
|
new_prefix = (len(str(key)) + 2) * " " + prefix
|
||||||
|
dict_string = get_tree_str_new(value, new_prefix)
|
||||||
|
if dict_string:
|
||||||
|
base += "\n" + prefix + str(key) + ": " + dict_string
|
||||||
|
last_is_dict = True
|
||||||
|
else:
|
||||||
|
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
|
||||||
|
last_is_dict = False
|
||||||
|
return base
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
tree = {
|
||||||
|
112377: {
|
||||||
|
2944: {228: 228, 263: 263, 252: 252, 396: 396},
|
||||||
|
10024: {
|
||||||
|
1424: {189: 189, 209: 209, 200: 200, 102: 102, 178: 178, 22: 22, 9: 9},
|
||||||
|
1053: 432,
|
||||||
|
1350: {68: 68, 200: 200, 50: 50, 17: 17, 36: 36, 283: 283},
|
||||||
|
7: 7,
|
||||||
|
},
|
||||||
|
18196: 322,
|
||||||
|
13373: {
|
||||||
|
1420: {99: 99, 189: 189, 163: 163},
|
||||||
|
2109: {320: 320, 92: 92, 95: 95, 224: 224, 435: 435, 4: 4, 373: 373, 27: 27, 228: 228},
|
||||||
|
708: 708,
|
||||||
|
2196: {27: 27, 157: 157, 87: 87, 231: 231},
|
||||||
|
401: 401,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
print(get_tree_str_new(tree, ""))
|
|
@ -0,0 +1,38 @@
|
||||||
|
# meaning dataset
|
||||||
|
|
||||||
|
meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
|
||||||
|
|
||||||
|
## 概念
|
||||||
|
|
||||||
|
1. token表示最终体现的基本数据表达,类似单词。vocab_size表示代表token的数量。
|
||||||
|
2. meaning表示一种语义(符号),所有的meaning都由一个编号表达,编号越大表示语义越复杂
|
||||||
|
3. 所有的meaning都可以由更低标号表达
|
||||||
|
4. 从0到vocab_size的编号表示基本meaning,是不能被拆解的,也就是token
|
||||||
|
5. meaning通过一层层的向低编号的meaning进行组合替换,最终形成一个最底层是token的树形数据
|
||||||
|
6. level表示当前token相对于root meaning的距离
|
||||||
|
7. idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的index,高位无用的位用1填充
|
||||||
|
8. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构
|
||||||
|
9. get_seq_mask返回一个sequence每个token在对应level是不是对应的index
|
||||||
|
10. meaning_height
|
||||||
|
11. meaning_weight
|
||||||
|
|
||||||
|
```
|
||||||
|
vocab_size = 256 meaning = 115200
|
||||||
|
|
||||||
|
115200
|
||||||
|
/ | \
|
||||||
|
10240 1100 12322
|
||||||
|
/ | \ / \ / | \
|
||||||
|
512 32 1201 245 233 3214 532 324
|
||||||
|
/ \ / \ / \ | / \
|
||||||
|
123 42 320 500 1231 23 324 93 176
|
||||||
|
/ \ / \ / \ / \
|
||||||
|
176 11 255 129 129 99 211 111
|
||||||
|
|
||||||
|
sequence = 123 42 32 176 11 255 129 245 233 129 99 23 211 111 93 176
|
||||||
|
level = 3 3 2 4 4 4 4 2 2 4 4 3 4 4 3 3
|
||||||
|
idx at 0 = 0 1 1 0 1 0 1 0 1 0 1 2 0 1 0 1
|
||||||
|
idx at 1 = 0 0 0 0 0 1 1 1 1 0 0 0 0 0 2 2
|
||||||
|
idx 0 1 1 0 1 16 17 16 17 0 1 2 0 1 32 33
|
||||||
|
|
||||||
|
```
|
|
@ -11,8 +11,7 @@ from torch.utils.data import BatchSampler
|
||||||
|
|
||||||
|
|
||||||
class MeaningMap:
|
class MeaningMap:
|
||||||
|
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, use_cache=True):
|
||||||
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10):
|
|
||||||
self.size = size
|
self.size = size
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_subitem = max_subitem
|
self.max_subitem = max_subitem
|
||||||
|
@ -20,99 +19,186 @@ class MeaningMap:
|
||||||
path = "./data/"
|
path = "./data/"
|
||||||
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
||||||
file = path + file
|
file = path + file
|
||||||
file_map = file + "_map" + ".npy"
|
file_slhwm = file + "_slhwm" + ".npy"
|
||||||
file_start = file + "_start" + ".npy"
|
file_dli = file + "_dli" + ".npy"
|
||||||
file_len = file + "_len" + ".npy"
|
|
||||||
file_data = file + "_data" + ".npy"
|
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
if (
|
if os.path.exists(file_slhwm) and os.path.exists(file_dli) and use_cache:
|
||||||
os.path.exists(file_start)
|
|
||||||
and os.path.exists(file_len)
|
|
||||||
and os.path.exists(file_data)
|
|
||||||
and os.path.exists(file_map)
|
|
||||||
):
|
|
||||||
print("Load from disk cache: " + file)
|
print("Load from disk cache: " + file)
|
||||||
self.ms_map = np.load(file_map)
|
slhwm = np.load(file_slhwm)
|
||||||
self.ms_data = np.load(file_data)
|
dli = np.load(file_dli)
|
||||||
self.ms_start = np.load(file_start)
|
self.ms_map = slhwm[:, 4:]
|
||||||
self.ms_len = np.load(file_len)
|
self.ms_data = dli[:, 0]
|
||||||
|
self.ms_start = slhwm[:, 0]
|
||||||
|
self.ms_len = slhwm[:, 1]
|
||||||
|
self.ms_level = dli[:, 1]
|
||||||
|
self.ms_idx = dli[:, 2].astype(np.uint32)
|
||||||
|
self.ms_height = slhwm[:, 2]
|
||||||
|
self.ms_weight = slhwm[:, 3]
|
||||||
print("Load end")
|
print("Load end")
|
||||||
else:
|
else:
|
||||||
print("Disk cache miss, build new one.")
|
print("Disk cache miss, build new one.")
|
||||||
|
|
||||||
mm = np.empty((size, max_subitem), dtype=np.int32)
|
map = np.empty((size, max_subitem), dtype=np.uint32)
|
||||||
|
|
||||||
index = np.arange(0, size)
|
index = np.arange(0, size)
|
||||||
mm = np.random.random((size, max_subitem))
|
map = np.random.random((size, max_subitem))
|
||||||
|
|
||||||
mask_zero = mm.copy()
|
mask_zero = map.copy()
|
||||||
mask_zero[:, 0] = 0.0
|
mask_zero[:, 0] = 0.0
|
||||||
mask_zero.sort(axis=1)
|
mask_zero.sort(axis=1)
|
||||||
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
|
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
|
||||||
mask_zero = mask_zero > thre
|
mask_zero = mask_zero > thre
|
||||||
|
|
||||||
item_sum = mm.sum(axis=1)
|
item_sum = map.sum(axis=1)
|
||||||
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
|
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
|
||||||
mm = mm * scale
|
map = map * scale
|
||||||
mm[mask_zero] = 0
|
|
||||||
|
|
||||||
mm[:vocab_size, 0] = np.arange(0, vocab_size)
|
map[mask_zero] = 0
|
||||||
mm[:vocab_size, 1:] = 0
|
|
||||||
mm = mm.astype(np.int32)
|
|
||||||
|
|
||||||
ms = [] # meaning sequence
|
map[:vocab_size, 0] = np.arange(0, vocab_size)
|
||||||
|
map[:vocab_size, 1:] = 0
|
||||||
|
map = map.astype(np.uint32)
|
||||||
|
|
||||||
|
ms_data = [] # meaning sequence
|
||||||
|
ms_level = [] # meaning level, vocab's level is 0
|
||||||
|
ms_idx = [] # meaning index of lowest level
|
||||||
ms_start = [] # meaning sequence start
|
ms_start = [] # meaning sequence start
|
||||||
ms_len = [] # meaning sequence length
|
ms_len = [] # meaning sequence length
|
||||||
|
ms_height = [] # meaning tree height
|
||||||
|
ms_weight = [] # meaning tree weight
|
||||||
index = 0
|
index = 0
|
||||||
for i in range(self.vocab_size):
|
for i in range(self.vocab_size):
|
||||||
ms.append([i])
|
ms_data.append([i])
|
||||||
|
ms_level.append([0])
|
||||||
|
ms_idx.append([0])
|
||||||
ms_start.append(index)
|
ms_start.append(index)
|
||||||
ms_len.append(1)
|
ms_len.append(1)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
|
ms_height.append(0)
|
||||||
|
ms_weight.append(1)
|
||||||
|
|
||||||
for i in range(self.vocab_size, size):
|
for i in range(self.vocab_size, size):
|
||||||
m = mm[i]
|
m = map[i]
|
||||||
m = m[m > 0]
|
m = m[m > 0]
|
||||||
ma = []
|
ma = []
|
||||||
for newm in m.tolist():
|
ml = []
|
||||||
ma = ma + ms[newm]
|
mi = []
|
||||||
ms.append(ma)
|
for i, newm in enumerate(m.tolist()):
|
||||||
|
ma = ma + ms_data[newm]
|
||||||
|
ml = ml + [x + 1 for x in ms_level[newm]]
|
||||||
|
mi = mi + ([0xFFFFFFF0 + i] if newm < self.vocab_size else [n * 16 + i for n in ms_idx[newm]])
|
||||||
|
ms_data.append(ma)
|
||||||
ms_start.append(index)
|
ms_start.append(index)
|
||||||
ms_len.append(len(ma))
|
ms_len.append(len(ma))
|
||||||
|
ms_level.append(ml)
|
||||||
|
ms_idx.append(mi)
|
||||||
index = index + len(ma)
|
index = index + len(ma)
|
||||||
|
ms_height.append(max([-1] + [ms_height[sub_m] for sub_m in m.tolist()]) + 1)
|
||||||
|
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m.tolist()))
|
||||||
|
|
||||||
ms_data = list(chain(*ms))
|
# offsets = [0, 0, 4, 8, 12, 16, 20, 24, 28]
|
||||||
np.save(file_map, np.array(mm).astype(np.int32))
|
# for idxmi, mi in enumerate(ms_idx):
|
||||||
np.save(file_data, np.array(ms_data).astype(np.int32))
|
# level = ms_level[idxmi]
|
||||||
np.save(file_start, np.array(ms_start).astype(np.int32))
|
# for idxnum, num in enumerate(mi):
|
||||||
np.save(file_len, np.array(ms_len).astype(np.int32))
|
# l = level[idxnum]
|
||||||
|
# elements = [(num >> offset) & 0xF for offset in offsets[l:0:-1]]
|
||||||
|
# num = (num >> (l * 4)) << (l * 4)
|
||||||
|
# num += sum(elem << (i * 4) for i, elem in enumerate(elements))
|
||||||
|
# mi[idxnum] = num
|
||||||
|
|
||||||
self.ms_map = mm
|
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
|
||||||
self.ms_data = ms_data
|
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
|
||||||
|
ms_idx = np.array(list(chain(*ms_idx))).astype(np.uint32)
|
||||||
|
|
||||||
|
d = np.ones(ms_idx.shape, dtype=np.uint32)
|
||||||
|
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
|
||||||
|
ms_idx = (
|
||||||
|
((ms_idx & 0xF) << 28)
|
||||||
|
+ ((ms_idx & 0xF0) << 20)
|
||||||
|
+ ((ms_idx & 0xF00) << 12)
|
||||||
|
+ ((ms_idx & 0xF000) << 4)
|
||||||
|
+ ((ms_idx & 0xF0000) >> 4)
|
||||||
|
+ ((ms_idx & 0xF00000) >> 12)
|
||||||
|
+ ((ms_idx & 0xF000000) >> 20)
|
||||||
|
+ ((ms_idx & 0xF0000000) >> 28)
|
||||||
|
)
|
||||||
|
ms_idx = ((ms_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
|
||||||
|
|
||||||
|
ms_start = np.array(ms_start).astype(np.uint32)
|
||||||
|
ms_height = np.array(ms_height).astype(np.uint32)
|
||||||
|
ms_weight = np.array(ms_weight).astype(np.uint32)
|
||||||
|
ms_len = np.array(ms_len).astype(np.uint32)
|
||||||
|
ms_map = map.astype(np.uint32)
|
||||||
|
|
||||||
|
slhwm = np.concatenate(
|
||||||
|
(
|
||||||
|
ms_start.reshape((-1, 1)),
|
||||||
|
ms_len.reshape((-1, 1)),
|
||||||
|
ms_height.reshape((-1, 1)),
|
||||||
|
ms_weight.reshape((-1, 1)),
|
||||||
|
ms_map,
|
||||||
|
),
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1)
|
||||||
|
|
||||||
|
np.save(file_slhwm, slhwm)
|
||||||
|
np.save(file_dli, dli)
|
||||||
|
|
||||||
|
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
|
||||||
|
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
|
||||||
self.ms_start = ms_start
|
self.ms_start = ms_start
|
||||||
self.ms_len = ms_len
|
self.ms_len = ms_len
|
||||||
|
self.ms_level = ms_level
|
||||||
|
self.ms_idx = ms_idx
|
||||||
|
self.ms_height = ms_height
|
||||||
|
self.ms_weight = ms_weight
|
||||||
print("Disk cache build end.")
|
print("Disk cache build end.")
|
||||||
|
|
||||||
def get_sequence(self, meaning):
|
def get_sequence(self, meaning): # return sequence[meaning]
|
||||||
start = self.ms_start[meaning]
|
start = self.ms_start[meaning]
|
||||||
len = self.ms_len[meaning]
|
len = self.ms_len[meaning]
|
||||||
return self.ms_data[start : start + len]
|
return self.ms_data[start : start + len], self.ms_level[start : start + len], self.ms_idx[start : start + len]
|
||||||
|
|
||||||
def get_mapping(self, meaning):
|
def get_tree(self, meaning): # return meaning all sub items
|
||||||
mapping = {}
|
tree = {}
|
||||||
ms = self.ms_map[meaning]
|
ms = self.ms_map[meaning]
|
||||||
for m in ms[ms > 0].tolist():
|
for m in ms[ms > 0].tolist():
|
||||||
mapping[m] = self.get_mapping(m) if m >= self.vocab_size else m
|
tree[m] = self.get_tree(m) if m >= self.vocab_size else m
|
||||||
return mapping
|
return tree
|
||||||
|
|
||||||
def max_length(self):
|
def max_length(self):
|
||||||
return max(self.ms_len)
|
return max(self.ms_len)
|
||||||
|
|
||||||
|
def get_tree_str(tree, prefix):
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
base = ""
|
||||||
|
last_is_dict = None
|
||||||
|
for key, value in tree.items():
|
||||||
|
new_prefix = (len(str(key)) + 2) * " " + prefix
|
||||||
|
dict_string = MeaningMap.get_tree_str(value, new_prefix)
|
||||||
|
if dict_string:
|
||||||
|
base += "\n" + prefix + str(key) + ": " + dict_string
|
||||||
|
last_is_dict = True
|
||||||
|
else:
|
||||||
|
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
|
||||||
|
last_is_dict = False
|
||||||
|
return base
|
||||||
|
return None
|
||||||
|
|
||||||
|
def token_frequency(tree, freq):
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
for key, value in tree.items():
|
||||||
|
if key in freq:
|
||||||
|
freq[key] = freq[key] + 1
|
||||||
|
else:
|
||||||
|
freq[key] = 1
|
||||||
|
MeaningMap.token_frequency(value, freq)
|
||||||
|
|
||||||
|
|
||||||
class MeaningDataset(Dataset):
|
class MeaningDataset(Dataset):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
start=131072,
|
start=131072,
|
||||||
|
@ -124,25 +210,34 @@ class MeaningDataset(Dataset):
|
||||||
seed=42,
|
seed=42,
|
||||||
data=None,
|
data=None,
|
||||||
length=None,
|
length=None,
|
||||||
mapping=None,
|
tree=None,
|
||||||
|
level=None,
|
||||||
|
idx=None,
|
||||||
|
use_cache=True,
|
||||||
):
|
):
|
||||||
if data != None and length != None and mapping != None:
|
if data != None and length != None and tree != None and level != None and idx != None:
|
||||||
self.data = data
|
self.data = data
|
||||||
self.length = length
|
self.length = length
|
||||||
self.mapping = mapping
|
self.tree = tree
|
||||||
|
self.level = level
|
||||||
|
self.idx = idx
|
||||||
return
|
return
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576
|
map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
|
||||||
self.mapping = []
|
self.tree = []
|
||||||
self.data = []
|
self.data = []
|
||||||
|
self.level = []
|
||||||
|
self.idx = []
|
||||||
self.length = []
|
self.length = []
|
||||||
meanings = np.random.randint(start, end, size=(size))
|
meanings = np.random.randint(start, end, size=(size))
|
||||||
for m in meanings:
|
for m in meanings:
|
||||||
sq = mm.get_sequence(m)
|
d, l, i = map.get_sequence(m)
|
||||||
if len(sq) >= min_seq_len:
|
if len(d) >= min_seq_len:
|
||||||
self.mapping.append({m: mm.get_mapping(m)})
|
self.tree.append({m: map.get_tree(m)})
|
||||||
self.data.append(sq)
|
self.data.append(d)
|
||||||
self.length.append(len(sq))
|
self.level.append(l)
|
||||||
|
self.idx.append(i)
|
||||||
|
self.length.append(len(d))
|
||||||
|
|
||||||
unique, counts = np.unique(self.length, return_counts=True)
|
unique, counts = np.unique(self.length, return_counts=True)
|
||||||
print("----------------------------------------------------------------")
|
print("----------------------------------------------------------------")
|
||||||
|
@ -164,50 +259,34 @@ class MeaningDataset(Dataset):
|
||||||
output["input_ids"] = data
|
output["input_ids"] = data
|
||||||
output["labels"] = data.clone()
|
output["labels"] = data.clone()
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
|
output["tree"] = self.tree[idx]
|
||||||
|
output["level"] = self.level[idx]
|
||||||
|
output["idx"] = self.idx[idx]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_batch(self, index_list): # must equal sequence length
|
def get_batch(self, idx_list): # must equal sequence length
|
||||||
data = [self.data[i] for i in index_list]
|
data = [self.data[i] for i in idx_list]
|
||||||
output = {}
|
output = {}
|
||||||
data = torch.tensor(np.stack(data, axis=0)).long()
|
data = torch.tensor(np.stack(data, axis=0)).long()
|
||||||
output["input_ids"] = data
|
output["input_ids"] = data
|
||||||
output["labels"] = data.clone()
|
output["labels"] = data.clone()
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
|
output["tree"] = [self.tree[i] for i in idx_list]
|
||||||
|
output["level"] = [self.level[i] for i in idx_list]
|
||||||
|
output["idx"] = [self.idx[i] for i in idx_list]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_token_batch(self, index_list): # must equal sequence length
|
def get_token(self, idx): # must equal sequence length
|
||||||
return [self.data[i] for i in index_list]
|
return self.data[idx]
|
||||||
|
|
||||||
def print_token_batch(self, index_list): # must equal sequence length
|
def get_tree(self, idx):
|
||||||
data = [self.data[i] for i in index_list]
|
return self.tree[idx]
|
||||||
output = {}
|
|
||||||
data = torch.tensor(np.stack(data, axis=0)).long()
|
|
||||||
output["input_ids"] = data
|
|
||||||
output["labels"] = data.clone()
|
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def get_mapping_batch(self, index_list):
|
def print_tree(self, idx):
|
||||||
return [self.mapping[i] for i in index_list]
|
tokens = self.data[idx]
|
||||||
|
tree = self.get_tree(idx)
|
||||||
def __get_mapping_str__(map, prefix):
|
s = str(tokens) + "\n"
|
||||||
if isinstance(map, dict):
|
s += MeaningMap.get_tree_str(tree, "")
|
||||||
base = ""
|
|
||||||
for key, value in map.items():
|
|
||||||
base += prefix + str(key) + "\n"
|
|
||||||
base += MeaningDataset.__get_mapping_str__(value, prefix + " ")
|
|
||||||
return base
|
|
||||||
else:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def print_mapping_batch(self, index_list):
|
|
||||||
tokens = self.get_token_batch(index_list)
|
|
||||||
map = self.get_mapping_batch(index_list)
|
|
||||||
s = "--------------------------------------------------------\n"
|
|
||||||
for i, m in enumerate(map):
|
|
||||||
s += str(tokens[i]) + "\n"
|
|
||||||
s += MeaningDataset.__get_mapping_str__(m, "")
|
|
||||||
s += "--------------------------------------------------------\n"
|
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def split(self, ratio):
|
def split(self, ratio):
|
||||||
|
@ -215,14 +294,38 @@ class MeaningDataset(Dataset):
|
||||||
middle = int(l * ratio)
|
middle = int(l * ratio)
|
||||||
d_shuffle = self.data.copy()
|
d_shuffle = self.data.copy()
|
||||||
l_shuffle = self.length.copy()
|
l_shuffle = self.length.copy()
|
||||||
m_shuffle = self.mapping.copy()
|
m_shuffle = self.tree.copy()
|
||||||
md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle], mapping=m_shuffle[:middle])
|
level_shuffle = self.level.copy()
|
||||||
md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:], mapping=m_shuffle[middle:])
|
i_shuffle = self.idx.copy()
|
||||||
|
md1 = MeaningDataset(
|
||||||
|
data=d_shuffle[:middle],
|
||||||
|
length=l_shuffle[:middle],
|
||||||
|
tree=m_shuffle[:middle],
|
||||||
|
level=level_shuffle[:middle],
|
||||||
|
idx=i_shuffle[:middle],
|
||||||
|
)
|
||||||
|
md2 = MeaningDataset(
|
||||||
|
data=d_shuffle[middle:],
|
||||||
|
length=l_shuffle[middle:],
|
||||||
|
tree=m_shuffle[middle:],
|
||||||
|
level=level_shuffle[middle:],
|
||||||
|
idx=i_shuffle[middle:],
|
||||||
|
)
|
||||||
return md1, md2
|
return md1, md2
|
||||||
|
|
||||||
|
def token_frequency(self):
|
||||||
|
freq = {}
|
||||||
|
for t in self.tree:
|
||||||
|
MeaningMap.token_frequency(t, freq)
|
||||||
|
return freq
|
||||||
|
|
||||||
|
def get_seq_mask(idx, level, index):
|
||||||
|
assert index < 15, "index must < 15"
|
||||||
|
assert level < 8, "level must < 8"
|
||||||
|
return [((int(i / (16**level)) & 0xF) == index) for i in idx]
|
||||||
|
|
||||||
|
|
||||||
class BatchGroupMeaningDataloader(Dataset):
|
class BatchGroupMeaningDataloader(Dataset):
|
||||||
|
|
||||||
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
|
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
@ -266,17 +369,28 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self.dataset.get_batch(self.indexBatch[idx])
|
return self.dataset.get_batch(self.indexBatch[idx])
|
||||||
|
|
||||||
def mapping(self, idx):
|
def get_tree(self, idx):
|
||||||
return self.dataset.get_mapping_batch(self.indexBatch[idx])
|
return [self.dataset.get_tree(i) for i in self.indexBatch[idx]]
|
||||||
|
|
||||||
def print_mapping(self, idx):
|
def print_tree(self, idx):
|
||||||
return self.dataset.print_mapping_batch(self.indexBatch[idx])
|
idx_list = self.indexBatch[idx]
|
||||||
|
s = "--------------------------------------------------------\n"
|
||||||
|
for i in idx_list:
|
||||||
|
s += self.dataset.print_tree(i)
|
||||||
|
s += "--------------------------------------------------------\n"
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024)
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
||||||
train, val = md.split(0.95)
|
train, val = md.split(0.95)
|
||||||
|
fdaf = md.__getitem__(920)
|
||||||
|
print(md.print_tree(920))
|
||||||
|
print(md.idx[920])
|
||||||
|
fdasfe = MeaningDataset.get_seq_mask(md.idx[920], 1, 1)
|
||||||
|
print(fdasfe)
|
||||||
|
freq = md.token_frequency()
|
||||||
|
|
||||||
dl = BatchGroupMeaningDataloader(train, 2)
|
dl = BatchGroupMeaningDataloader(train, 2)
|
||||||
length = len(dl)
|
length = len(dl)
|
||||||
|
@ -285,9 +399,9 @@ if __name__ == "__main__":
|
||||||
ne2 = next(it)
|
ne2 = next(it)
|
||||||
ne3 = next(it)
|
ne3 = next(it)
|
||||||
|
|
||||||
map1 = dl.mapping(0)
|
map1 = dl.get_tree(0)
|
||||||
map2 = dl.mapping(1)
|
map2 = dl.get_tree(1)
|
||||||
print(dl.print_mapping(0))
|
print(dl.print_tree(0))
|
||||||
|
|
||||||
dl = DataLoader(
|
dl = DataLoader(
|
||||||
train,
|
train,
|
||||||
|
|
12
wit/train.py
12
wit/train.py
|
@ -17,7 +17,7 @@ pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
|
||||||
learning_rate = 0.0001
|
learning_rate = 0.0001
|
||||||
use_tril_attention_mask = None
|
use_tril_attention_mask = None
|
||||||
precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true"
|
precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true"
|
||||||
train_batch_size = 4
|
train_batch_size = 2
|
||||||
val_batch_size = 1
|
val_batch_size = 1
|
||||||
num_proc = 8
|
num_proc = 8
|
||||||
max_epochs = 1000
|
max_epochs = 1000
|
||||||
|
@ -25,14 +25,14 @@ strategy = "auto"
|
||||||
resume_from_ckpt_path = None
|
resume_from_ckpt_path = None
|
||||||
seed = 42
|
seed = 42
|
||||||
|
|
||||||
vocab_size = 1024
|
vocab_size = 256
|
||||||
level_ratio = 4
|
level_ratio = 6
|
||||||
level = 6
|
level = 4
|
||||||
dataset_level = 1
|
dataset_level = 1
|
||||||
|
|
||||||
hidden_size = 2048 # 128 1024 2048 32
|
hidden_size = 1024 # 128 1024 2048 32
|
||||||
num_attention_heads = 16 # 8 8 16
|
num_attention_heads = 16 # 8 8 16
|
||||||
num_hidden_layers = 12 # 6 12 24 3
|
num_hidden_layers = 3 # 6 12 24 3
|
||||||
|
|
||||||
name = "vocab_ratio_level_data_hidden_head_layer"
|
name = "vocab_ratio_level_data_hidden_head_layer"
|
||||||
ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{dataset_level}"
|
ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{dataset_level}"
|
||||||
|
|
Loading…
Reference in New Issue