Retest wit.
This commit is contained in:
parent
a70d12d04d
commit
601c7f6510
|
@ -7,8 +7,8 @@
|
|||
class QWenConfig:
|
||||
def __init__(self):
|
||||
self.vocab_size = 4096
|
||||
self.hidden_size = 128 # 128 1024 2048
|
||||
self.num_hidden_layers = 6 # 6 12 24
|
||||
self.hidden_size = 128 # 128 1024 2048 32
|
||||
self.num_hidden_layers = 6 # 6 12 24 3
|
||||
self.num_attention_heads = 8 # 8 8 16
|
||||
self.emb_dropout_prob = 0.0
|
||||
self.attn_dropout_prob = 0.0
|
||||
|
|
|
@ -5,9 +5,6 @@ import pytorch_lightning as pl
|
|||
import torch
|
||||
import torchmetrics
|
||||
|
||||
# from utils import init_model
|
||||
# from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
from modeling_wit import QWenLMHeadModel
|
||||
from configuration_qwen import QWenConfig
|
||||
|
||||
|
@ -64,6 +61,8 @@ class LitModule(pl.LightningModule):
|
|||
label_mask = labels < self.vocab_size
|
||||
logits = logits[label_mask]
|
||||
labels = labels[label_mask]
|
||||
# m = torch.max(logits, 1).indices.cpu().numpy()
|
||||
# ll = labels.cpu().numpy()
|
||||
self.metric_accuracy.update(logits, labels)
|
||||
self.metric_loss.update(loss)
|
||||
|
||||
|
@ -100,5 +99,6 @@ class LitModule(pl.LightningModule):
|
|||
mode="max",
|
||||
stopping_threshold=1,
|
||||
)
|
||||
return [checkpoint_callback]
|
||||
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")
|
||||
return [checkpoint_callback, lr_monitor]
|
||||
# return [checkpoint_callback, early_stop_callback]
|
||||
|
|
|
@ -245,7 +245,8 @@ class QwenRunner:
|
|||
rot_dim = freqs[0].shape[-1]
|
||||
cos, sin = freqs
|
||||
t_float = t.float()
|
||||
t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
|
||||
t_rot = t_float[..., :rot_dim]
|
||||
t_pass = t_float[..., rot_dim:]
|
||||
t_rot = (t_rot * cos) + (self._rotate_half(t_rot) * sin)
|
||||
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
||||
|
||||
|
@ -351,6 +352,8 @@ class QwenRunner:
|
|||
mask = shift_labels < self.qwen.config.vocab_size
|
||||
shift_labels = shift_labels[mask]
|
||||
shift_logits = shift_logits[mask]
|
||||
# m = torch.max(shift_logits, 1).indices.cpu().numpy()
|
||||
# ll = shift_labels.cpu().numpy()
|
||||
loss = CrossEntropyLoss()(shift_logits, shift_labels)
|
||||
|
||||
return lm_logits, loss
|
||||
|
|
|
@ -35,7 +35,7 @@ vocab_size = 4096
|
|||
|
||||
|
||||
class SpecialDataset(Dataset):
|
||||
def __init__(self, start=1, end=320, size=32768):
|
||||
def __init__(self, start=1, end=320, size=32768): # 1048576
|
||||
self.size = size
|
||||
self.features = []
|
||||
a = torch.randint(start, end, [size])
|
||||
|
@ -43,7 +43,11 @@ class SpecialDataset(Dataset):
|
|||
c = torch.randint(start, end, [size])
|
||||
d = torch.randint(start, end, [size])
|
||||
# self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0)
|
||||
self.data = torch.stack([a, a + a, a + a, a + a]).permute(1, 0)
|
||||
self.data = torch.stack([a, a, a + a]).permute(1, 0)
|
||||
|
||||
# a = torch.randint(start, end, [size])
|
||||
# self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5
|
||||
# self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
|
Loading…
Reference in New Issue