2024-03-05 15:54:03 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.utils.checkpoint
|
|
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
import math
|
2024-03-05 19:30:15 +08:00
|
|
|
import torchmetrics
|
2024-03-05 15:54:03 +08:00
|
|
|
|
2024-03-05 19:30:15 +08:00
|
|
|
# shift_logits = torch.zeros((16, 4096))
|
|
|
|
# shift_logits[:, 2] = 10.0
|
|
|
|
# shift_labels = (torch.ones(16) * 2).long()
|
|
|
|
# loss = CrossEntropyLoss()(shift_logits, shift_labels)
|
|
|
|
# print(loss)
|
|
|
|
|
|
|
|
|
|
|
|
# loss = nn.CrossEntropyLoss()
|
|
|
|
# input = torch.tensor([[1.0, 2.0, 3.0]])
|
|
|
|
# target = torch.tensor([0]).long()
|
|
|
|
# output = loss(input, target)
|
|
|
|
# print(output)
|
|
|
|
|
|
|
|
|
2024-03-20 22:27:28 +08:00
|
|
|
# target = torch.tensor([0, 1, 2])
|
|
|
|
# preds = torch.tensor([[0.1, 0.9, 0], [0.3, 10.1, 0.6], [0.2, 0.3, 0.9]])
|
|
|
|
# accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3)
|
|
|
|
# accur = accuracy(preds, target)
|
2024-03-05 19:30:15 +08:00
|
|
|
|
|
|
|
|
|
|
|
metric_accuracy = torchmetrics.Accuracy(
|
|
|
|
task="multiclass",
|
2024-03-20 22:27:28 +08:00
|
|
|
num_classes=4,
|
2024-03-05 19:30:15 +08:00
|
|
|
)
|
2024-03-05 15:54:03 +08:00
|
|
|
|
|
|
|
|
2024-03-20 22:27:28 +08:00
|
|
|
shift_logits = torch.rand((128, 4))
|
|
|
|
shift_labels = torch.randint(0, 4, size=(128,))
|
2024-03-05 19:30:15 +08:00
|
|
|
accur = metric_accuracy(shift_logits, shift_labels)
|
2024-03-20 22:27:28 +08:00
|
|
|
print(accur.numpy())
|
2024-03-05 15:54:03 +08:00
|
|
|
|
2024-03-20 22:27:28 +08:00
|
|
|
shift_logits = torch.cat((shift_logits, shift_logits), dim=0)
|
|
|
|
shift_labels = torch.cat((shift_labels, shift_labels), dim=0)
|
|
|
|
accur = metric_accuracy(shift_logits, shift_labels)
|
|
|
|
print(accur.numpy())
|
|
|
|
print(accur.numpy())
|
2024-03-05 15:54:03 +08:00
|
|
|
|
|
|
|
# torch.manual_seed(32)
|
|
|
|
# criterion = nn.CrossEntropyLoss()
|
|
|
|
# output = torch.randn(1, 5)
|
|
|
|
# label = torch.ones(1, dtype=torch.long)*3
|
|
|
|
# loss = criterion(output, label)
|
|
|
|
|
|
|
|
# print("网络输出为5类:")
|
|
|
|
# print(output)
|
|
|
|
# print("要计算label的类别:")
|
|
|
|
# print(label)
|
|
|
|
# print("计算loss的结果:")
|
|
|
|
# print(loss)
|
|
|
|
|
|
|
|
# first = 0
|
|
|
|
# for i in range(1):
|
|
|
|
# first = -output[i][label[i]]
|
|
|
|
# second = 0
|
|
|
|
# for i in range(1):
|
|
|
|
# for j in range(5):
|
|
|
|
# second += math.exp(output[i][j])
|
|
|
|
# res = 0
|
|
|
|
# res = (first + math.log(second))
|
|
|
|
# print("自己的计算结果:")
|
|
|
|
# print(res)
|