Retest wit.

This commit is contained in:
Colin 2024-03-07 16:30:37 +08:00
parent a70d12d04d
commit 601c7f6510
4 changed files with 16 additions and 9 deletions

View File

@ -7,8 +7,8 @@
class QWenConfig:
def __init__(self):
self.vocab_size = 4096
self.hidden_size = 128 # 128 1024 2048
self.num_hidden_layers = 6 # 6 12 24
self.hidden_size = 128 # 128 1024 2048 32
self.num_hidden_layers = 6 # 6 12 24 3
self.num_attention_heads = 8 # 8 8 16
self.emb_dropout_prob = 0.0
self.attn_dropout_prob = 0.0

View File

@ -5,9 +5,6 @@ import pytorch_lightning as pl
import torch
import torchmetrics
# from utils import init_model
# from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from modeling_wit import QWenLMHeadModel
from configuration_qwen import QWenConfig
@ -64,6 +61,8 @@ class LitModule(pl.LightningModule):
label_mask = labels < self.vocab_size
logits = logits[label_mask]
labels = labels[label_mask]
# m = torch.max(logits, 1).indices.cpu().numpy()
# ll = labels.cpu().numpy()
self.metric_accuracy.update(logits, labels)
self.metric_loss.update(loss)
@ -100,5 +99,6 @@ class LitModule(pl.LightningModule):
mode="max",
stopping_threshold=1,
)
return [checkpoint_callback]
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")
return [checkpoint_callback, lr_monitor]
# return [checkpoint_callback, early_stop_callback]

View File

@ -245,7 +245,8 @@ class QwenRunner:
rot_dim = freqs[0].shape[-1]
cos, sin = freqs
t_float = t.float()
t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
t_rot = t_float[..., :rot_dim]
t_pass = t_float[..., rot_dim:]
t_rot = (t_rot * cos) + (self._rotate_half(t_rot) * sin)
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
@ -351,6 +352,8 @@ class QwenRunner:
mask = shift_labels < self.qwen.config.vocab_size
shift_labels = shift_labels[mask]
shift_logits = shift_logits[mask]
# m = torch.max(shift_logits, 1).indices.cpu().numpy()
# ll = shift_labels.cpu().numpy()
loss = CrossEntropyLoss()(shift_logits, shift_labels)
return lm_logits, loss

View File

@ -35,7 +35,7 @@ vocab_size = 4096
class SpecialDataset(Dataset):
def __init__(self, start=1, end=320, size=32768):
def __init__(self, start=1, end=320, size=32768): # 1048576
self.size = size
self.features = []
a = torch.randint(start, end, [size])
@ -43,7 +43,11 @@ class SpecialDataset(Dataset):
c = torch.randint(start, end, [size])
d = torch.randint(start, end, [size])
# self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0)
self.data = torch.stack([a, a + a, a + a, a + a]).permute(1, 0)
self.data = torch.stack([a, a, a + a]).permute(1, 0)
# a = torch.randint(start, end, [size])
# self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5
# self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1
def __len__(self):
return self.size