Refine meaning dataset.
This commit is contained in:
		
							parent
							
								
									2bc9e3b57e
								
							
						
					
					
						commit
						33d1e22655
					
				| 
						 | 
					@ -0,0 +1,67 @@
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					a = np.array([0, 1, 32 + 1, (32 + 1) * 16, 4, 5, 6, 7, 8, 8]).astype(np.uint32)
 | 
				
			||||||
 | 
					b = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 8]).astype(np.uint32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					d = np.ones(a.shape, dtype=np.uint32)
 | 
				
			||||||
 | 
					d = (d * 0xFFFFFFFF) << (b * 4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					c = a.astype(np.uint32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc = (
 | 
				
			||||||
 | 
					    ((c & 0xF) << 28)
 | 
				
			||||||
 | 
					    + ((c & 0xF0) << 20)
 | 
				
			||||||
 | 
					    + ((c & 0xF00) << 12)
 | 
				
			||||||
 | 
					    + ((c & 0xF000) << 4)
 | 
				
			||||||
 | 
					    + ((c & 0xF0000) >> 4)
 | 
				
			||||||
 | 
					    + ((c & 0xF00000) >> 12)
 | 
				
			||||||
 | 
					    + ((c & 0xF000000) >> 20)
 | 
				
			||||||
 | 
					    + ((c & 0xF0000000) >> 28)
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					cc = (cc >> ((8 - b) * 4)) + d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(cc[3] == 4294963218)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					b = np.ones((10)).astype(np.int32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_tree_str_new(tree, prefix):
 | 
				
			||||||
 | 
					    if isinstance(tree, dict):
 | 
				
			||||||
 | 
					        base = ""
 | 
				
			||||||
 | 
					        last_is_dict = None
 | 
				
			||||||
 | 
					        for key, value in tree.items():
 | 
				
			||||||
 | 
					            new_prefix = (len(str(key)) + 2) * " " + prefix
 | 
				
			||||||
 | 
					            dict_string = get_tree_str_new(value, new_prefix)
 | 
				
			||||||
 | 
					            if dict_string:
 | 
				
			||||||
 | 
					                base += "\n" + prefix + str(key) + ": " + dict_string
 | 
				
			||||||
 | 
					                last_is_dict = True
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                base += "\n" + prefix + str(key) + "  " if last_is_dict else str(key) + "  "
 | 
				
			||||||
 | 
					                last_is_dict = False
 | 
				
			||||||
 | 
					        return base
 | 
				
			||||||
 | 
					    return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					tree = {
 | 
				
			||||||
 | 
					    112377: {
 | 
				
			||||||
 | 
					        2944: {228: 228, 263: 263, 252: 252, 396: 396},
 | 
				
			||||||
 | 
					        10024: {
 | 
				
			||||||
 | 
					            1424: {189: 189, 209: 209, 200: 200, 102: 102, 178: 178, 22: 22, 9: 9},
 | 
				
			||||||
 | 
					            1053: 432,
 | 
				
			||||||
 | 
					            1350: {68: 68, 200: 200, 50: 50, 17: 17, 36: 36, 283: 283},
 | 
				
			||||||
 | 
					            7: 7,
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        18196: 322,
 | 
				
			||||||
 | 
					        13373: {
 | 
				
			||||||
 | 
					            1420: {99: 99, 189: 189, 163: 163},
 | 
				
			||||||
 | 
					            2109: {320: 320, 92: 92, 95: 95, 224: 224, 435: 435, 4: 4, 373: 373, 27: 27, 228: 228},
 | 
				
			||||||
 | 
					            708: 708,
 | 
				
			||||||
 | 
					            2196: {27: 27, 157: 157, 87: 87, 231: 231},
 | 
				
			||||||
 | 
					            401: 401,
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(get_tree_str_new(tree, ""))
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,38 @@
 | 
				
			||||||
 | 
					# meaning dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 概念
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					1. token表示最终体现的基本数据表达,类似单词。vocab_size表示代表token的数量。
 | 
				
			||||||
 | 
					2. meaning表示一种语义(符号),所有的meaning都由一个编号表达,编号越大表示语义越复杂
 | 
				
			||||||
 | 
					3. 所有的meaning都可以由更低标号表达
 | 
				
			||||||
 | 
					4. 从0到vocab_size的编号表示基本meaning,是不能被拆解的,也就是token
 | 
				
			||||||
 | 
					5. meaning通过一层层的向低编号的meaning进行组合替换,最终形成一个最底层是token的树形数据
 | 
				
			||||||
 | 
					6. level表示当前token相对于root meaning的距离
 | 
				
			||||||
 | 
					7. idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的index,高位无用的位用1填充
 | 
				
			||||||
 | 
					8. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构
 | 
				
			||||||
 | 
					9. get_seq_mask返回一个sequence每个token在对应level是不是对应的index
 | 
				
			||||||
 | 
					10. meaning_height
 | 
				
			||||||
 | 
					11. meaning_weight
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					vocab_size = 256 meaning = 115200
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                         115200
 | 
				
			||||||
 | 
					                /          |           \
 | 
				
			||||||
 | 
					           10240          1100           12322
 | 
				
			||||||
 | 
					        /    |   \        /  \        /     |   \
 | 
				
			||||||
 | 
					    512     32   1201   245  233   3214    532    324
 | 
				
			||||||
 | 
					    / \          /   \             /  \     |     / \
 | 
				
			||||||
 | 
					  123 42      320     500        1231  23  324   93  176
 | 
				
			||||||
 | 
					              / \     / \        / \       / \
 | 
				
			||||||
 | 
					            176 11  255 129    129  99   211 111
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					sequence = 123 42 32 176 11 255 129 245 233 129 99 23 211 111 93 176
 | 
				
			||||||
 | 
					level    =  3   3  2  4   4  4   4   2   2   4   4  3  4   4   3  3
 | 
				
			||||||
 | 
					idx at 0 =  0   1  1  0   1  0   1   0   1   0   1  2  0   1   0  1
 | 
				
			||||||
 | 
					idx at 1 =  0   0  0  0   0  1   1   1   1   0   0  0  0   0   2  2
 | 
				
			||||||
 | 
					idx         0   1  1  0   1 16  17  16  17   0   1  2  0   1  32 33
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
| 
						 | 
					@ -11,8 +11,7 @@ from torch.utils.data import BatchSampler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MeaningMap:
 | 
					class MeaningMap:
 | 
				
			||||||
 | 
					    def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, use_cache=True):
 | 
				
			||||||
    def __init__(self, size=1048576, vocab_size=4096, max_subitem=10):
 | 
					 | 
				
			||||||
        self.size = size
 | 
					        self.size = size
 | 
				
			||||||
        self.vocab_size = vocab_size
 | 
					        self.vocab_size = vocab_size
 | 
				
			||||||
        self.max_subitem = max_subitem
 | 
					        self.max_subitem = max_subitem
 | 
				
			||||||
| 
						 | 
					@ -20,99 +19,186 @@ class MeaningMap:
 | 
				
			||||||
        path = "./data/"
 | 
					        path = "./data/"
 | 
				
			||||||
        file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
 | 
					        file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
 | 
				
			||||||
        file = path + file
 | 
					        file = path + file
 | 
				
			||||||
        file_map = file + "_map" + ".npy"
 | 
					        file_slhwm = file + "_slhwm" + ".npy"
 | 
				
			||||||
        file_start = file + "_start" + ".npy"
 | 
					        file_dli = file + "_dli" + ".npy"
 | 
				
			||||||
        file_len = file + "_len" + ".npy"
 | 
					 | 
				
			||||||
        file_data = file + "_data" + ".npy"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not os.path.exists(path):
 | 
					        if not os.path.exists(path):
 | 
				
			||||||
            os.mkdir(path)
 | 
					            os.mkdir(path)
 | 
				
			||||||
        if (
 | 
					        if os.path.exists(file_slhwm) and os.path.exists(file_dli) and use_cache:
 | 
				
			||||||
            os.path.exists(file_start)
 | 
					 | 
				
			||||||
            and os.path.exists(file_len)
 | 
					 | 
				
			||||||
            and os.path.exists(file_data)
 | 
					 | 
				
			||||||
            and os.path.exists(file_map)
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            print("Load from disk cache: " + file)
 | 
					            print("Load from disk cache: " + file)
 | 
				
			||||||
            self.ms_map = np.load(file_map)
 | 
					            slhwm = np.load(file_slhwm)
 | 
				
			||||||
            self.ms_data = np.load(file_data)
 | 
					            dli = np.load(file_dli)
 | 
				
			||||||
            self.ms_start = np.load(file_start)
 | 
					            self.ms_map = slhwm[:, 4:]
 | 
				
			||||||
            self.ms_len = np.load(file_len)
 | 
					            self.ms_data = dli[:, 0]
 | 
				
			||||||
 | 
					            self.ms_start = slhwm[:, 0]
 | 
				
			||||||
 | 
					            self.ms_len = slhwm[:, 1]
 | 
				
			||||||
 | 
					            self.ms_level = dli[:, 1]
 | 
				
			||||||
 | 
					            self.ms_idx = dli[:, 2].astype(np.uint32)
 | 
				
			||||||
 | 
					            self.ms_height = slhwm[:, 2]
 | 
				
			||||||
 | 
					            self.ms_weight = slhwm[:, 3]
 | 
				
			||||||
            print("Load end")
 | 
					            print("Load end")
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            print("Disk cache miss, build new one.")
 | 
					            print("Disk cache miss, build new one.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            mm = np.empty((size, max_subitem), dtype=np.int32)
 | 
					            map = np.empty((size, max_subitem), dtype=np.uint32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            index = np.arange(0, size)
 | 
					            index = np.arange(0, size)
 | 
				
			||||||
            mm = np.random.random((size, max_subitem))
 | 
					            map = np.random.random((size, max_subitem))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            mask_zero = mm.copy()
 | 
					            mask_zero = map.copy()
 | 
				
			||||||
            mask_zero[:, 0] = 0.0
 | 
					            mask_zero[:, 0] = 0.0
 | 
				
			||||||
            mask_zero.sort(axis=1)
 | 
					            mask_zero.sort(axis=1)
 | 
				
			||||||
            thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
 | 
					            thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
 | 
				
			||||||
            mask_zero = mask_zero > thre
 | 
					            mask_zero = mask_zero > thre
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            item_sum = mm.sum(axis=1)
 | 
					            item_sum = map.sum(axis=1)
 | 
				
			||||||
            scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
 | 
					            scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
 | 
				
			||||||
            mm = mm * scale
 | 
					            map = map * scale
 | 
				
			||||||
            mm[mask_zero] = 0
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            mm[:vocab_size, 0] = np.arange(0, vocab_size)
 | 
					            map[mask_zero] = 0
 | 
				
			||||||
            mm[:vocab_size, 1:] = 0
 | 
					 | 
				
			||||||
            mm = mm.astype(np.int32)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            ms = []  # meaning sequence
 | 
					            map[:vocab_size, 0] = np.arange(0, vocab_size)
 | 
				
			||||||
 | 
					            map[:vocab_size, 1:] = 0
 | 
				
			||||||
 | 
					            map = map.astype(np.uint32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            ms_data = []  # meaning sequence
 | 
				
			||||||
 | 
					            ms_level = []  # meaning level, vocab's level is 0
 | 
				
			||||||
 | 
					            ms_idx = []  # meaning index of lowest level
 | 
				
			||||||
            ms_start = []  # meaning sequence start
 | 
					            ms_start = []  # meaning sequence start
 | 
				
			||||||
            ms_len = []  # meaning sequence length
 | 
					            ms_len = []  # meaning sequence length
 | 
				
			||||||
 | 
					            ms_height = []  # meaning tree height
 | 
				
			||||||
 | 
					            ms_weight = []  # meaning tree weight
 | 
				
			||||||
            index = 0
 | 
					            index = 0
 | 
				
			||||||
            for i in range(self.vocab_size):
 | 
					            for i in range(self.vocab_size):
 | 
				
			||||||
                ms.append([i])
 | 
					                ms_data.append([i])
 | 
				
			||||||
 | 
					                ms_level.append([0])
 | 
				
			||||||
 | 
					                ms_idx.append([0])
 | 
				
			||||||
                ms_start.append(index)
 | 
					                ms_start.append(index)
 | 
				
			||||||
                ms_len.append(1)
 | 
					                ms_len.append(1)
 | 
				
			||||||
                index = index + 1
 | 
					                index = index + 1
 | 
				
			||||||
 | 
					                ms_height.append(0)
 | 
				
			||||||
 | 
					                ms_weight.append(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for i in range(self.vocab_size, size):
 | 
					            for i in range(self.vocab_size, size):
 | 
				
			||||||
                m = mm[i]
 | 
					                m = map[i]
 | 
				
			||||||
                m = m[m > 0]
 | 
					                m = m[m > 0]
 | 
				
			||||||
                ma = []
 | 
					                ma = []
 | 
				
			||||||
                for newm in m.tolist():
 | 
					                ml = []
 | 
				
			||||||
                    ma = ma + ms[newm]
 | 
					                mi = []
 | 
				
			||||||
                ms.append(ma)
 | 
					                for i, newm in enumerate(m.tolist()):
 | 
				
			||||||
 | 
					                    ma = ma + ms_data[newm]
 | 
				
			||||||
 | 
					                    ml = ml + [x + 1 for x in ms_level[newm]]
 | 
				
			||||||
 | 
					                    mi = mi + ([0xFFFFFFF0 + i] if newm < self.vocab_size else [n * 16 + i for n in ms_idx[newm]])
 | 
				
			||||||
 | 
					                ms_data.append(ma)
 | 
				
			||||||
                ms_start.append(index)
 | 
					                ms_start.append(index)
 | 
				
			||||||
                ms_len.append(len(ma))
 | 
					                ms_len.append(len(ma))
 | 
				
			||||||
 | 
					                ms_level.append(ml)
 | 
				
			||||||
 | 
					                ms_idx.append(mi)
 | 
				
			||||||
                index = index + len(ma)
 | 
					                index = index + len(ma)
 | 
				
			||||||
 | 
					                ms_height.append(max([-1] + [ms_height[sub_m] for sub_m in m.tolist()]) + 1)
 | 
				
			||||||
 | 
					                ms_weight.append(sum(ms_weight[sub_m] for sub_m in m.tolist()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            ms_data = list(chain(*ms))
 | 
					            # offsets = [0, 0, 4, 8, 12, 16, 20, 24, 28]
 | 
				
			||||||
            np.save(file_map, np.array(mm).astype(np.int32))
 | 
					            # for idxmi, mi in enumerate(ms_idx):
 | 
				
			||||||
            np.save(file_data, np.array(ms_data).astype(np.int32))
 | 
					            #     level = ms_level[idxmi]
 | 
				
			||||||
            np.save(file_start, np.array(ms_start).astype(np.int32))
 | 
					            #     for idxnum, num in enumerate(mi):
 | 
				
			||||||
            np.save(file_len, np.array(ms_len).astype(np.int32))
 | 
					            #         l = level[idxnum]
 | 
				
			||||||
 | 
					            #         elements = [(num >> offset) & 0xF for offset in offsets[l:0:-1]]
 | 
				
			||||||
 | 
					            #         num = (num >> (l * 4)) << (l * 4)
 | 
				
			||||||
 | 
					            #         num += sum(elem << (i * 4) for i, elem in enumerate(elements))
 | 
				
			||||||
 | 
					            #         mi[idxnum] = num
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self.ms_map = mm
 | 
					            ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
 | 
				
			||||||
            self.ms_data = ms_data
 | 
					            ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
 | 
				
			||||||
 | 
					            ms_idx = np.array(list(chain(*ms_idx))).astype(np.uint32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            d = np.ones(ms_idx.shape, dtype=np.uint32)
 | 
				
			||||||
 | 
					            d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
 | 
				
			||||||
 | 
					            ms_idx = (
 | 
				
			||||||
 | 
					                ((ms_idx & 0xF) << 28)
 | 
				
			||||||
 | 
					                + ((ms_idx & 0xF0) << 20)
 | 
				
			||||||
 | 
					                + ((ms_idx & 0xF00) << 12)
 | 
				
			||||||
 | 
					                + ((ms_idx & 0xF000) << 4)
 | 
				
			||||||
 | 
					                + ((ms_idx & 0xF0000) >> 4)
 | 
				
			||||||
 | 
					                + ((ms_idx & 0xF00000) >> 12)
 | 
				
			||||||
 | 
					                + ((ms_idx & 0xF000000) >> 20)
 | 
				
			||||||
 | 
					                + ((ms_idx & 0xF0000000) >> 28)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            ms_idx = ((ms_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            ms_start = np.array(ms_start).astype(np.uint32)
 | 
				
			||||||
 | 
					            ms_height = np.array(ms_height).astype(np.uint32)
 | 
				
			||||||
 | 
					            ms_weight = np.array(ms_weight).astype(np.uint32)
 | 
				
			||||||
 | 
					            ms_len = np.array(ms_len).astype(np.uint32)
 | 
				
			||||||
 | 
					            ms_map = map.astype(np.uint32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            slhwm = np.concatenate(
 | 
				
			||||||
 | 
					                (
 | 
				
			||||||
 | 
					                    ms_start.reshape((-1, 1)),
 | 
				
			||||||
 | 
					                    ms_len.reshape((-1, 1)),
 | 
				
			||||||
 | 
					                    ms_height.reshape((-1, 1)),
 | 
				
			||||||
 | 
					                    ms_weight.reshape((-1, 1)),
 | 
				
			||||||
 | 
					                    ms_map,
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					                axis=1,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            np.save(file_slhwm, slhwm)
 | 
				
			||||||
 | 
					            np.save(file_dli, dli)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.ms_map = map  # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
 | 
				
			||||||
 | 
					            self.ms_data = ms_data  # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
 | 
				
			||||||
            self.ms_start = ms_start
 | 
					            self.ms_start = ms_start
 | 
				
			||||||
            self.ms_len = ms_len
 | 
					            self.ms_len = ms_len
 | 
				
			||||||
 | 
					            self.ms_level = ms_level
 | 
				
			||||||
 | 
					            self.ms_idx = ms_idx
 | 
				
			||||||
 | 
					            self.ms_height = ms_height
 | 
				
			||||||
 | 
					            self.ms_weight = ms_weight
 | 
				
			||||||
            print("Disk cache build end.")
 | 
					            print("Disk cache build end.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_sequence(self, meaning):
 | 
					    def get_sequence(self, meaning):  # return sequence[meaning]
 | 
				
			||||||
        start = self.ms_start[meaning]
 | 
					        start = self.ms_start[meaning]
 | 
				
			||||||
        len = self.ms_len[meaning]
 | 
					        len = self.ms_len[meaning]
 | 
				
			||||||
        return self.ms_data[start : start + len]
 | 
					        return self.ms_data[start : start + len], self.ms_level[start : start + len], self.ms_idx[start : start + len]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_mapping(self, meaning):
 | 
					    def get_tree(self, meaning):  # return meaning all sub items
 | 
				
			||||||
        mapping = {}
 | 
					        tree = {}
 | 
				
			||||||
        ms = self.ms_map[meaning]
 | 
					        ms = self.ms_map[meaning]
 | 
				
			||||||
        for m in ms[ms > 0].tolist():
 | 
					        for m in ms[ms > 0].tolist():
 | 
				
			||||||
            mapping[m] = self.get_mapping(m) if m >= self.vocab_size else m
 | 
					            tree[m] = self.get_tree(m) if m >= self.vocab_size else m
 | 
				
			||||||
        return mapping
 | 
					        return tree
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def max_length(self):
 | 
					    def max_length(self):
 | 
				
			||||||
        return max(self.ms_len)
 | 
					        return max(self.ms_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_tree_str(tree, prefix):
 | 
				
			||||||
 | 
					        if isinstance(tree, dict):
 | 
				
			||||||
 | 
					            base = ""
 | 
				
			||||||
 | 
					            last_is_dict = None
 | 
				
			||||||
 | 
					            for key, value in tree.items():
 | 
				
			||||||
 | 
					                new_prefix = (len(str(key)) + 2) * " " + prefix
 | 
				
			||||||
 | 
					                dict_string = MeaningMap.get_tree_str(value, new_prefix)
 | 
				
			||||||
 | 
					                if dict_string:
 | 
				
			||||||
 | 
					                    base += "\n" + prefix + str(key) + ": " + dict_string
 | 
				
			||||||
 | 
					                    last_is_dict = True
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    base += "\n" + prefix + str(key) + "  " if last_is_dict else str(key) + "  "
 | 
				
			||||||
 | 
					                    last_is_dict = False
 | 
				
			||||||
 | 
					            return base
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def token_frequency(tree, freq):
 | 
				
			||||||
 | 
					        if isinstance(tree, dict):
 | 
				
			||||||
 | 
					            for key, value in tree.items():
 | 
				
			||||||
 | 
					                if key in freq:
 | 
				
			||||||
 | 
					                    freq[key] = freq[key] + 1
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    freq[key] = 1
 | 
				
			||||||
 | 
					                MeaningMap.token_frequency(value, freq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MeaningDataset(Dataset):
 | 
					class MeaningDataset(Dataset):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        start=131072,
 | 
					        start=131072,
 | 
				
			||||||
| 
						 | 
					@ -124,25 +210,34 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
        seed=42,
 | 
					        seed=42,
 | 
				
			||||||
        data=None,
 | 
					        data=None,
 | 
				
			||||||
        length=None,
 | 
					        length=None,
 | 
				
			||||||
        mapping=None,
 | 
					        tree=None,
 | 
				
			||||||
 | 
					        level=None,
 | 
				
			||||||
 | 
					        idx=None,
 | 
				
			||||||
 | 
					        use_cache=True,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        if data != None and length != None and mapping != None:
 | 
					        if data != None and length != None and tree != None and level != None and idx != None:
 | 
				
			||||||
            self.data = data
 | 
					            self.data = data
 | 
				
			||||||
            self.length = length
 | 
					            self.length = length
 | 
				
			||||||
            self.mapping = mapping
 | 
					            self.tree = tree
 | 
				
			||||||
 | 
					            self.level = level
 | 
				
			||||||
 | 
					            self.idx = idx
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        np.random.seed(seed)
 | 
					        np.random.seed(seed)
 | 
				
			||||||
        mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem)  # 1048576
 | 
					        map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
 | 
				
			||||||
        self.mapping = []
 | 
					        self.tree = []
 | 
				
			||||||
        self.data = []
 | 
					        self.data = []
 | 
				
			||||||
 | 
					        self.level = []
 | 
				
			||||||
 | 
					        self.idx = []
 | 
				
			||||||
        self.length = []
 | 
					        self.length = []
 | 
				
			||||||
        meanings = np.random.randint(start, end, size=(size))
 | 
					        meanings = np.random.randint(start, end, size=(size))
 | 
				
			||||||
        for m in meanings:
 | 
					        for m in meanings:
 | 
				
			||||||
            sq = mm.get_sequence(m)
 | 
					            d, l, i = map.get_sequence(m)
 | 
				
			||||||
            if len(sq) >= min_seq_len:
 | 
					            if len(d) >= min_seq_len:
 | 
				
			||||||
                self.mapping.append({m: mm.get_mapping(m)})
 | 
					                self.tree.append({m: map.get_tree(m)})
 | 
				
			||||||
                self.data.append(sq)
 | 
					                self.data.append(d)
 | 
				
			||||||
                self.length.append(len(sq))
 | 
					                self.level.append(l)
 | 
				
			||||||
 | 
					                self.idx.append(i)
 | 
				
			||||||
 | 
					                self.length.append(len(d))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        unique, counts = np.unique(self.length, return_counts=True)
 | 
					        unique, counts = np.unique(self.length, return_counts=True)
 | 
				
			||||||
        print("----------------------------------------------------------------")
 | 
					        print("----------------------------------------------------------------")
 | 
				
			||||||
| 
						 | 
					@ -164,50 +259,34 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
        output["input_ids"] = data
 | 
					        output["input_ids"] = data
 | 
				
			||||||
        output["labels"] = data.clone()
 | 
					        output["labels"] = data.clone()
 | 
				
			||||||
        output["token_type_ids"] = torch.zeros(data.shape)
 | 
					        output["token_type_ids"] = torch.zeros(data.shape)
 | 
				
			||||||
 | 
					        output["tree"] = self.tree[idx]
 | 
				
			||||||
 | 
					        output["level"] = self.level[idx]
 | 
				
			||||||
 | 
					        output["idx"] = self.idx[idx]
 | 
				
			||||||
        return output
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_batch(self, index_list):  # must equal sequence length
 | 
					    def get_batch(self, idx_list):  # must equal sequence length
 | 
				
			||||||
        data = [self.data[i] for i in index_list]
 | 
					        data = [self.data[i] for i in idx_list]
 | 
				
			||||||
        output = {}
 | 
					        output = {}
 | 
				
			||||||
        data = torch.tensor(np.stack(data, axis=0)).long()
 | 
					        data = torch.tensor(np.stack(data, axis=0)).long()
 | 
				
			||||||
        output["input_ids"] = data
 | 
					        output["input_ids"] = data
 | 
				
			||||||
        output["labels"] = data.clone()
 | 
					        output["labels"] = data.clone()
 | 
				
			||||||
        output["token_type_ids"] = torch.zeros(data.shape)
 | 
					        output["token_type_ids"] = torch.zeros(data.shape)
 | 
				
			||||||
 | 
					        output["tree"] = [self.tree[i] for i in idx_list]
 | 
				
			||||||
 | 
					        output["level"] = [self.level[i] for i in idx_list]
 | 
				
			||||||
 | 
					        output["idx"] = [self.idx[i] for i in idx_list]
 | 
				
			||||||
        return output
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_token_batch(self, index_list):  # must equal sequence length
 | 
					    def get_token(self, idx):  # must equal sequence length
 | 
				
			||||||
        return [self.data[i] for i in index_list]
 | 
					        return self.data[idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def print_token_batch(self, index_list):  # must equal sequence length
 | 
					    def get_tree(self, idx):
 | 
				
			||||||
        data = [self.data[i] for i in index_list]
 | 
					        return self.tree[idx]
 | 
				
			||||||
        output = {}
 | 
					 | 
				
			||||||
        data = torch.tensor(np.stack(data, axis=0)).long()
 | 
					 | 
				
			||||||
        output["input_ids"] = data
 | 
					 | 
				
			||||||
        output["labels"] = data.clone()
 | 
					 | 
				
			||||||
        output["token_type_ids"] = torch.zeros(data.shape)
 | 
					 | 
				
			||||||
        return output
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_mapping_batch(self, index_list):
 | 
					    def print_tree(self, idx):
 | 
				
			||||||
        return [self.mapping[i] for i in index_list]
 | 
					        tokens = self.data[idx]
 | 
				
			||||||
 | 
					        tree = self.get_tree(idx)
 | 
				
			||||||
    def __get_mapping_str__(map, prefix):
 | 
					        s = str(tokens) + "\n"
 | 
				
			||||||
        if isinstance(map, dict):
 | 
					        s += MeaningMap.get_tree_str(tree, "")
 | 
				
			||||||
            base = ""
 | 
					 | 
				
			||||||
            for key, value in map.items():
 | 
					 | 
				
			||||||
                base += prefix + str(key) + "\n"
 | 
					 | 
				
			||||||
                base += MeaningDataset.__get_mapping_str__(value, prefix + "    ")
 | 
					 | 
				
			||||||
            return base
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return ""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def print_mapping_batch(self, index_list):
 | 
					 | 
				
			||||||
        tokens = self.get_token_batch(index_list)
 | 
					 | 
				
			||||||
        map = self.get_mapping_batch(index_list)
 | 
					 | 
				
			||||||
        s = "--------------------------------------------------------\n"
 | 
					 | 
				
			||||||
        for i, m in enumerate(map):
 | 
					 | 
				
			||||||
            s += str(tokens[i]) + "\n"
 | 
					 | 
				
			||||||
            s += MeaningDataset.__get_mapping_str__(m, "")
 | 
					 | 
				
			||||||
            s += "--------------------------------------------------------\n"
 | 
					 | 
				
			||||||
        return s
 | 
					        return s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def split(self, ratio):
 | 
					    def split(self, ratio):
 | 
				
			||||||
| 
						 | 
					@ -215,14 +294,38 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
        middle = int(l * ratio)
 | 
					        middle = int(l * ratio)
 | 
				
			||||||
        d_shuffle = self.data.copy()
 | 
					        d_shuffle = self.data.copy()
 | 
				
			||||||
        l_shuffle = self.length.copy()
 | 
					        l_shuffle = self.length.copy()
 | 
				
			||||||
        m_shuffle = self.mapping.copy()
 | 
					        m_shuffle = self.tree.copy()
 | 
				
			||||||
        md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle], mapping=m_shuffle[:middle])
 | 
					        level_shuffle = self.level.copy()
 | 
				
			||||||
        md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:], mapping=m_shuffle[middle:])
 | 
					        i_shuffle = self.idx.copy()
 | 
				
			||||||
 | 
					        md1 = MeaningDataset(
 | 
				
			||||||
 | 
					            data=d_shuffle[:middle],
 | 
				
			||||||
 | 
					            length=l_shuffle[:middle],
 | 
				
			||||||
 | 
					            tree=m_shuffle[:middle],
 | 
				
			||||||
 | 
					            level=level_shuffle[:middle],
 | 
				
			||||||
 | 
					            idx=i_shuffle[:middle],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        md2 = MeaningDataset(
 | 
				
			||||||
 | 
					            data=d_shuffle[middle:],
 | 
				
			||||||
 | 
					            length=l_shuffle[middle:],
 | 
				
			||||||
 | 
					            tree=m_shuffle[middle:],
 | 
				
			||||||
 | 
					            level=level_shuffle[middle:],
 | 
				
			||||||
 | 
					            idx=i_shuffle[middle:],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        return md1, md2
 | 
					        return md1, md2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def token_frequency(self):
 | 
				
			||||||
 | 
					        freq = {}
 | 
				
			||||||
 | 
					        for t in self.tree:
 | 
				
			||||||
 | 
					            MeaningMap.token_frequency(t, freq)
 | 
				
			||||||
 | 
					        return freq
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_seq_mask(idx, level, index):
 | 
				
			||||||
 | 
					        assert index < 15, "index must < 15"
 | 
				
			||||||
 | 
					        assert level < 8, "level must < 8"
 | 
				
			||||||
 | 
					        return [((int(i / (16**level)) & 0xF) == index) for i in idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BatchGroupMeaningDataloader(Dataset):
 | 
					class BatchGroupMeaningDataloader(Dataset):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
 | 
					    def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
 | 
				
			||||||
        self.dataset = dataset
 | 
					        self.dataset = dataset
 | 
				
			||||||
        self.batch_size = batch_size
 | 
					        self.batch_size = batch_size
 | 
				
			||||||
| 
						 | 
					@ -266,17 +369,28 @@ class BatchGroupMeaningDataloader(Dataset):
 | 
				
			||||||
    def __getitem__(self, idx):
 | 
					    def __getitem__(self, idx):
 | 
				
			||||||
        return self.dataset.get_batch(self.indexBatch[idx])
 | 
					        return self.dataset.get_batch(self.indexBatch[idx])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def mapping(self, idx):
 | 
					    def get_tree(self, idx):
 | 
				
			||||||
        return self.dataset.get_mapping_batch(self.indexBatch[idx])
 | 
					        return [self.dataset.get_tree(i) for i in self.indexBatch[idx]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def print_mapping(self, idx):
 | 
					    def print_tree(self, idx):
 | 
				
			||||||
        return self.dataset.print_mapping_batch(self.indexBatch[idx])
 | 
					        idx_list = self.indexBatch[idx]
 | 
				
			||||||
 | 
					        s = "--------------------------------------------------------\n"
 | 
				
			||||||
 | 
					        for i in idx_list:
 | 
				
			||||||
 | 
					            s += self.dataset.print_tree(i)
 | 
				
			||||||
 | 
					            s += "--------------------------------------------------------\n"
 | 
				
			||||||
 | 
					        return s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024)
 | 
					    md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
 | 
				
			||||||
    train, val = md.split(0.95)
 | 
					    train, val = md.split(0.95)
 | 
				
			||||||
 | 
					    fdaf = md.__getitem__(920)
 | 
				
			||||||
 | 
					    print(md.print_tree(920))
 | 
				
			||||||
 | 
					    print(md.idx[920])
 | 
				
			||||||
 | 
					    fdasfe = MeaningDataset.get_seq_mask(md.idx[920], 1, 1)
 | 
				
			||||||
 | 
					    print(fdasfe)
 | 
				
			||||||
 | 
					    freq = md.token_frequency()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    dl = BatchGroupMeaningDataloader(train, 2)
 | 
					    dl = BatchGroupMeaningDataloader(train, 2)
 | 
				
			||||||
    length = len(dl)
 | 
					    length = len(dl)
 | 
				
			||||||
| 
						 | 
					@ -285,9 +399,9 @@ if __name__ == "__main__":
 | 
				
			||||||
    ne2 = next(it)
 | 
					    ne2 = next(it)
 | 
				
			||||||
    ne3 = next(it)
 | 
					    ne3 = next(it)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    map1 = dl.mapping(0)
 | 
					    map1 = dl.get_tree(0)
 | 
				
			||||||
    map2 = dl.mapping(1)
 | 
					    map2 = dl.get_tree(1)
 | 
				
			||||||
    print(dl.print_mapping(0))
 | 
					    print(dl.print_tree(0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    dl = DataLoader(
 | 
					    dl = DataLoader(
 | 
				
			||||||
        train,
 | 
					        train,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										12
									
								
								wit/train.py
								
								
								
								
							
							
						
						
									
										12
									
								
								wit/train.py
								
								
								
								
							| 
						 | 
					@ -17,7 +17,7 @@ pretrain_model_name = None  # "qwen/Qwen-1_8B-Chat"
 | 
				
			||||||
learning_rate = 0.0001
 | 
					learning_rate = 0.0001
 | 
				
			||||||
use_tril_attention_mask = None
 | 
					use_tril_attention_mask = None
 | 
				
			||||||
precision = "32-true"  # "precision:bf16-mixed,16-mixed,32-true"
 | 
					precision = "32-true"  # "precision:bf16-mixed,16-mixed,32-true"
 | 
				
			||||||
train_batch_size = 4
 | 
					train_batch_size = 2
 | 
				
			||||||
val_batch_size = 1
 | 
					val_batch_size = 1
 | 
				
			||||||
num_proc = 8
 | 
					num_proc = 8
 | 
				
			||||||
max_epochs = 1000
 | 
					max_epochs = 1000
 | 
				
			||||||
| 
						 | 
					@ -25,14 +25,14 @@ strategy = "auto"
 | 
				
			||||||
resume_from_ckpt_path = None
 | 
					resume_from_ckpt_path = None
 | 
				
			||||||
seed = 42
 | 
					seed = 42
 | 
				
			||||||
 | 
					
 | 
				
			||||||
vocab_size = 1024
 | 
					vocab_size = 256
 | 
				
			||||||
level_ratio = 4
 | 
					level_ratio = 6
 | 
				
			||||||
level = 6
 | 
					level = 4
 | 
				
			||||||
dataset_level = 1
 | 
					dataset_level = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
hidden_size = 2048  # 128 1024 2048  32
 | 
					hidden_size = 1024  # 128 1024 2048  32
 | 
				
			||||||
num_attention_heads = 16  # 8 8 16
 | 
					num_attention_heads = 16  # 8 8 16
 | 
				
			||||||
num_hidden_layers = 12  # 6 12 24  3
 | 
					num_hidden_layers = 3  # 6 12 24  3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
name = "vocab_ratio_level_data_hidden_head_layer"
 | 
					name = "vocab_ratio_level_data_hidden_head_layer"
 | 
				
			||||||
ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{dataset_level}"
 | 
					ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{dataset_level}"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue