From fdc8c657b31c24526ac730e2afbcb5ea39f58b5f Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 5 Mar 2024 19:30:15 +0800 Subject: [PATCH] Add accurancy in loss. --- test/loss.py | 43 +++++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/test/loss.py b/test/loss.py index a5c5f78..8df1d6a 100644 --- a/test/loss.py +++ b/test/loss.py @@ -4,19 +4,42 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import CrossEntropyLoss import math +import torchmetrics -shift_logits = torch.zeros((16, 4096)) -shift_logits[:, 2] = 10.0 -shift_labels = (torch.ones(16) * 2).long() -loss = CrossEntropyLoss()(shift_logits, shift_labels) -print(loss) +# shift_logits = torch.zeros((16, 4096)) +# shift_logits[:, 2] = 10.0 +# shift_labels = (torch.ones(16) * 2).long() +# loss = CrossEntropyLoss()(shift_logits, shift_labels) +# print(loss) -loss = nn.CrossEntropyLoss() -input = torch.tensor([[1.0, 2.0, 3.0]]) -target = torch.tensor([0]).long() -output = loss(input, target) -print(output) +# loss = nn.CrossEntropyLoss() +# input = torch.tensor([[1.0, 2.0, 3.0]]) +# target = torch.tensor([0]).long() +# output = loss(input, target) +# print(output) + + +target = torch.tensor([0, 1, 2]) +preds = torch.tensor([[0.1, 0.9, 0], [0.3, 10.1, 0.6], [0.2, 0.3, 0.9]]) +accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3) +accur = accuracy(preds, target) + + +metric_accuracy = torchmetrics.Accuracy( + task="multiclass", + num_classes=4096, +) +shift_logits = torch.zeros((16, 2, 4096)) +shift_logits[:8, :, 2] = 10.0 +shift_labels = (torch.ones((16, 2)) * 2).long() + +label_mask = shift_labels != 4096 +shift_logits = shift_logits[label_mask] +shift_labels = shift_labels[label_mask] + +accur = metric_accuracy(shift_logits, shift_labels) +metric_accuracy.update(shift_logits, shift_labels) # torch.manual_seed(32)