Refine train.
This commit is contained in:
		
							parent
							
								
									69cb525ab0
								
							
						
					
					
						commit
						1b8007e1c3
					
				|  | @ -1,7 +1,6 @@ | ||||||
| import torchvision | import torchvision | ||||||
| import torch | import torch | ||||||
| from torch import nn | from torch import nn | ||||||
| import cv2 |  | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| 
 | 
 | ||||||
|  | @ -12,7 +11,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||||||
| weights_path = "weights" | weights_path = "weights" | ||||||
| model = resnet | model = resnet | ||||||
| 
 | 
 | ||||||
| model.train().cuda() | model.train().to(device) | ||||||
| 
 | 
 | ||||||
| criterion = nn.CrossEntropyLoss() | criterion = nn.CrossEntropyLoss() | ||||||
| optimizer = torch.optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) | optimizer = torch.optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) | ||||||
|  | @ -21,13 +20,26 @@ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20) | ||||||
| # with torch.no_grad(): | # with torch.no_grad(): | ||||||
| img_ori = np.ones([1, 3, 224, 224]) | img_ori = np.ones([1, 3, 224, 224]) | ||||||
| img_ori = np.float32(img_ori) / 255 | img_ori = np.float32(img_ori) / 255 | ||||||
| img_ori = torch.tensor(img_ori).cuda() | img_ori = torch.tensor(img_ori).to(device) | ||||||
| output = model(img_ori) | output = model(img_ori) | ||||||
| 
 | 
 | ||||||
| target = torch.ones([1]).to(torch.int64).cuda() | target = torch.ones([1]).to(torch.int64) | ||||||
|  | target = target.to(device) | ||||||
| 
 | 
 | ||||||
| optimizer.zero_grad() | optimizer.zero_grad() | ||||||
| loss = F.nll_loss(output, target) | loss = F.nll_loss(output, target) | ||||||
| loss.backward() | loss.backward() | ||||||
|  | 
 | ||||||
|  | params = list(model.parameters()) | ||||||
|  | named_params = dict(model.named_parameters()) | ||||||
|  | print(model) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # import visdom | ||||||
|  | # viz = visdom.Visdom() | ||||||
|  | # # viz.heatmap(img_ori) | ||||||
|  | # viz.image(img_ori) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| optimizer.step() | optimizer.step() | ||||||
| print(loss) | print(loss) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue