diff --git a/wit/configuration_qwen.py b/wit/configuration_qwen.py index e1725b1..f2afc3e 100644 --- a/wit/configuration_qwen.py +++ b/wit/configuration_qwen.py @@ -7,9 +7,9 @@ class QWenConfig: def __init__(self): self.vocab_size = 4096 - self.hidden_size = 128 # 1024 2048 - self.num_hidden_layers = 6 # 12 24 - self.num_attention_heads = 8 # 8 16 + self.hidden_size = 128 # 128 1024 2048 + self.num_hidden_layers = 6 # 6 12 24 + self.num_attention_heads = 8 # 8 8 16 self.emb_dropout_prob = 0.0 self.attn_dropout_prob = 0.0 self.layer_norm_epsilon = 1e-6 diff --git a/wit/lit_module.py b/wit/lit_module.py index 055b2aa..7e69a4b 100644 --- a/wit/lit_module.py +++ b/wit/lit_module.py @@ -30,9 +30,10 @@ class LitModule(pl.LightningModule): self.learning_rate = learning_rate self.use_tril_attention_mask = use_tril_attention_mask self.metric_loss = torchmetrics.MeanMetric() + self.vocab_size = self.llm.config.vocab_size self.metric_accuracy = torchmetrics.Accuracy( task="multiclass", - num_classes=self.llm.config.vocab_size, + num_classes=self.vocab_size, ) @cache @@ -57,13 +58,15 @@ class LitModule(pl.LightningModule): def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx): outputs, loss = self.llm(**batch, return_dict=True) logits = outputs[..., :-1, :] + logits = logits.contiguous().view(-1, logits.size(-1)) labels = batch["labels"][..., 1:] - + labels = labels.contiguous().view(-1) + label_mask = labels < self.vocab_size + logits = logits[label_mask] + labels = labels[label_mask] + self.metric_accuracy.update(logits, labels) self.metric_loss.update(loss) - label_mask = labels != 0 - self.metric_accuracy.update(logits[label_mask], labels[label_mask]) - def on_validation_epoch_end(self) -> None: self.log("val_loss", self.metric_loss, rank_zero_only=True) self.log("accuracy", self.metric_accuracy, rank_zero_only=True) diff --git a/wit/modeling_wit.py b/wit/modeling_wit.py index 1122362..ff10fbd 100644 --- a/wit/modeling_wit.py +++ b/wit/modeling_wit.py @@ -348,11 +348,11 @@ class QwenRunner: shift_labels = labels[..., 1:].contiguous().view(-1) shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - mask = shift_labels != 0 + mask = shift_labels < self.qwen.config.vocab_size shift_labels = shift_labels[mask] shift_logits = shift_logits[mask] - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits, shift_labels) + loss = CrossEntropyLoss()(shift_logits, shift_labels) + return lm_logits, loss def prepareInput(self, tokenizer, query, query_assistant, history, system):