Update meaning dataset.
This commit is contained in:
		
							parent
							
								
									7560166b76
								
							
						
					
					
						commit
						7434427ec9
					
				| 
						 | 
					@ -7,14 +7,16 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
 | 
				
			||||||
1. token表示最终体现的基本数据表达,类似单词。vocab_size表示代表token的数量。
 | 
					1. token表示最终体现的基本数据表达,类似单词。vocab_size表示代表token的数量。
 | 
				
			||||||
2. meaning表示一种语义(符号),所有的meaning都由一个编号表达,编号越大表示语义越复杂
 | 
					2. meaning表示一种语义(符号),所有的meaning都由一个编号表达,编号越大表示语义越复杂
 | 
				
			||||||
3. 所有的meaning都可以由更低标号表达
 | 
					3. 所有的meaning都可以由更低标号表达
 | 
				
			||||||
4. 从0到vocab_size的编号表示基本meaning,是不能被拆解的,也就是token
 | 
					4. 从0到(vocab_size-1)的编号表示基本meaning,是不能被拆解的,也就是token
 | 
				
			||||||
5. meaning通过一层层的向低编号的meaning进行组合替换,最终形成一个最底层是token的树形数据
 | 
					5. meaning通过一层层的向低编号的meaning进行组合替换,最终形成一个最底层是token的树形数据
 | 
				
			||||||
6. level表示当前token相对于root meaning的距离
 | 
					6. level表示当前token相对于root meaning的距离
 | 
				
			||||||
7. idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的index,高位无用的位用1填充
 | 
					7. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充
 | 
				
			||||||
 | 
					7. rank_all表示当前token在不同层的分子个数,每4位表示在一层里面的编号,低4位表示最低层级的rank_all,高位无用的位用1填充
 | 
				
			||||||
8. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构
 | 
					8. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构
 | 
				
			||||||
9. get_seq_mask返回一个sequence每个token在对应level是不是对应的index
 | 
					9. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
 | 
				
			||||||
10. meaning_height
 | 
					10. meaning_height 当前meaning的总高度
 | 
				
			||||||
11. meaning_weight
 | 
					11. meaning_weight 当前meaning的总宽度
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
vocab_size = 256 meaning = 115200
 | 
					vocab_size = 256 meaning = 115200
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,7 @@ from typing import Dict, Tuple
 | 
				
			||||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
 | 
					from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from torch.utils.data import BatchSampler
 | 
					from torch.utils.data import BatchSampler
 | 
				
			||||||
 | 
					import copy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MeaningMap:
 | 
					class MeaningMap:
 | 
				
			||||||
| 
						 | 
					@ -18,7 +19,7 @@ class MeaningMap:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        path = "./data/"
 | 
					        path = "./data/"
 | 
				
			||||||
        file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
 | 
					        file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
 | 
				
			||||||
        file = path + file
 | 
					        file = path + file + ".npz"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not os.path.exists(path):
 | 
					        if not os.path.exists(path):
 | 
				
			||||||
            os.mkdir(path)
 | 
					            os.mkdir(path)
 | 
				
			||||||
| 
						 | 
					@ -26,13 +27,14 @@ class MeaningMap:
 | 
				
			||||||
            print("Load from disk cache: " + file)
 | 
					            print("Load from disk cache: " + file)
 | 
				
			||||||
            loaded = np.load(file)
 | 
					            loaded = np.load(file)
 | 
				
			||||||
            slhwm = loaded["slhwm"]
 | 
					            slhwm = loaded["slhwm"]
 | 
				
			||||||
            dli = loaded["dli"]
 | 
					            dlra = loaded["dlra"]
 | 
				
			||||||
            self.ms_map = slhwm[:, 4:]
 | 
					            self.ms_map = slhwm[:, 4:]
 | 
				
			||||||
            self.ms_data = dli[:, 0]
 | 
					            self.ms_data = dlra[:, 0]
 | 
				
			||||||
            self.ms_start = slhwm[:, 0]
 | 
					            self.ms_start = slhwm[:, 0]
 | 
				
			||||||
            self.ms_len = slhwm[:, 1]
 | 
					            self.ms_len = slhwm[:, 1]
 | 
				
			||||||
            self.ms_level = dli[:, 1]
 | 
					            self.ms_level = dlra[:, 1]
 | 
				
			||||||
            self.ms_idx = dli[:, 2].astype(np.uint32)
 | 
					            self.ms_rank_idx = dlra[:, 2].astype(np.uint32)
 | 
				
			||||||
 | 
					            self.ms_rank_all = dlra[:, 3].astype(np.uint32)
 | 
				
			||||||
            self.ms_height = slhwm[:, 2]
 | 
					            self.ms_height = slhwm[:, 2]
 | 
				
			||||||
            self.ms_weight = slhwm[:, 3]
 | 
					            self.ms_weight = slhwm[:, 3]
 | 
				
			||||||
            print("Load end")
 | 
					            print("Load end")
 | 
				
			||||||
| 
						 | 
					@ -61,7 +63,8 @@ class MeaningMap:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            ms_data = []  # meaning sequence
 | 
					            ms_data = []  # meaning sequence
 | 
				
			||||||
            ms_level = []  # meaning level, vocab's level is 0
 | 
					            ms_level = []  # meaning level, vocab's level is 0
 | 
				
			||||||
            ms_idx = []  # meaning index of lowest level
 | 
					            ms_rank_idx = []  # meaning index of all level
 | 
				
			||||||
 | 
					            ms_rank_all = []  # meaning all of all level
 | 
				
			||||||
            ms_start = []  # meaning sequence start
 | 
					            ms_start = []  # meaning sequence start
 | 
				
			||||||
            ms_len = []  # meaning sequence length
 | 
					            ms_len = []  # meaning sequence length
 | 
				
			||||||
            ms_height = []  # meaning tree height
 | 
					            ms_height = []  # meaning tree height
 | 
				
			||||||
| 
						 | 
					@ -70,7 +73,8 @@ class MeaningMap:
 | 
				
			||||||
            for i in range(self.vocab_size):
 | 
					            for i in range(self.vocab_size):
 | 
				
			||||||
                ms_data.append(np.array([i]))
 | 
					                ms_data.append(np.array([i]))
 | 
				
			||||||
                ms_level.append(np.array([0]))
 | 
					                ms_level.append(np.array([0]))
 | 
				
			||||||
                ms_idx.append(np.array([0]))
 | 
					                ms_rank_idx.append(np.array([0]))
 | 
				
			||||||
 | 
					                ms_rank_all.append(np.array([0]))
 | 
				
			||||||
                ms_start.append(index)
 | 
					                ms_start.append(index)
 | 
				
			||||||
                ms_len.append(1)
 | 
					                ms_len.append(1)
 | 
				
			||||||
                ms_height.append(0)
 | 
					                ms_height.append(0)
 | 
				
			||||||
| 
						 | 
					@ -79,59 +83,70 @@ class MeaningMap:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for i in range(self.vocab_size, size):
 | 
					            for i in range(self.vocab_size, size):
 | 
				
			||||||
                m = map[i]
 | 
					                m = map[i]
 | 
				
			||||||
                m = m[m >= 0]
 | 
					                m = m[m >= 0]  # donot cut off the map such as [0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                m_list = m.tolist()
 | 
					                m_list = m.tolist()
 | 
				
			||||||
 | 
					                m_len = len(m_list)
 | 
				
			||||||
                assert m_list, "map list can not be empty list"
 | 
					                assert m_list, "map list can not be empty list"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                ma = np.concatenate([ms_data[newm] for newm in m_list])
 | 
					                ma = np.concatenate([ms_data[newm] for newm in m_list])
 | 
				
			||||||
                ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
 | 
					                ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
 | 
				
			||||||
                mi = np.concatenate(
 | 
					                mr = np.concatenate(
 | 
				
			||||||
                    [
 | 
					                    [
 | 
				
			||||||
                        ([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_idx[newm] * 16 + i)
 | 
					                        ([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
 | 
				
			||||||
                        for i, newm in enumerate(m_list)
 | 
					                        for i, newm in enumerate(m_list)
 | 
				
			||||||
                    ]
 | 
					                    ]
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                ml = ml[ma > 0]
 | 
					                mrl = np.concatenate(
 | 
				
			||||||
                mi = mi[ma > 0]
 | 
					                    [
 | 
				
			||||||
                ma = ma[ma > 0]
 | 
					                        ([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
 | 
				
			||||||
 | 
					                        for i, newm in enumerate(m_list)
 | 
				
			||||||
 | 
					                    ]
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                # ml = ml[ma > 0] # cut off the 0 token, such as [12,32,0,42,32]
 | 
				
			||||||
 | 
					                # mr = mr[ma > 0]
 | 
				
			||||||
 | 
					                # mrl = mrl[ma > 0]
 | 
				
			||||||
 | 
					                # ma = ma[ma > 0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                ms_data.append(ma)
 | 
					                ms_data.append(ma)
 | 
				
			||||||
                ms_level.append(ml)
 | 
					                ms_level.append(ml)
 | 
				
			||||||
                ms_idx.append(mi)
 | 
					                ms_rank_idx.append(mr)
 | 
				
			||||||
 | 
					                ms_rank_all.append(mrl)
 | 
				
			||||||
                ms_start.append(index)
 | 
					                ms_start.append(index)
 | 
				
			||||||
                ms_len.append(len(ma))
 | 
					                ms_len.append(len(ma))
 | 
				
			||||||
                ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
 | 
					                ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
 | 
				
			||||||
                ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
 | 
					                ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
 | 
				
			||||||
                index = index + len(ma)
 | 
					                index = index + len(ma)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # offsets = [0, 0, 4, 8, 12, 16, 20, 24, 28]
 | 
					 | 
				
			||||||
            # for idxmi, mi in enumerate(ms_idx):
 | 
					 | 
				
			||||||
            #     level = ms_level[idxmi]
 | 
					 | 
				
			||||||
            #     for idxnum, num in enumerate(mi):
 | 
					 | 
				
			||||||
            #         l = level[idxnum]
 | 
					 | 
				
			||||||
            #         elements = [(num >> offset) & 0xF for offset in offsets[l:0:-1]]
 | 
					 | 
				
			||||||
            #         num = (num >> (l * 4)) << (l * 4)
 | 
					 | 
				
			||||||
            #         num += sum(elem << (i * 4) for i, elem in enumerate(elements))
 | 
					 | 
				
			||||||
            #         mi[idxnum] = num
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
 | 
					            ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
 | 
				
			||||||
            ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
 | 
					            ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
 | 
				
			||||||
            ms_idx = np.array(list(chain(*ms_idx))).astype(np.uint32)
 | 
					            ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
 | 
				
			||||||
 | 
					            ms_rank_all = np.array(list(chain(*ms_rank_all))).astype(np.uint32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            d = np.ones(ms_idx.shape, dtype=np.uint32)
 | 
					            d = np.ones(ms_rank_idx.shape, dtype=np.uint32)
 | 
				
			||||||
            d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
 | 
					            d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
 | 
				
			||||||
            ms_idx = (
 | 
					            ms_rank_idx = (
 | 
				
			||||||
                ((ms_idx & 0xF) << 28)
 | 
					                ((ms_rank_idx & 0xF) << 28)
 | 
				
			||||||
                + ((ms_idx & 0xF0) << 20)
 | 
					                + ((ms_rank_idx & 0xF0) << 20)
 | 
				
			||||||
                + ((ms_idx & 0xF00) << 12)
 | 
					                + ((ms_rank_idx & 0xF00) << 12)
 | 
				
			||||||
                + ((ms_idx & 0xF000) << 4)
 | 
					                + ((ms_rank_idx & 0xF000) << 4)
 | 
				
			||||||
                + ((ms_idx & 0xF0000) >> 4)
 | 
					                + ((ms_rank_idx & 0xF0000) >> 4)
 | 
				
			||||||
                + ((ms_idx & 0xF00000) >> 12)
 | 
					                + ((ms_rank_idx & 0xF00000) >> 12)
 | 
				
			||||||
                + ((ms_idx & 0xF000000) >> 20)
 | 
					                + ((ms_rank_idx & 0xF000000) >> 20)
 | 
				
			||||||
                + ((ms_idx & 0xF0000000) >> 28)
 | 
					                + ((ms_rank_idx & 0xF0000000) >> 28)
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            ms_idx = ((ms_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
 | 
					            ms_rank_idx = ((ms_rank_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
 | 
				
			||||||
 | 
					            ms_rank_all = (
 | 
				
			||||||
 | 
					                ((ms_rank_all & 0xF) << 28)
 | 
				
			||||||
 | 
					                + ((ms_rank_all & 0xF0) << 20)
 | 
				
			||||||
 | 
					                + ((ms_rank_all & 0xF00) << 12)
 | 
				
			||||||
 | 
					                + ((ms_rank_all & 0xF000) << 4)
 | 
				
			||||||
 | 
					                + ((ms_rank_all & 0xF0000) >> 4)
 | 
				
			||||||
 | 
					                + ((ms_rank_all & 0xF00000) >> 12)
 | 
				
			||||||
 | 
					                + ((ms_rank_all & 0xF000000) >> 20)
 | 
				
			||||||
 | 
					                + ((ms_rank_all & 0xF0000000) >> 28)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            ms_rank_all = ((ms_rank_all >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            ms_start = np.array(ms_start).astype(np.int32)
 | 
					            ms_start = np.array(ms_start).astype(np.int32)
 | 
				
			||||||
            ms_height = np.array(ms_height).astype(np.int32)
 | 
					            ms_height = np.array(ms_height).astype(np.int32)
 | 
				
			||||||
| 
						 | 
					@ -148,15 +163,17 @@ class MeaningMap:
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
                axis=1,
 | 
					                axis=1,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1)
 | 
					            dlra = np.stack((ms_data, ms_level, ms_rank_idx.astype(np.int32), ms_rank_all.astype(np.int32)), axis=1)
 | 
				
			||||||
            np.savez(file, slhwm=slhwm, dli=dli)
 | 
					            np.savez(file, slhwm=slhwm, dlra=dlra)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.ms_data = ms_data  # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
 | 
				
			||||||
 | 
					            self.ms_level = ms_level
 | 
				
			||||||
 | 
					            self.ms_rank_idx = ms_rank_idx
 | 
				
			||||||
 | 
					            self.ms_rank_all = ms_rank_all
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self.ms_map = map  # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
 | 
					            self.ms_map = map  # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
 | 
				
			||||||
            self.ms_data = ms_data  # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
 | 
					 | 
				
			||||||
            self.ms_start = ms_start
 | 
					            self.ms_start = ms_start
 | 
				
			||||||
            self.ms_len = ms_len
 | 
					            self.ms_len = ms_len
 | 
				
			||||||
            self.ms_level = ms_level
 | 
					 | 
				
			||||||
            self.ms_idx = ms_idx
 | 
					 | 
				
			||||||
            self.ms_height = ms_height
 | 
					            self.ms_height = ms_height
 | 
				
			||||||
            self.ms_weight = ms_weight
 | 
					            self.ms_weight = ms_weight
 | 
				
			||||||
            print("Disk cache build end.")
 | 
					            print("Disk cache build end.")
 | 
				
			||||||
| 
						 | 
					@ -164,7 +181,12 @@ class MeaningMap:
 | 
				
			||||||
    def get_sequence(self, meaning):  # return sequence[meaning]
 | 
					    def get_sequence(self, meaning):  # return sequence[meaning]
 | 
				
			||||||
        start = self.ms_start[meaning]
 | 
					        start = self.ms_start[meaning]
 | 
				
			||||||
        len = self.ms_len[meaning]
 | 
					        len = self.ms_len[meaning]
 | 
				
			||||||
        return self.ms_data[start : start + len], self.ms_level[start : start + len], self.ms_idx[start : start + len]
 | 
					        return (
 | 
				
			||||||
 | 
					            self.ms_data[start : start + len],
 | 
				
			||||||
 | 
					            self.ms_level[start : start + len],
 | 
				
			||||||
 | 
					            self.ms_rank_idx[start : start + len],
 | 
				
			||||||
 | 
					            self.ms_rank_all[start : start + len],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_tree(self, meaning):  # return meaning all sub items
 | 
					    def get_tree(self, meaning):  # return meaning all sub items
 | 
				
			||||||
        tree = {}
 | 
					        tree = {}
 | 
				
			||||||
| 
						 | 
					@ -203,73 +225,70 @@ class MeaningMap:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MeaningDataset(Dataset):
 | 
					class MeaningDataset(Dataset):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        start=131072,
 | 
					        start,
 | 
				
			||||||
        end=1048576,
 | 
					        end,
 | 
				
			||||||
        size=32768,
 | 
					        size,
 | 
				
			||||||
        vocab_size=4096,
 | 
					        vocab_size,
 | 
				
			||||||
        max_subitem=10,
 | 
					        max_subitem=10,
 | 
				
			||||||
        min_seq_len=2,
 | 
					        min_seq_len=2,
 | 
				
			||||||
        seed=42,
 | 
					        seed=42,
 | 
				
			||||||
        data=None,
 | 
					 | 
				
			||||||
        length=None,
 | 
					 | 
				
			||||||
        tree=None,
 | 
					 | 
				
			||||||
        level=None,
 | 
					 | 
				
			||||||
        idx=None,
 | 
					 | 
				
			||||||
        use_cache=True,
 | 
					        use_cache=True,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        if data != None and length != None and tree != None and level != None and idx != None:
 | 
					 | 
				
			||||||
            self.data = data
 | 
					 | 
				
			||||||
            self.length = length
 | 
					 | 
				
			||||||
            self.tree = tree
 | 
					 | 
				
			||||||
            self.level = level
 | 
					 | 
				
			||||||
            self.idx = idx
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
        np.random.seed(seed)
 | 
					        np.random.seed(seed)
 | 
				
			||||||
        map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
 | 
					        map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
 | 
				
			||||||
 | 
					        np.random.seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.tree = []
 | 
					        self.tree = []
 | 
				
			||||||
        self.data = []
 | 
					        self.seq = []
 | 
				
			||||||
        self.level = []
 | 
					        self.level = []
 | 
				
			||||||
        self.idx = []
 | 
					        self.rank_idx = []
 | 
				
			||||||
        self.length = []
 | 
					        self.rank_all = []
 | 
				
			||||||
 | 
					        self.seq_meaning = []
 | 
				
			||||||
 | 
					        self.m_height = map.ms_height
 | 
				
			||||||
 | 
					        self.m_weight = map.ms_weight
 | 
				
			||||||
        meanings = np.random.randint(start, end, size=(size))
 | 
					        meanings = np.random.randint(start, end, size=(size))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        seq_len = []
 | 
				
			||||||
        for m in meanings:
 | 
					        for m in meanings:
 | 
				
			||||||
            d, l, i = map.get_sequence(m)
 | 
					            d, l, i, a = map.get_sequence(m)
 | 
				
			||||||
            if len(d) >= min_seq_len:
 | 
					            if len(d) >= min_seq_len:
 | 
				
			||||||
                self.tree.append({m: map.get_tree(m)})
 | 
					                self.tree.append({m: map.get_tree(m)})
 | 
				
			||||||
                self.data.append(d)
 | 
					                self.seq.append(d)
 | 
				
			||||||
                self.level.append(l)
 | 
					                self.level.append(l)
 | 
				
			||||||
                self.idx.append(i)
 | 
					                self.rank_idx.append(i)
 | 
				
			||||||
                self.length.append(len(d))
 | 
					                self.rank_all.append(a)
 | 
				
			||||||
 | 
					                self.seq_meaning.append(m)
 | 
				
			||||||
 | 
					                seq_len.append(len(d))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        unique, counts = np.unique(self.length, return_counts=True)
 | 
					        unique, counts = np.unique(seq_len, return_counts=True)
 | 
				
			||||||
        print("----------------------------------------------------------------")
 | 
					        print("----------------------------------------------------------------")
 | 
				
			||||||
        print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
 | 
					        print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
 | 
				
			||||||
        print("MeaningDataset size:" + str(len(self.length)))
 | 
					        print("MeaningDataset size:" + str(len(seq_len)))
 | 
				
			||||||
        print("MeaningDataset max sequence length:" + str(max(unique)))
 | 
					        print("MeaningDataset max sequence length:" + str(max(unique)))
 | 
				
			||||||
        print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
 | 
					        print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
 | 
				
			||||||
        print("----------------------------------------------------------------")
 | 
					        print("----------------------------------------------------------------")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __len__(self):
 | 
					    def __len__(self):
 | 
				
			||||||
        return len(self.data)
 | 
					        return len(self.seq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def len(self):
 | 
					    def len(self):
 | 
				
			||||||
        return len(self.data)
 | 
					        return len(self.seq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __getitem__(self, idx):
 | 
					    def __getitem__(self, idx):
 | 
				
			||||||
        output = {}
 | 
					        output = {}
 | 
				
			||||||
        data = torch.tensor(self.data[idx]).long()
 | 
					        data = torch.tensor(self.seq[idx]).long()
 | 
				
			||||||
        output["input_ids"] = data
 | 
					        output["input_ids"] = data
 | 
				
			||||||
        output["labels"] = data.clone()
 | 
					        output["labels"] = data.clone()
 | 
				
			||||||
        output["token_type_ids"] = torch.zeros(data.shape)
 | 
					        output["token_type_ids"] = torch.zeros(data.shape)
 | 
				
			||||||
        output["tree"] = self.tree[idx]
 | 
					        output["tree"] = self.tree[idx]
 | 
				
			||||||
        output["level"] = self.level[idx]
 | 
					        output["level"] = self.level[idx]
 | 
				
			||||||
        output["idx"] = self.idx[idx]
 | 
					 | 
				
			||||||
        return output
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_batch(self, idx_list):  # must equal sequence length
 | 
					    def get_batch(self, idx_list):  # must equal sequence length
 | 
				
			||||||
        data = [self.data[i] for i in idx_list]
 | 
					        data = [self.seq[i] for i in idx_list]
 | 
				
			||||||
        output = {}
 | 
					        output = {}
 | 
				
			||||||
        data = torch.tensor(np.stack(data, axis=0)).long()
 | 
					        data = torch.tensor(np.stack(data, axis=0)).long()
 | 
				
			||||||
        output["input_ids"] = data
 | 
					        output["input_ids"] = data
 | 
				
			||||||
| 
						 | 
					@ -277,45 +296,35 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
        output["token_type_ids"] = torch.zeros(data.shape)
 | 
					        output["token_type_ids"] = torch.zeros(data.shape)
 | 
				
			||||||
        output["tree"] = [self.tree[i] for i in idx_list]
 | 
					        output["tree"] = [self.tree[i] for i in idx_list]
 | 
				
			||||||
        output["level"] = [self.level[i] for i in idx_list]
 | 
					        output["level"] = [self.level[i] for i in idx_list]
 | 
				
			||||||
        output["idx"] = [self.idx[i] for i in idx_list]
 | 
					 | 
				
			||||||
        return output
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_token(self, idx):  # must equal sequence length
 | 
					    def get_token(self, idx):  # must equal sequence length
 | 
				
			||||||
        return self.data[idx]
 | 
					        return self.seq[idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_tree(self, idx):
 | 
					    def get_tree(self, idx):
 | 
				
			||||||
        return self.tree[idx]
 | 
					        return self.tree[idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def print_tree(self, idx):
 | 
					    def print_tree(self, idx):
 | 
				
			||||||
        tokens = self.data[idx]
 | 
					        tokens = self.seq[idx]
 | 
				
			||||||
        tree = self.get_tree(idx)
 | 
					        tree = self.get_tree(idx)
 | 
				
			||||||
        s = str(tokens) + "\n"
 | 
					        s = str(tokens) + "\n"
 | 
				
			||||||
        s += MeaningMap.get_tree_str(tree, "")
 | 
					        s += MeaningMap.get_tree_str(tree, "")
 | 
				
			||||||
        return s
 | 
					        return s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def copy(self, start, end):
 | 
				
			||||||
 | 
					        new = copy.deepcopy(self)
 | 
				
			||||||
 | 
					        new.tree = new.tree[start:end]
 | 
				
			||||||
 | 
					        new.seq = new.seq[start:end]
 | 
				
			||||||
 | 
					        new.level = new.level[start:end]
 | 
				
			||||||
 | 
					        new.rank_idx = new.rank_idx[start:end]
 | 
				
			||||||
 | 
					        new.rank_all = new.rank_all[start:end]
 | 
				
			||||||
 | 
					        new.seq_meaning = new.seq_meaning[start:end]
 | 
				
			||||||
 | 
					        return new
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def split(self, ratio):
 | 
					    def split(self, ratio):
 | 
				
			||||||
        l = len(self.data)
 | 
					        l = self.len()
 | 
				
			||||||
        middle = int(l * ratio)
 | 
					        middle = int(l * ratio)
 | 
				
			||||||
        d_shuffle = self.data.copy()
 | 
					        return self.copy(0, middle), self.copy(middle, l)
 | 
				
			||||||
        l_shuffle = self.length.copy()
 | 
					 | 
				
			||||||
        m_shuffle = self.tree.copy()
 | 
					 | 
				
			||||||
        level_shuffle = self.level.copy()
 | 
					 | 
				
			||||||
        i_shuffle = self.idx.copy()
 | 
					 | 
				
			||||||
        md1 = MeaningDataset(
 | 
					 | 
				
			||||||
            data=d_shuffle[:middle],
 | 
					 | 
				
			||||||
            length=l_shuffle[:middle],
 | 
					 | 
				
			||||||
            tree=m_shuffle[:middle],
 | 
					 | 
				
			||||||
            level=level_shuffle[:middle],
 | 
					 | 
				
			||||||
            idx=i_shuffle[:middle],
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        md2 = MeaningDataset(
 | 
					 | 
				
			||||||
            data=d_shuffle[middle:],
 | 
					 | 
				
			||||||
            length=l_shuffle[middle:],
 | 
					 | 
				
			||||||
            tree=m_shuffle[middle:],
 | 
					 | 
				
			||||||
            level=level_shuffle[middle:],
 | 
					 | 
				
			||||||
            idx=i_shuffle[middle:],
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        return md1, md2
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def token_frequency(self):
 | 
					    def token_frequency(self):
 | 
				
			||||||
        freq = {}
 | 
					        freq = {}
 | 
				
			||||||
| 
						 | 
					@ -323,10 +332,12 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
            MeaningMap.token_frequency(t, freq)
 | 
					            MeaningMap.token_frequency(t, freq)
 | 
				
			||||||
        return freq
 | 
					        return freq
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_seq_mask(idx, level, index):
 | 
					    def get_seq_mask(self, idx, level, index):
 | 
				
			||||||
        assert index < 15, "index must < 15"
 | 
					        assert index < 15, "index must < 15"
 | 
				
			||||||
        assert level < 8, "level must < 8"
 | 
					        assert level < 8, "level must < 8"
 | 
				
			||||||
        return [((int(i / (16**level)) & 0xF) == index) for i in idx]
 | 
					        rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
 | 
				
			||||||
 | 
					        rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
 | 
				
			||||||
 | 
					        return rank_idx == (rank_all + index if index < 0 else index)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BatchGroupMeaningDataloader(Dataset):
 | 
					class BatchGroupMeaningDataloader(Dataset):
 | 
				
			||||||
| 
						 | 
					@ -335,11 +346,11 @@ class BatchGroupMeaningDataloader(Dataset):
 | 
				
			||||||
        self.batch_size = batch_size
 | 
					        self.batch_size = batch_size
 | 
				
			||||||
        self.drop_last = drop_last
 | 
					        self.drop_last = drop_last
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        length = dataset.length
 | 
					        seq_len = [len(s) for s in dataset.seq]
 | 
				
			||||||
        unique, counts = np.unique(length, return_counts=True)
 | 
					        unique, counts = np.unique(seq_len, return_counts=True)
 | 
				
			||||||
        gl = {}
 | 
					        gl = {}
 | 
				
			||||||
        for u in unique:
 | 
					        for u in unique:
 | 
				
			||||||
            gl[u] = np.where(length == u)[0]
 | 
					            gl[u] = np.where(seq_len == u)[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        lens = list(gl.keys())
 | 
					        lens = list(gl.keys())
 | 
				
			||||||
        gs = {}
 | 
					        gs = {}
 | 
				
			||||||
| 
						 | 
					@ -365,7 +376,7 @@ class BatchGroupMeaningDataloader(Dataset):
 | 
				
			||||||
            index = index[index_shuffle]
 | 
					            index = index[index_shuffle]
 | 
				
			||||||
        self.indexBatch = index
 | 
					        self.indexBatch = index
 | 
				
			||||||
        print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
 | 
					        print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
 | 
				
			||||||
        print("Dataloader total:" + str(len(length)) + " drop:" + str(len(length) - len(index) * batch_size))
 | 
					        print("Dataloader total:" + str(len(seq_len)) + " drop:" + str(len(seq_len) - len(index) * batch_size))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __len__(self):
 | 
					    def __len__(self):
 | 
				
			||||||
        return len(self.indexBatch)
 | 
					        return len(self.indexBatch)
 | 
				
			||||||
| 
						 | 
					@ -387,229 +398,109 @@ class BatchGroupMeaningDataloader(Dataset):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True)
 | 
					    md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
 | 
				
			||||||
    train, val = md.split(0.95)
 | 
					    train, val = md.split(0.95)
 | 
				
			||||||
    fdaf = md.__getitem__(920)
 | 
					    fdaf = md.__getitem__(920)
 | 
				
			||||||
    print(md.print_tree(920))
 | 
					    print(md.print_tree(920))
 | 
				
			||||||
    print(md.idx[920])
 | 
					    print(md.rank_idx[920])
 | 
				
			||||||
    mask = MeaningDataset.get_seq_mask(md.idx[920], 1, 1)
 | 
					    print(md.rank_all[920])
 | 
				
			||||||
 | 
					    mask = md.get_seq_mask(920, 0, -1)
 | 
				
			||||||
    print(mask)
 | 
					    print(mask)
 | 
				
			||||||
    assert mask == [
 | 
					    mask = md.get_seq_mask(920, 1, 0)
 | 
				
			||||||
        False,
 | 
					    print(mask)
 | 
				
			||||||
        True,
 | 
					    mask = md.get_seq_mask(920, 1, -1)
 | 
				
			||||||
        True,
 | 
					    print(mask)
 | 
				
			||||||
        True,
 | 
					    mask = md.get_seq_mask(920, 1, 1)
 | 
				
			||||||
        True,
 | 
					    print(mask)
 | 
				
			||||||
        True,
 | 
					    assert all(
 | 
				
			||||||
        True,
 | 
					        np.equal(
 | 
				
			||||||
        True,
 | 
					            mask[0:57],
 | 
				
			||||||
        True,
 | 
					            np.array(
 | 
				
			||||||
        True,
 | 
					                [
 | 
				
			||||||
        True,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        True,
 | 
					                    False,
 | 
				
			||||||
        True,
 | 
					                    False,
 | 
				
			||||||
        True,
 | 
					                    True,
 | 
				
			||||||
        True,
 | 
					                    True,
 | 
				
			||||||
        True,
 | 
					                    True,
 | 
				
			||||||
        True,
 | 
					                    True,
 | 
				
			||||||
        True,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        True,
 | 
					                    True,
 | 
				
			||||||
        True,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        True,
 | 
					                    True,
 | 
				
			||||||
        True,
 | 
					                    True,
 | 
				
			||||||
        True,
 | 
					                    True,
 | 
				
			||||||
        True,
 | 
					                    True,
 | 
				
			||||||
        True,
 | 
					                    False,
 | 
				
			||||||
        True,
 | 
					                    False,
 | 
				
			||||||
        True,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    True,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        False,
 | 
					                    False,
 | 
				
			||||||
        True,
 | 
					                    False,
 | 
				
			||||||
        True,
 | 
					                    False,
 | 
				
			||||||
        True,
 | 
					                    False,
 | 
				
			||||||
        True,
 | 
					                ]
 | 
				
			||||||
        True,
 | 
					            ),
 | 
				
			||||||
        False,
 | 
					        )
 | 
				
			||||||
        False,
 | 
					    ), "False"
 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        True,
 | 
					 | 
				
			||||||
        True,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        True,
 | 
					 | 
				
			||||||
        True,
 | 
					 | 
				
			||||||
        True,
 | 
					 | 
				
			||||||
        True,
 | 
					 | 
				
			||||||
        True,
 | 
					 | 
				
			||||||
        True,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
        False,
 | 
					 | 
				
			||||||
    ], "False"
 | 
					 | 
				
			||||||
    freq = md.token_frequency()
 | 
					    freq = md.token_frequency()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    dl = BatchGroupMeaningDataloader(train, 2)
 | 
					    dl = BatchGroupMeaningDataloader(train, 2)
 | 
				
			||||||
    length = len(dl)
 | 
					    # length = len(dl)
 | 
				
			||||||
    it = iter(dl)
 | 
					    # it = iter(dl)
 | 
				
			||||||
    ne1 = next(it)
 | 
					    # ne1 = next(it)
 | 
				
			||||||
    ne2 = next(it)
 | 
					    # ne2 = next(it)
 | 
				
			||||||
    ne3 = next(it)
 | 
					    # ne3 = next(it)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    map1 = dl.get_tree(0)
 | 
					    # map1 = dl.get_tree(0)
 | 
				
			||||||
    map2 = dl.get_tree(1)
 | 
					    # map2 = dl.get_tree(1)
 | 
				
			||||||
    print(dl.print_tree(0))
 | 
					    # print(dl.print_tree(0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    dl = DataLoader(
 | 
					    # dl = DataLoader(
 | 
				
			||||||
        train,
 | 
					    #     train,
 | 
				
			||||||
        num_workers=1,
 | 
					    #     num_workers=1,
 | 
				
			||||||
        persistent_workers=True,
 | 
					    #     persistent_workers=True,
 | 
				
			||||||
        shuffle=False,
 | 
					    #     shuffle=False,
 | 
				
			||||||
    )
 | 
					    # )
 | 
				
			||||||
    it = iter(dl)
 | 
					    # it = iter(dl)
 | 
				
			||||||
    ne1 = next(it)
 | 
					    # ne1 = next(it)
 | 
				
			||||||
    ne2 = next(it)
 | 
					    # ne2 = next(it)
 | 
				
			||||||
    ne3 = next(it)
 | 
					    # ne3 = next(it)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for i in range(10):
 | 
					    # for i in range(10):
 | 
				
			||||||
        print(next(it)["input_ids"].numpy().tolist())
 | 
					    #     print(next(it)["input_ids"].numpy().tolist())
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue