Witllm/test/loss.py

69 lines
1.7 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
import math
import torchmetrics
# shift_logits = torch.zeros((16, 4096))
# shift_logits[:, 2] = 10.0
# shift_labels = (torch.ones(16) * 2).long()
# loss = CrossEntropyLoss()(shift_logits, shift_labels)
# print(loss)
# loss = nn.CrossEntropyLoss()
# input = torch.tensor([[1.0, 2.0, 3.0]])
# target = torch.tensor([0]).long()
# output = loss(input, target)
# print(output)
target = torch.tensor([0, 1, 2])
preds = torch.tensor([[0.1, 0.9, 0], [0.3, 10.1, 0.6], [0.2, 0.3, 0.9]])
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3)
accur = accuracy(preds, target)
metric_accuracy = torchmetrics.Accuracy(
task="multiclass",
num_classes=4096,
)
shift_logits = torch.zeros((16, 2, 4096))
shift_logits[:8, :, 2] = 10.0
shift_labels = (torch.ones((16, 2)) * 2).long()
label_mask = shift_labels != 4096
shift_logits = shift_logits[label_mask]
shift_labels = shift_labels[label_mask]
accur = metric_accuracy(shift_logits, shift_labels)
metric_accuracy.update(shift_logits, shift_labels)
# torch.manual_seed(32)
# criterion = nn.CrossEntropyLoss()
# output = torch.randn(1, 5)
# label = torch.ones(1, dtype=torch.long)*3
# loss = criterion(output, label)
# print("网络输出为5类:")
# print(output)
# print("要计算label的类别:")
# print(label)
# print("计算loss的结果:")
# print(loss)
# first = 0
# for i in range(1):
# first = -output[i][label[i]]
# second = 0
# for i in range(1):
# for j in range(5):
# second += math.exp(output[i][j])
# res = 0
# res = (first + math.log(second))
# print("自己的计算结果:")
# print(res)