diff --git a/finetune/embedding.py b/finetune/embedding.py new file mode 100644 index 0000000..b0e385d --- /dev/null +++ b/finetune/embedding.py @@ -0,0 +1,34 @@ +from transformers import AutoTokenizer, AutoModel +import torch +import torch.nn.functional as F + +#Mean Pooling - Take attention mask into account for correct averaging +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] #First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + +# Sentences we want sentence embeddings for +sentences = ['This is an example sentence', 'Each sentence is converted'] + +# Load model from HuggingFace Hub +tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') +model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') + +# Tokenize sentences +encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') + +# Compute token embeddings +with torch.no_grad(): + model_output = model(**encoded_input) + +# Perform pooling +sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) + +# Normalize embeddings +sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) + +print("Sentence embeddings:") +print(sentence_embeddings) +print(sentence_embeddings.cpu().numpy()) diff --git a/tools/show.py b/tools/show.py index 813d415..e6bf0ff 100644 --- a/tools/show.py +++ b/tools/show.py @@ -9,7 +9,14 @@ import os from pathlib import Path +def toTensor(tensor): + if not torch.is_tensor(tensor): + tensor = torch.tensor(tensor) + return tensor + + def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None, Value2Log=False): + tensor = toTensor(tensor) if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: raise ("Error input dims") if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}): @@ -20,6 +27,7 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, if len(tensor.shape) == 3: channel = tensor.shape[0] x = math.ceil((channel) ** 0.5) + y = math.ceil((x * x) / channel) calc = tensor.reshape((channel, tensor.shape[1] * tensor.shape[2])) if not Contrast: tensormax = calc.max(1)[0] @@ -33,11 +41,11 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, calc = calc.reshape((channel, tensor.shape[1], tensor.shape[2])) if not GridValue: GridValue = 128.0 - calc = F.pad(calc, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=GridValue) - calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2])) + calc = F.pad(calc, (0, 0, 0, 0, 0, x * y - channel), mode="constant", value=GridValue) + calc = calc.reshape((y, x, tensor.shape[1], tensor.shape[2])) calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue) tensor = calc.permute((0, 2, 1, 3)) - tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3])) + tensor = tensor.reshape((y * tensor.shape[1], x * tensor.shape[3])) DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, Contrast=[0.0, 255.0], GridValue=GridValue) return @@ -78,6 +86,7 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, def DumpTensorToLog(tensor, name="log"): + tensor = toTensor(tensor) tensor_mean = torch.mean(tensor).cpu().detach().numpy() tensor_abs_mean = torch.mean(torch.abs(tensor)).cpu().detach().numpy() tensor_range = (torch.max(tensor) - torch.min(tensor)).cpu().detach().numpy() @@ -92,6 +101,7 @@ def DumpTensorToLog(tensor, name="log"): def DumpTensorToFile(tensor, name="tensor.pt"): + tensor = toTensor(tensor) torch.save(tensor.cpu(), name) diff --git a/wit/configuration.py b/wit/configuration.py index 1d45674..d512c36 100644 --- a/wit/configuration.py +++ b/wit/configuration.py @@ -38,10 +38,10 @@ class ModelConfig: class MeaningDatasetConfig: def __init__(self): - self.level_ratio = 5 - self.level = 5 - self.dataset_level = 3 + self.start = 10000 + self.size = 4000 self.min_subitem = 2 + self.max_subitem = 10 self.val_mask_level = None self.val_mask_idx = None diff --git a/wit/dataset/dataset.py b/wit/dataset/dataset.py index bdf6e85..edfcf9e 100644 --- a/wit/dataset/dataset.py +++ b/wit/dataset/dataset.py @@ -31,12 +31,12 @@ def InitDataset(config): if config.dataset.name == "meaning": c = config.dataset.meaning vocab = config.model_config.vocab_size - start = vocab * (c.level_ratio**c.level) - size = vocab * int((c.level_ratio**c.dataset_level)) + start = c.start + size = c.size path = "./data/" - trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt" - valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt" + trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt" + valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt" if not os.path.exists(path): os.mkdir(path) if os.path.exists(trainfile) and os.path.exists(valfile): @@ -48,7 +48,7 @@ def InitDataset(config): val_dataset.set_mask(c.val_mask_level, c.val_mask_idx) print(f"INFO: Load dataset end") else: - raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem) + raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem) raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx) train_dataset, val_dataset = raw_dataset.split(0.9) torch.save(train_dataset, trainfile) @@ -80,11 +80,11 @@ def InitValDataset(config): if config.dataset.name == "meaning": c = config.dataset.meaning vocab = config.model_config.vocab_size - start = vocab * (c.level_ratio**c.level) - size = vocab * int((c.level_ratio**c.dataset_level)) + start = c.start + size = c.size path = "./data/" - valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt" + valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt" if not os.path.exists(path): os.mkdir(path) if os.path.exists(valfile): @@ -93,7 +93,7 @@ def InitValDataset(config): val_dataset.set_mask(c.val_mask_level, c.val_mask_idx) print(f"INFO: Load dataset end") else: - raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem) + raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem) raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx) train_dataset, val_dataset = raw_dataset.split(0.9) torch.save(val_dataset, valfile) diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index 5c33928..6f04434 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -105,12 +105,13 @@ class MeaningMap: index = index + 1 for i in range(self.vocab_size, size): - m = map[i] + m = map[i] # 当前meaning的拆分的分支 m = m[m >= 0] # donot cut off the map such as [0] - m_len = len(m) + m_len = len(m) # 当前meaning的拆分的分支个数 m_list = m.tolist() assert m_list, "map list can not be empty list" + # 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index( current -> common 两个变化的level距离 + def get_level_change(self, meaning): + def level_change(ms_map, meaning, current_to_common, common_to_current): ms = ms_map[meaning] - mlist.append(meaning) - mlist.append(-255) # level down marker for m in ms[ms >= 0].tolist(): if m >= self.vocab_size: - get_tree_list(ms_map, m, mlist) + common_to_current[-1] = common_to_current[-1] + 1 + level_change(ms_map, m, current_to_common, common_to_current) else: - mlist.append(m) - mlist.append(-1) # level up marker + current_to_common.append(0) + common_to_current.append(0) + current_to_common[-2] = current_to_common[-2] + 1 - meaninglist = [] - get_tree_list(self.ms_map, meaning, meaninglist) - return meaninglist + common_to_current = [] + common_to_current.append(1) + current_to_common = [] + current_to_common.append(0) + level_change(self.ms_map, meaning, current_to_common, common_to_current) + current_to_common = current_to_common[:-1] + common_to_current = common_to_current[:-1] + return current_to_common, common_to_current + + # 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面token是否有关系 + def get_relation_table(self, meaning): + current_to_common, common_to_current = self.get_level_change(meaning) + width = len(current_to_common) + relation = np.zeros((width, width), dtype=int) + relation[0, 0] = 1 + for i in range(1, width, 1): + if i == width - 2: + print(1) + ori = current_to_common[i] - common_to_current[i] + start_index = width + for s in range(i - 1, -1, -1): + if ori < 0: + break + ori = ori - common_to_current[s] + current_to_common[s] + start_index = s + relation[i, start_index : i + 1] = 1.0 + return relation def max_length(self): return max(self.ms_len) @@ -430,6 +457,7 @@ class BatchGroupMeaningDataloader(Dataset): self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate ) + if __name__ == "__main__": md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False) diff --git a/wit/doc/meaning_dataset.md b/wit/doc/meaning_dataset.md index 772147c..47e6294 100644 --- a/wit/doc/meaning_dataset.md +++ b/wit/doc/meaning_dataset.md @@ -12,13 +12,20 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。 6. level表示当前token相对于root meaning的距离 7. rank 8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充 -9. rank_all表示当前token在不同层的分子个数,每4位表示在一层里面的编号,低4位表示最低层级的rank_all,高位无用的位用1填充 +9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充 10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构 11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个 12. meaning_height 当前meaning的总高度 13. meaning_weight 当前meaning的总宽度 14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 +## code + +``` +vocab = config.model_config.vocab_size +start 数据集的样本开始meaning +size 数据集的样本个数 +``` ``` vocab_size = 256 meaning = 115200 @@ -33,11 +40,13 @@ vocab_size = 256 meaning = 115200 / \ / \ / \ / \ 176 11 255 129 129 99 211 111 -sequence = 123 42 32 176 11 255 129 245 233 129 99 23 211 111 93 176 -level = 3 3 2 4 4 4 4 2 2 4 4 3 4 4 3 3 -idx at 0 = 0 1 1 0 1 0 1 0 1 0 1 2 0 1 0 1 -idx at 1 = 0 0 0 0 0 1 1 1 1 0 0 0 0 0 2 2 -idx 0 1 1 0 1 16 17 16 17 0 1 2 0 1 32 33 +sequence = 123 42 32 176 11 255 129 245 233 129 99 23 211 111 93 176 +level = 3 3 2 4 4 4 4 2 2 4 4 3 4 4 3 3 +rank_idx = 0 1 1 0 1 16 17 16 17 0 1 2 0 1 32 33 + idx at0 = 0 1 1 0 1 0 1 0 1 0 1 2 0 1 0 1 + idx at1 = 0 0 0 0 0 1 1 1 1 0 0 0 0 0 2 2 + +rank_all = ``` diff --git a/wit/doc/q@k_seq_47_layer_0.png b/wit/doc/q@k_seq_47_layer_0.png new file mode 100644 index 0000000..fa207f3 Binary files /dev/null and b/wit/doc/q@k_seq_47_layer_0.png differ diff --git a/wit/doc/train_meaning_dataset.md b/wit/doc/train_meaning_dataset.md index bdff996..eddcc43 100644 --- a/wit/doc/train_meaning_dataset.md +++ b/wit/doc/train_meaning_dataset.md @@ -10,4 +10,11 @@ ## 不同模型深度对结果的影响 6层相对于3层没有提升的原因,可能是数据集太小,3层已经能完全拟合 -![alt text](model_level_number.png) \ No newline at end of file +![alt text](model_level_number.png) + +## qk图解释 + +1. key[10] = 1000.0 +2. 每一行数据(像素)表示一个新的token,和前面所有token的关系 + +![alt text](q@k_seq_47_layer_0.png) \ No newline at end of file diff --git a/wit/inference.py b/wit/inference.py index 1555494..ad242e7 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -8,7 +8,9 @@ import dataset.dataset as ds if __name__ == "__main__": - checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt" + # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt" + # checkpoint_path = "log/bigger/version_4/checkpoints/epoch=81-step=64288.ckpt" + checkpoint_path = "log/bigger/version_8/checkpoints/epoch=14-step=67455.ckpt" qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() @@ -19,18 +21,18 @@ if __name__ == "__main__": runner = ModelRunner(qwen.llm) - # batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64) - # sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) - # print(sorted_logits.detach().cpu().numpy()) - # print(sorted_indices.detach().cpu().numpy()) - val = ds.InitValDataset(conf).dataset md = val.meaning_dataset map = md.get_meaning_map() - item = md.get_token(0) - node = map.get_nodetree(md.get_meaning(0)) - # node.print() + # seq:844 + # seq:849 + # seq:991 + # seq:995 + + node = map.get_nodetree(995) + item, l, rank_idx, rank_all = map.get_sequence(995) + print("len of seq:" + str(len(item))) for i in range(1, len(item)): itemm = [item[:i]] diff --git a/wit/query_block_output.py b/wit/query_block_output.py index 98b9d96..85039f5 100644 --- a/wit/query_block_output.py +++ b/wit/query_block_output.py @@ -15,7 +15,7 @@ import dataset.dataset as ds if __name__ == "__main__": - checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt" + checkpoint_path = "log/bigger/version_8/checkpoints/epoch=14-step=67455.ckpt" qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() @@ -25,6 +25,7 @@ if __name__ == "__main__": runner = ModelRunner(qwen.llm) def DumpQK(query, key, causal_mask, index): + global relation_table size = query.shape[2] scale_factor = 1 / math.sqrt(query.size(-1)) attn_weight = query @ key.transpose(-2, -1) * scale_factor @@ -35,28 +36,44 @@ if __name__ == "__main__": attn_weight = attn_weight * attn_mask qk = attn_weight[0] prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png" + qk = qk.cpu() + qk = torch.cat((qk, relation_table.unsqueeze(0)), dim=0) show.DumpTensorToImage(qk, prePath, GridValue=255) # qk_seq.append(qk) # qk_index = size qwen.llm.hook_attention = DumpQK - val = ds.InitValDataset(conf).dataset md = val.meaning_dataset map = md.get_meaning_map() - item = md.get_token(0) - node = map.get_nodetree(md.get_meaning(0)) - # node.print() + # seq:844 + # seq:849 + # seq:991 + # seq:995 + meaning = 995 + node = map.get_nodetree(meaning) + current_to_common, common_to_current = map.get_level_change(meaning) + + node.print() + print(current_to_common) + print(common_to_current) + relation_table = map.get_relation_table(meaning) + # prePath = "./temp/" + "q@k_seq_" + "_layer_" + ".png" + # show.DumpTensorToImage(relation_table, prePath, GridValue=255) + relation_table = torch.tensor(relation_table) + + item, level, rank_idx, rank_all = map.get_sequence(meaning) + print(item) + print(level) + print(rank_idx) + print(rank_all) + print("len of seq:" + str(len(item))) batch = torch.tensor([item], dtype=torch.int64) sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) next_token = sorted_indices.detach().cpu().numpy()[0][0] - node.print() - - - # batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64) # sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) diff --git a/wit/query_meaning_freq.py b/wit/query_meaning_freq.py index 6e46aa1..87087c1 100644 --- a/wit/query_meaning_freq.py +++ b/wit/query_meaning_freq.py @@ -25,13 +25,13 @@ if __name__ == "__main__": loader = train_dataloader.dataset map = loader.meaning_dataset.get_meaning_map() - trees = {} + seqs = {} for batch in loader: for m in batch["meaning"]: - trees[m] = map.get_tree(m) + seqs[m] = map.get_sequence(m) while True: m = int(input("input meaning: ")) total = 0 - for tree in trees.values(): - total = total + tree.count(m) + for seq in seqs.values(): + total = total + seq.count(m) print(f"meaning of {m} count as {total}") diff --git a/wit/train.py b/wit/train.py index 41280f9..9021e32 100644 --- a/wit/train.py +++ b/wit/train.py @@ -20,7 +20,7 @@ if __name__ == "__main__": conf.learning_rate = 0.001 conf.use_tril_attention_mask = None conf.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true" - conf.train_batch_size = 16 + conf.train_batch_size = 32 conf.val_batch_size = 2 conf.num_proc = 8 conf.max_epochs = 1000 @@ -29,18 +29,18 @@ if __name__ == "__main__": conf.seed = 42 conf.dataloader_works = 2 - conf.dataset.meaning.level_ratio = 5 - conf.dataset.meaning.level = 2 - conf.dataset.meaning.dataset_level = 5 + conf.dataset.meaning.start = 800 + conf.dataset.meaning.size = 200000 conf.dataset.meaning.min_subitem = 2 + conf.dataset.meaning.max_subitem = 4 conf.dataset.meaning.val_mask_level = [0, 1, 2] conf.dataset.meaning.val_mask_idx = [0, 0, -1] config.vocab_size = 32 - config.hidden_size = 128 # 128 1024 2048 32 - config.intermediate_size = 256 - config.num_hidden_layers = 3 # 6 12 24 3 - config.num_attention_heads = 8 # 8 8 16 + config.hidden_size = 256 # 128 1024 2048 32 + config.intermediate_size = 512 + config.num_hidden_layers = 4 # 6 12 24 3 + config.num_attention_heads = 4 # 8 8 16 torch.manual_seed(conf.seed) np.random.seed(conf.seed)