Refine label used.
This commit is contained in:
parent
fdc8c657b3
commit
11fc8f1d39
|
@ -7,9 +7,9 @@
|
|||
class QWenConfig:
|
||||
def __init__(self):
|
||||
self.vocab_size = 4096
|
||||
self.hidden_size = 128 # 1024 2048
|
||||
self.num_hidden_layers = 6 # 12 24
|
||||
self.num_attention_heads = 8 # 8 16
|
||||
self.hidden_size = 128 # 128 1024 2048
|
||||
self.num_hidden_layers = 6 # 6 12 24
|
||||
self.num_attention_heads = 8 # 8 8 16
|
||||
self.emb_dropout_prob = 0.0
|
||||
self.attn_dropout_prob = 0.0
|
||||
self.layer_norm_epsilon = 1e-6
|
||||
|
|
|
@ -30,9 +30,10 @@ class LitModule(pl.LightningModule):
|
|||
self.learning_rate = learning_rate
|
||||
self.use_tril_attention_mask = use_tril_attention_mask
|
||||
self.metric_loss = torchmetrics.MeanMetric()
|
||||
self.vocab_size = self.llm.config.vocab_size
|
||||
self.metric_accuracy = torchmetrics.Accuracy(
|
||||
task="multiclass",
|
||||
num_classes=self.llm.config.vocab_size,
|
||||
num_classes=self.vocab_size,
|
||||
)
|
||||
|
||||
@cache
|
||||
|
@ -57,13 +58,15 @@ class LitModule(pl.LightningModule):
|
|||
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
||||
outputs, loss = self.llm(**batch, return_dict=True)
|
||||
logits = outputs[..., :-1, :]
|
||||
logits = logits.contiguous().view(-1, logits.size(-1))
|
||||
labels = batch["labels"][..., 1:]
|
||||
|
||||
labels = labels.contiguous().view(-1)
|
||||
label_mask = labels < self.vocab_size
|
||||
logits = logits[label_mask]
|
||||
labels = labels[label_mask]
|
||||
self.metric_accuracy.update(logits, labels)
|
||||
self.metric_loss.update(loss)
|
||||
|
||||
label_mask = labels != 0
|
||||
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
self.log("val_loss", self.metric_loss, rank_zero_only=True)
|
||||
self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
|
||||
|
|
|
@ -348,11 +348,11 @@ class QwenRunner:
|
|||
shift_labels = labels[..., 1:].contiguous().view(-1)
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
mask = shift_labels != 0
|
||||
mask = shift_labels < self.qwen.config.vocab_size
|
||||
shift_labels = shift_labels[mask]
|
||||
shift_logits = shift_logits[mask]
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = CrossEntropyLoss()(shift_logits, shift_labels)
|
||||
|
||||
return lm_logits, loss
|
||||
|
||||
def prepareInput(self, tokenizer, query, query_assistant, history, system):
|
||||
|
|
Loading…
Reference in New Issue