import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import CrossEntropyLoss import math import torchmetrics # shift_logits = torch.zeros((16, 4096)) # shift_logits[:, 2] = 10.0 # shift_labels = (torch.ones(16) * 2).long() # loss = CrossEntropyLoss()(shift_logits, shift_labels) # print(loss) # loss = nn.CrossEntropyLoss() # input = torch.tensor([[1.0, 2.0, 3.0]]) # target = torch.tensor([0]).long() # output = loss(input, target) # print(output) # target = torch.tensor([0, 1, 2]) # preds = torch.tensor([[0.1, 0.9, 0], [0.3, 10.1, 0.6], [0.2, 0.3, 0.9]]) # accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3) # accur = accuracy(preds, target) metric_accuracy = torchmetrics.Accuracy( task="multiclass", num_classes=4, ) shift_logits = torch.rand((128, 4)) shift_labels = torch.randint(0, 4, size=(128,)) accur = metric_accuracy(shift_logits, shift_labels) print(accur.numpy()) shift_logits = torch.cat((shift_logits, shift_logits), dim=0) shift_labels = torch.cat((shift_labels, shift_labels), dim=0) accur = metric_accuracy(shift_logits, shift_labels) print(accur.numpy()) print(accur.numpy()) # torch.manual_seed(32) # criterion = nn.CrossEntropyLoss() # output = torch.randn(1, 5) # label = torch.ones(1, dtype=torch.long)*3 # loss = criterion(output, label) # print("网络输出为5类:") # print(output) # print("要计算label的类别:") # print(label) # print("计算loss的结果:") # print(loss) # first = 0 # for i in range(1): # first = -output[i][label[i]] # second = 0 # for i in range(1): # for j in range(5): # second += math.exp(output[i][j]) # res = 0 # res = (first + math.log(second)) # print("自己的计算结果:") # print(res)