Add train resnet test.
This commit is contained in:
		
							parent
							
								
									467c78d83d
								
							
						
					
					
						commit
						08f7b75efe
					
				| 
						 | 
					@ -0,0 +1,33 @@
 | 
				
			||||||
 | 
					import torchvision
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					import cv2
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					resnet = torchvision.models.resnet152(pretrained=False)
 | 
				
			||||||
 | 
					resnet.fc = torch.nn.Linear(2048, 10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 | 
				
			||||||
 | 
					weights_path = "weights"
 | 
				
			||||||
 | 
					model = resnet
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					model.train().cuda()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					criterion = nn.CrossEntropyLoss()
 | 
				
			||||||
 | 
					optimizer = torch.optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
 | 
				
			||||||
 | 
					scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# with torch.no_grad():
 | 
				
			||||||
 | 
					img_ori = np.ones([1, 3, 224, 224])
 | 
				
			||||||
 | 
					img_ori = np.float32(img_ori) / 255
 | 
				
			||||||
 | 
					img_ori = torch.tensor(img_ori).cuda()
 | 
				
			||||||
 | 
					output = model(img_ori)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					target = torch.ones([1]).to(torch.int64).cuda()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					optimizer.zero_grad()
 | 
				
			||||||
 | 
					loss = F.nll_loss(output, target)
 | 
				
			||||||
 | 
					loss.backward()
 | 
				
			||||||
 | 
					optimizer.step()
 | 
				
			||||||
 | 
					print(loss)
 | 
				
			||||||
		Loading…
	
		Reference in New Issue