From 601c7f65101b6d3822f9e7dca5fce13c90fa47cd Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 7 Mar 2024 16:30:37 +0800 Subject: [PATCH] Retest wit. --- wit/configuration_qwen.py | 4 ++-- wit/lit_module.py | 8 ++++---- wit/modeling_wit.py | 5 ++++- wit/train.py | 8 ++++++-- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/wit/configuration_qwen.py b/wit/configuration_qwen.py index f2afc3e..828bd08 100644 --- a/wit/configuration_qwen.py +++ b/wit/configuration_qwen.py @@ -7,8 +7,8 @@ class QWenConfig: def __init__(self): self.vocab_size = 4096 - self.hidden_size = 128 # 128 1024 2048 - self.num_hidden_layers = 6 # 6 12 24 + self.hidden_size = 128 # 128 1024 2048 32 + self.num_hidden_layers = 6 # 6 12 24 3 self.num_attention_heads = 8 # 8 8 16 self.emb_dropout_prob = 0.0 self.attn_dropout_prob = 0.0 diff --git a/wit/lit_module.py b/wit/lit_module.py index 7e69a4b..1015be7 100644 --- a/wit/lit_module.py +++ b/wit/lit_module.py @@ -5,9 +5,6 @@ import pytorch_lightning as pl import torch import torchmetrics -# from utils import init_model -# from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel - from modeling_wit import QWenLMHeadModel from configuration_qwen import QWenConfig @@ -64,6 +61,8 @@ class LitModule(pl.LightningModule): label_mask = labels < self.vocab_size logits = logits[label_mask] labels = labels[label_mask] + # m = torch.max(logits, 1).indices.cpu().numpy() + # ll = labels.cpu().numpy() self.metric_accuracy.update(logits, labels) self.metric_loss.update(loss) @@ -100,5 +99,6 @@ class LitModule(pl.LightningModule): mode="max", stopping_threshold=1, ) - return [checkpoint_callback] + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step") + return [checkpoint_callback, lr_monitor] # return [checkpoint_callback, early_stop_callback] diff --git a/wit/modeling_wit.py b/wit/modeling_wit.py index ff10fbd..f25839e 100644 --- a/wit/modeling_wit.py +++ b/wit/modeling_wit.py @@ -245,7 +245,8 @@ class QwenRunner: rot_dim = freqs[0].shape[-1] cos, sin = freqs t_float = t.float() - t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:] + t_rot = t_float[..., :rot_dim] + t_pass = t_float[..., rot_dim:] t_rot = (t_rot * cos) + (self._rotate_half(t_rot) * sin) return torch.cat((t_rot, t_pass), dim=-1).type_as(t) @@ -351,6 +352,8 @@ class QwenRunner: mask = shift_labels < self.qwen.config.vocab_size shift_labels = shift_labels[mask] shift_logits = shift_logits[mask] + # m = torch.max(shift_logits, 1).indices.cpu().numpy() + # ll = shift_labels.cpu().numpy() loss = CrossEntropyLoss()(shift_logits, shift_labels) return lm_logits, loss diff --git a/wit/train.py b/wit/train.py index 03c46a2..f6062e1 100644 --- a/wit/train.py +++ b/wit/train.py @@ -35,7 +35,7 @@ vocab_size = 4096 class SpecialDataset(Dataset): - def __init__(self, start=1, end=320, size=32768): + def __init__(self, start=1, end=320, size=32768): # 1048576 self.size = size self.features = [] a = torch.randint(start, end, [size]) @@ -43,7 +43,11 @@ class SpecialDataset(Dataset): c = torch.randint(start, end, [size]) d = torch.randint(start, end, [size]) # self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0) - self.data = torch.stack([a, a + a, a + a, a + a]).permute(1, 0) + self.data = torch.stack([a, a, a + a]).permute(1, 0) + + # a = torch.randint(start, end, [size]) + # self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5 + # self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1 def __len__(self): return self.size