Refine label used.

This commit is contained in:
Colin 2024-03-05 22:08:37 +08:00
parent fdc8c657b3
commit 11fc8f1d39
3 changed files with 14 additions and 11 deletions

View File

@ -7,9 +7,9 @@
class QWenConfig:
def __init__(self):
self.vocab_size = 4096
self.hidden_size = 128 # 1024 2048
self.num_hidden_layers = 6 # 12 24
self.num_attention_heads = 8 # 8 16
self.hidden_size = 128 # 128 1024 2048
self.num_hidden_layers = 6 # 6 12 24
self.num_attention_heads = 8 # 8 8 16
self.emb_dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.layer_norm_epsilon = 1e-6

View File

@ -30,9 +30,10 @@ class LitModule(pl.LightningModule):
self.learning_rate = learning_rate
self.use_tril_attention_mask = use_tril_attention_mask
self.metric_loss = torchmetrics.MeanMetric()
self.vocab_size = self.llm.config.vocab_size
self.metric_accuracy = torchmetrics.Accuracy(
task="multiclass",
num_classes=self.llm.config.vocab_size,
num_classes=self.vocab_size,
)
@cache
@ -57,13 +58,15 @@ class LitModule(pl.LightningModule):
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
outputs, loss = self.llm(**batch, return_dict=True)
logits = outputs[..., :-1, :]
logits = logits.contiguous().view(-1, logits.size(-1))
labels = batch["labels"][..., 1:]
labels = labels.contiguous().view(-1)
label_mask = labels < self.vocab_size
logits = logits[label_mask]
labels = labels[label_mask]
self.metric_accuracy.update(logits, labels)
self.metric_loss.update(loss)
label_mask = labels != 0
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
def on_validation_epoch_end(self) -> None:
self.log("val_loss", self.metric_loss, rank_zero_only=True)
self.log("accuracy", self.metric_accuracy, rank_zero_only=True)

View File

@ -348,11 +348,11 @@ class QwenRunner:
shift_labels = labels[..., 1:].contiguous().view(-1)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
mask = shift_labels != 0
mask = shift_labels < self.qwen.config.vocab_size
shift_labels = shift_labels[mask]
shift_logits = shift_logits[mask]
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
loss = CrossEntropyLoss()(shift_logits, shift_labels)
return lm_logits, loss
def prepareInput(self, tokenizer, query, query_assistant, history, system):