Retest wit.
This commit is contained in:
		
							parent
							
								
									a70d12d04d
								
							
						
					
					
						commit
						601c7f6510
					
				| 
						 | 
				
			
			@ -7,8 +7,8 @@
 | 
			
		|||
class QWenConfig:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.vocab_size = 4096
 | 
			
		||||
        self.hidden_size = 128  # 128 1024 2048
 | 
			
		||||
        self.num_hidden_layers = 6  # 6 12 24
 | 
			
		||||
        self.hidden_size = 128  # 128 1024 2048  32
 | 
			
		||||
        self.num_hidden_layers = 6  # 6 12 24  3
 | 
			
		||||
        self.num_attention_heads = 8  # 8 8 16
 | 
			
		||||
        self.emb_dropout_prob = 0.0
 | 
			
		||||
        self.attn_dropout_prob = 0.0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,9 +5,6 @@ import pytorch_lightning as pl
 | 
			
		|||
import torch
 | 
			
		||||
import torchmetrics
 | 
			
		||||
 | 
			
		||||
# from utils import init_model
 | 
			
		||||
# from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel
 | 
			
		||||
 | 
			
		||||
from modeling_wit import QWenLMHeadModel
 | 
			
		||||
from configuration_qwen import QWenConfig
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -64,6 +61,8 @@ class LitModule(pl.LightningModule):
 | 
			
		|||
        label_mask = labels < self.vocab_size
 | 
			
		||||
        logits = logits[label_mask]
 | 
			
		||||
        labels = labels[label_mask]
 | 
			
		||||
        # m = torch.max(logits, 1).indices.cpu().numpy()
 | 
			
		||||
        # ll = labels.cpu().numpy()
 | 
			
		||||
        self.metric_accuracy.update(logits, labels)
 | 
			
		||||
        self.metric_loss.update(loss)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -100,5 +99,6 @@ class LitModule(pl.LightningModule):
 | 
			
		|||
            mode="max",
 | 
			
		||||
            stopping_threshold=1,
 | 
			
		||||
        )
 | 
			
		||||
        return [checkpoint_callback]
 | 
			
		||||
        lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")
 | 
			
		||||
        return [checkpoint_callback, lr_monitor]
 | 
			
		||||
        # return [checkpoint_callback, early_stop_callback]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -245,7 +245,8 @@ class QwenRunner:
 | 
			
		|||
        rot_dim = freqs[0].shape[-1]
 | 
			
		||||
        cos, sin = freqs
 | 
			
		||||
        t_float = t.float()
 | 
			
		||||
        t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
 | 
			
		||||
        t_rot = t_float[..., :rot_dim]
 | 
			
		||||
        t_pass = t_float[..., rot_dim:]
 | 
			
		||||
        t_rot = (t_rot * cos) + (self._rotate_half(t_rot) * sin)
 | 
			
		||||
        return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -351,6 +352,8 @@ class QwenRunner:
 | 
			
		|||
            mask = shift_labels < self.qwen.config.vocab_size
 | 
			
		||||
            shift_labels = shift_labels[mask]
 | 
			
		||||
            shift_logits = shift_logits[mask]
 | 
			
		||||
            # m = torch.max(shift_logits, 1).indices.cpu().numpy()
 | 
			
		||||
            # ll = shift_labels.cpu().numpy()
 | 
			
		||||
            loss = CrossEntropyLoss()(shift_logits, shift_labels)
 | 
			
		||||
 | 
			
		||||
        return lm_logits, loss
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -35,7 +35,7 @@ vocab_size = 4096
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class SpecialDataset(Dataset):
 | 
			
		||||
    def __init__(self, start=1, end=320, size=32768):
 | 
			
		||||
    def __init__(self, start=1, end=320, size=32768):  # 1048576
 | 
			
		||||
        self.size = size
 | 
			
		||||
        self.features = []
 | 
			
		||||
        a = torch.randint(start, end, [size])
 | 
			
		||||
| 
						 | 
				
			
			@ -43,7 +43,11 @@ class SpecialDataset(Dataset):
 | 
			
		|||
        c = torch.randint(start, end, [size])
 | 
			
		||||
        d = torch.randint(start, end, [size])
 | 
			
		||||
        # self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0)
 | 
			
		||||
        self.data = torch.stack([a, a + a, a + a, a + a]).permute(1, 0)
 | 
			
		||||
        self.data = torch.stack([a, a, a + a]).permute(1, 0)
 | 
			
		||||
 | 
			
		||||
        # a = torch.randint(start, end, [size])
 | 
			
		||||
        # self.data = torch.stack([a, a, a + a]).permute(1, 0)  # accuracy=0.5
 | 
			
		||||
        # self.data = torch.stack([a, a + a, a]).permute(1, 0)  # accuracy=1
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return self.size
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue