Add loss and logger code.
This commit is contained in:
		
							parent
							
								
									9e8e92ae25
								
							
						
					
					
						commit
						cf726a5b9f
					
				|  | @ -0,0 +1,45 @@ | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  | import torch.utils.checkpoint | ||||||
|  | from torch.nn import CrossEntropyLoss | ||||||
|  | import math | ||||||
|  | 
 | ||||||
|  | shift_logits = torch.zeros((16, 4096)) | ||||||
|  | shift_logits[:, 2] = 10.0 | ||||||
|  | shift_labels = (torch.ones(16) * 2).long() | ||||||
|  | loss = CrossEntropyLoss()(shift_logits, shift_labels) | ||||||
|  | print(loss) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | loss = nn.CrossEntropyLoss() | ||||||
|  | input = torch.tensor([[1.0, 2.0, 3.0]]) | ||||||
|  | target = torch.tensor([0]).long() | ||||||
|  | output = loss(input, target) | ||||||
|  | print(output) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # torch.manual_seed(32) | ||||||
|  | # criterion = nn.CrossEntropyLoss() | ||||||
|  | # output = torch.randn(1, 5) | ||||||
|  | # label = torch.ones(1, dtype=torch.long)*3 | ||||||
|  | # loss = criterion(output, label) | ||||||
|  | 
 | ||||||
|  | # print("网络输出为5类:") | ||||||
|  | # print(output) | ||||||
|  | # print("要计算label的类别:") | ||||||
|  | # print(label) | ||||||
|  | # print("计算loss的结果:") | ||||||
|  | # print(loss) | ||||||
|  | 
 | ||||||
|  | # first = 0 | ||||||
|  | # for i in range(1): | ||||||
|  | #     first = -output[i][label[i]] | ||||||
|  | # second = 0 | ||||||
|  | # for i in range(1): | ||||||
|  | #     for j in range(5): | ||||||
|  | #         second += math.exp(output[i][j]) | ||||||
|  | # res = 0 | ||||||
|  | # res = (first + math.log(second)) | ||||||
|  | # print("自己的计算结果:") | ||||||
|  | # print(res) | ||||||
|  | @ -0,0 +1,25 @@ | ||||||
|  | from pytorch_lightning import loggers | ||||||
|  | from pytorch_lightning.utilities import rank_zero_only | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class TBLogger(loggers.TensorBoardLogger): | ||||||
|  |     @rank_zero_only | ||||||
|  |     def log_metrics(self, metrics, step): | ||||||
|  |         metrics.pop("epoch", None) | ||||||
|  |         return super().log_metrics(metrics, step) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class WBLogger(loggers.WandbLogger): | ||||||
|  | 
 | ||||||
|  |     @rank_zero_only | ||||||
|  |     def log_metrics(self, metrics, step): | ||||||
|  |         metrics.pop("epoch", None) | ||||||
|  |         return super().log_metrics(metrics, step) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class MLFLogger(loggers.MLFlowLogger): | ||||||
|  | 
 | ||||||
|  |     @rank_zero_only | ||||||
|  |     def log_metrics(self, metrics, step): | ||||||
|  |         metrics.pop("epoch", None) | ||||||
|  |         return super().log_metrics(metrics, step) | ||||||
		Loading…
	
		Reference in New Issue