52 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			52 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
import torch
 | 
						|
 | 
						|
from model.light_module import LightModule
 | 
						|
from model.light_module import ModelRunner
 | 
						|
import numpy as np
 | 
						|
 | 
						|
import math
 | 
						|
import sys
 | 
						|
 | 
						|
sys.path.append("..")
 | 
						|
from tools import show
 | 
						|
 | 
						|
 | 
						|
import dataset.dataset as ds
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
 | 
						|
    # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
 | 
						|
    checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
 | 
						|
    checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
 | 
						|
    checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
 | 
						|
 | 
						|
    qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
 | 
						|
    qwen.eval()
 | 
						|
    conf = qwen.config
 | 
						|
    torch.manual_seed(conf.seed)
 | 
						|
    np.random.seed(conf.seed)
 | 
						|
    runner = ModelRunner(qwen.llm)
 | 
						|
 | 
						|
    def DumpQK(query, key, causal_mask, index):
 | 
						|
        size = query.shape[2]
 | 
						|
        scale_factor = 1 / math.sqrt(query.size(-1))
 | 
						|
        attn_weight = query @ key.transpose(-2, -1) * scale_factor
 | 
						|
        attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
 | 
						|
        attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
 | 
						|
        attn_weight = attn_weight * attn_mask
 | 
						|
        attn_weight = torch.softmax(attn_weight, dim=-1)
 | 
						|
        attn_weight = attn_weight * attn_mask
 | 
						|
        qk = attn_weight[0]
 | 
						|
        prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
 | 
						|
        show.DumpTensorToImage(qk, prePath, GridValue=255)
 | 
						|
        # qk_seq.append(qk)
 | 
						|
        # qk_index = size
 | 
						|
 | 
						|
    qwen.llm.hook_attention = DumpQK
 | 
						|
 | 
						|
    batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
 | 
						|
    sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
 | 
						|
 | 
						|
    print(sorted_logits.detach().cpu().numpy())
 | 
						|
    print(sorted_indices.detach().cpu().numpy())
 |