import torch from model.light_module import LightModule from model.light_module import ModelRunner import numpy as np import math import sys sys.path.append("..") from tools import show import dataset.dataset as ds if __name__ == "__main__": checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt" qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() conf = qwen.config torch.manual_seed(conf.seed) np.random.seed(conf.seed) runner = ModelRunner(qwen.llm) def DumpQK(query, key, causal_mask, index): size = query.shape[2] scale_factor = 1 / math.sqrt(query.size(-1)) attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device) attn_mask.masked_fill_(causal_mask.logical_not(), float(0)) attn_weight = attn_weight * attn_mask attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = attn_weight * attn_mask qk = attn_weight[0] prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png" show.DumpTensorToImage(qk, prePath, GridValue=255) # qk_seq.append(qk) # qk_index = size qwen.llm.hook_attention = DumpQK val = ds.InitValDataset(conf).dataset md = val.meaning_dataset map = md.get_meaning_map() item = md.get_token(0) node = map.get_nodetree(md.get_meaning(0)) # node.print() batch = torch.tensor([item], dtype=torch.int64) sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) next_token = sorted_indices.detach().cpu().numpy()[0][0] node.print() # batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64) # sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) # print(sorted_logits.detach().cpu().numpy()) # print(sorted_indices.detach().cpu().numpy())