Refine train.
This commit is contained in:
parent
69cb525ab0
commit
1b8007e1c3
|
@ -1,7 +1,6 @@
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
@ -12,7 +11,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
weights_path = "weights"
|
weights_path = "weights"
|
||||||
model = resnet
|
model = resnet
|
||||||
|
|
||||||
model.train().cuda()
|
model.train().to(device)
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
|
||||||
|
@ -21,13 +20,26 @@ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20)
|
||||||
# with torch.no_grad():
|
# with torch.no_grad():
|
||||||
img_ori = np.ones([1, 3, 224, 224])
|
img_ori = np.ones([1, 3, 224, 224])
|
||||||
img_ori = np.float32(img_ori) / 255
|
img_ori = np.float32(img_ori) / 255
|
||||||
img_ori = torch.tensor(img_ori).cuda()
|
img_ori = torch.tensor(img_ori).to(device)
|
||||||
output = model(img_ori)
|
output = model(img_ori)
|
||||||
|
|
||||||
target = torch.ones([1]).to(torch.int64).cuda()
|
target = torch.ones([1]).to(torch.int64)
|
||||||
|
target = target.to(device)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = F.nll_loss(output, target)
|
loss = F.nll_loss(output, target)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
params = list(model.parameters())
|
||||||
|
named_params = dict(model.named_parameters())
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
|
||||||
|
# import visdom
|
||||||
|
# viz = visdom.Visdom()
|
||||||
|
# # viz.heatmap(img_ori)
|
||||||
|
# viz.image(img_ori)
|
||||||
|
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
print(loss)
|
print(loss)
|
||||||
|
|
Loading…
Reference in New Issue