Add loss and logger code.
This commit is contained in:
parent
9e8e92ae25
commit
cf726a5b9f
|
@ -0,0 +1,45 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
import math
|
||||
|
||||
shift_logits = torch.zeros((16, 4096))
|
||||
shift_logits[:, 2] = 10.0
|
||||
shift_labels = (torch.ones(16) * 2).long()
|
||||
loss = CrossEntropyLoss()(shift_logits, shift_labels)
|
||||
print(loss)
|
||||
|
||||
|
||||
loss = nn.CrossEntropyLoss()
|
||||
input = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
target = torch.tensor([0]).long()
|
||||
output = loss(input, target)
|
||||
print(output)
|
||||
|
||||
|
||||
# torch.manual_seed(32)
|
||||
# criterion = nn.CrossEntropyLoss()
|
||||
# output = torch.randn(1, 5)
|
||||
# label = torch.ones(1, dtype=torch.long)*3
|
||||
# loss = criterion(output, label)
|
||||
|
||||
# print("网络输出为5类:")
|
||||
# print(output)
|
||||
# print("要计算label的类别:")
|
||||
# print(label)
|
||||
# print("计算loss的结果:")
|
||||
# print(loss)
|
||||
|
||||
# first = 0
|
||||
# for i in range(1):
|
||||
# first = -output[i][label[i]]
|
||||
# second = 0
|
||||
# for i in range(1):
|
||||
# for j in range(5):
|
||||
# second += math.exp(output[i][j])
|
||||
# res = 0
|
||||
# res = (first + math.log(second))
|
||||
# print("自己的计算结果:")
|
||||
# print(res)
|
|
@ -0,0 +1,25 @@
|
|||
from pytorch_lightning import loggers
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
|
||||
class TBLogger(loggers.TensorBoardLogger):
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics, step):
|
||||
metrics.pop("epoch", None)
|
||||
return super().log_metrics(metrics, step)
|
||||
|
||||
|
||||
class WBLogger(loggers.WandbLogger):
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics, step):
|
||||
metrics.pop("epoch", None)
|
||||
return super().log_metrics(metrics, step)
|
||||
|
||||
|
||||
class MLFLogger(loggers.MLFlowLogger):
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics, step):
|
||||
metrics.pop("epoch", None)
|
||||
return super().log_metrics(metrics, step)
|
Loading…
Reference in New Issue