Add loss and logger code.
This commit is contained in:
		
							parent
							
								
									9e8e92ae25
								
							
						
					
					
						commit
						cf726a5b9f
					
				|  | @ -0,0 +1,45 @@ | |||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| import torch.utils.checkpoint | ||||
| from torch.nn import CrossEntropyLoss | ||||
| import math | ||||
| 
 | ||||
| shift_logits = torch.zeros((16, 4096)) | ||||
| shift_logits[:, 2] = 10.0 | ||||
| shift_labels = (torch.ones(16) * 2).long() | ||||
| loss = CrossEntropyLoss()(shift_logits, shift_labels) | ||||
| print(loss) | ||||
| 
 | ||||
| 
 | ||||
| loss = nn.CrossEntropyLoss() | ||||
| input = torch.tensor([[1.0, 2.0, 3.0]]) | ||||
| target = torch.tensor([0]).long() | ||||
| output = loss(input, target) | ||||
| print(output) | ||||
| 
 | ||||
| 
 | ||||
| # torch.manual_seed(32) | ||||
| # criterion = nn.CrossEntropyLoss() | ||||
| # output = torch.randn(1, 5) | ||||
| # label = torch.ones(1, dtype=torch.long)*3 | ||||
| # loss = criterion(output, label) | ||||
| 
 | ||||
| # print("网络输出为5类:") | ||||
| # print(output) | ||||
| # print("要计算label的类别:") | ||||
| # print(label) | ||||
| # print("计算loss的结果:") | ||||
| # print(loss) | ||||
| 
 | ||||
| # first = 0 | ||||
| # for i in range(1): | ||||
| #     first = -output[i][label[i]] | ||||
| # second = 0 | ||||
| # for i in range(1): | ||||
| #     for j in range(5): | ||||
| #         second += math.exp(output[i][j]) | ||||
| # res = 0 | ||||
| # res = (first + math.log(second)) | ||||
| # print("自己的计算结果:") | ||||
| # print(res) | ||||
|  | @ -0,0 +1,25 @@ | |||
| from pytorch_lightning import loggers | ||||
| from pytorch_lightning.utilities import rank_zero_only | ||||
| 
 | ||||
| 
 | ||||
| class TBLogger(loggers.TensorBoardLogger): | ||||
|     @rank_zero_only | ||||
|     def log_metrics(self, metrics, step): | ||||
|         metrics.pop("epoch", None) | ||||
|         return super().log_metrics(metrics, step) | ||||
| 
 | ||||
| 
 | ||||
| class WBLogger(loggers.WandbLogger): | ||||
| 
 | ||||
|     @rank_zero_only | ||||
|     def log_metrics(self, metrics, step): | ||||
|         metrics.pop("epoch", None) | ||||
|         return super().log_metrics(metrics, step) | ||||
| 
 | ||||
| 
 | ||||
| class MLFLogger(loggers.MLFlowLogger): | ||||
| 
 | ||||
|     @rank_zero_only | ||||
|     def log_metrics(self, metrics, step): | ||||
|         metrics.pop("epoch", None) | ||||
|         return super().log_metrics(metrics, step) | ||||
		Loading…
	
		Reference in New Issue