Add accurancy in loss.
This commit is contained in:
parent
cf726a5b9f
commit
fdc8c657b3
43
test/loss.py
43
test/loss.py
|
@ -4,19 +4,42 @@ import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
import math
|
import math
|
||||||
|
import torchmetrics
|
||||||
|
|
||||||
shift_logits = torch.zeros((16, 4096))
|
# shift_logits = torch.zeros((16, 4096))
|
||||||
shift_logits[:, 2] = 10.0
|
# shift_logits[:, 2] = 10.0
|
||||||
shift_labels = (torch.ones(16) * 2).long()
|
# shift_labels = (torch.ones(16) * 2).long()
|
||||||
loss = CrossEntropyLoss()(shift_logits, shift_labels)
|
# loss = CrossEntropyLoss()(shift_logits, shift_labels)
|
||||||
print(loss)
|
# print(loss)
|
||||||
|
|
||||||
|
|
||||||
loss = nn.CrossEntropyLoss()
|
# loss = nn.CrossEntropyLoss()
|
||||||
input = torch.tensor([[1.0, 2.0, 3.0]])
|
# input = torch.tensor([[1.0, 2.0, 3.0]])
|
||||||
target = torch.tensor([0]).long()
|
# target = torch.tensor([0]).long()
|
||||||
output = loss(input, target)
|
# output = loss(input, target)
|
||||||
print(output)
|
# print(output)
|
||||||
|
|
||||||
|
|
||||||
|
target = torch.tensor([0, 1, 2])
|
||||||
|
preds = torch.tensor([[0.1, 0.9, 0], [0.3, 10.1, 0.6], [0.2, 0.3, 0.9]])
|
||||||
|
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3)
|
||||||
|
accur = accuracy(preds, target)
|
||||||
|
|
||||||
|
|
||||||
|
metric_accuracy = torchmetrics.Accuracy(
|
||||||
|
task="multiclass",
|
||||||
|
num_classes=4096,
|
||||||
|
)
|
||||||
|
shift_logits = torch.zeros((16, 2, 4096))
|
||||||
|
shift_logits[:8, :, 2] = 10.0
|
||||||
|
shift_labels = (torch.ones((16, 2)) * 2).long()
|
||||||
|
|
||||||
|
label_mask = shift_labels != 4096
|
||||||
|
shift_logits = shift_logits[label_mask]
|
||||||
|
shift_labels = shift_labels[label_mask]
|
||||||
|
|
||||||
|
accur = metric_accuracy(shift_logits, shift_labels)
|
||||||
|
metric_accuracy.update(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
|
||||||
# torch.manual_seed(32)
|
# torch.manual_seed(32)
|
||||||
|
|
Loading…
Reference in New Issue