Add train resnet test.
This commit is contained in:
parent
467c78d83d
commit
08f7b75efe
|
@ -0,0 +1,33 @@
|
|||
import torchvision
|
||||
import torch
|
||||
from torch import nn
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
resnet = torchvision.models.resnet152(pretrained=False)
|
||||
resnet.fc = torch.nn.Linear(2048, 10)
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
weights_path = "weights"
|
||||
model = resnet
|
||||
|
||||
model.train().cuda()
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20)
|
||||
|
||||
# with torch.no_grad():
|
||||
img_ori = np.ones([1, 3, 224, 224])
|
||||
img_ori = np.float32(img_ori) / 255
|
||||
img_ori = torch.tensor(img_ori).cuda()
|
||||
output = model(img_ori)
|
||||
|
||||
target = torch.ones([1]).to(torch.int64).cuda()
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print(loss)
|
Loading…
Reference in New Issue