Retest wit.
This commit is contained in:
parent
a70d12d04d
commit
601c7f6510
|
@ -7,8 +7,8 @@
|
||||||
class QWenConfig:
|
class QWenConfig:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.vocab_size = 4096
|
self.vocab_size = 4096
|
||||||
self.hidden_size = 128 # 128 1024 2048
|
self.hidden_size = 128 # 128 1024 2048 32
|
||||||
self.num_hidden_layers = 6 # 6 12 24
|
self.num_hidden_layers = 6 # 6 12 24 3
|
||||||
self.num_attention_heads = 8 # 8 8 16
|
self.num_attention_heads = 8 # 8 8 16
|
||||||
self.emb_dropout_prob = 0.0
|
self.emb_dropout_prob = 0.0
|
||||||
self.attn_dropout_prob = 0.0
|
self.attn_dropout_prob = 0.0
|
||||||
|
|
|
@ -5,9 +5,6 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
|
|
||||||
# from utils import init_model
|
|
||||||
# from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
|
||||||
|
|
||||||
from modeling_wit import QWenLMHeadModel
|
from modeling_wit import QWenLMHeadModel
|
||||||
from configuration_qwen import QWenConfig
|
from configuration_qwen import QWenConfig
|
||||||
|
|
||||||
|
@ -64,6 +61,8 @@ class LitModule(pl.LightningModule):
|
||||||
label_mask = labels < self.vocab_size
|
label_mask = labels < self.vocab_size
|
||||||
logits = logits[label_mask]
|
logits = logits[label_mask]
|
||||||
labels = labels[label_mask]
|
labels = labels[label_mask]
|
||||||
|
# m = torch.max(logits, 1).indices.cpu().numpy()
|
||||||
|
# ll = labels.cpu().numpy()
|
||||||
self.metric_accuracy.update(logits, labels)
|
self.metric_accuracy.update(logits, labels)
|
||||||
self.metric_loss.update(loss)
|
self.metric_loss.update(loss)
|
||||||
|
|
||||||
|
@ -100,5 +99,6 @@ class LitModule(pl.LightningModule):
|
||||||
mode="max",
|
mode="max",
|
||||||
stopping_threshold=1,
|
stopping_threshold=1,
|
||||||
)
|
)
|
||||||
return [checkpoint_callback]
|
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")
|
||||||
|
return [checkpoint_callback, lr_monitor]
|
||||||
# return [checkpoint_callback, early_stop_callback]
|
# return [checkpoint_callback, early_stop_callback]
|
||||||
|
|
|
@ -245,7 +245,8 @@ class QwenRunner:
|
||||||
rot_dim = freqs[0].shape[-1]
|
rot_dim = freqs[0].shape[-1]
|
||||||
cos, sin = freqs
|
cos, sin = freqs
|
||||||
t_float = t.float()
|
t_float = t.float()
|
||||||
t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
|
t_rot = t_float[..., :rot_dim]
|
||||||
|
t_pass = t_float[..., rot_dim:]
|
||||||
t_rot = (t_rot * cos) + (self._rotate_half(t_rot) * sin)
|
t_rot = (t_rot * cos) + (self._rotate_half(t_rot) * sin)
|
||||||
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
||||||
|
|
||||||
|
@ -351,6 +352,8 @@ class QwenRunner:
|
||||||
mask = shift_labels < self.qwen.config.vocab_size
|
mask = shift_labels < self.qwen.config.vocab_size
|
||||||
shift_labels = shift_labels[mask]
|
shift_labels = shift_labels[mask]
|
||||||
shift_logits = shift_logits[mask]
|
shift_logits = shift_logits[mask]
|
||||||
|
# m = torch.max(shift_logits, 1).indices.cpu().numpy()
|
||||||
|
# ll = shift_labels.cpu().numpy()
|
||||||
loss = CrossEntropyLoss()(shift_logits, shift_labels)
|
loss = CrossEntropyLoss()(shift_logits, shift_labels)
|
||||||
|
|
||||||
return lm_logits, loss
|
return lm_logits, loss
|
||||||
|
|
|
@ -35,7 +35,7 @@ vocab_size = 4096
|
||||||
|
|
||||||
|
|
||||||
class SpecialDataset(Dataset):
|
class SpecialDataset(Dataset):
|
||||||
def __init__(self, start=1, end=320, size=32768):
|
def __init__(self, start=1, end=320, size=32768): # 1048576
|
||||||
self.size = size
|
self.size = size
|
||||||
self.features = []
|
self.features = []
|
||||||
a = torch.randint(start, end, [size])
|
a = torch.randint(start, end, [size])
|
||||||
|
@ -43,7 +43,11 @@ class SpecialDataset(Dataset):
|
||||||
c = torch.randint(start, end, [size])
|
c = torch.randint(start, end, [size])
|
||||||
d = torch.randint(start, end, [size])
|
d = torch.randint(start, end, [size])
|
||||||
# self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0)
|
# self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0)
|
||||||
self.data = torch.stack([a, a + a, a + a, a + a]).permute(1, 0)
|
self.data = torch.stack([a, a, a + a]).permute(1, 0)
|
||||||
|
|
||||||
|
# a = torch.randint(start, end, [size])
|
||||||
|
# self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5
|
||||||
|
# self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.size
|
return self.size
|
||||||
|
|
Loading…
Reference in New Issue