Witllm/test/train.py

46 lines
1.0 KiB
Python

import torchvision
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
resnet = torchvision.models.resnet152(pretrained=False)
resnet.fc = torch.nn.Linear(2048, 10)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
weights_path = "weights"
model = resnet
model.train().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20)
# with torch.no_grad():
img_ori = np.ones([1, 3, 224, 224])
img_ori = np.float32(img_ori) / 255
img_ori = torch.tensor(img_ori).to(device)
output = model(img_ori)
target = torch.ones([1]).to(torch.int64)
target = target.to(device)
optimizer.zero_grad()
loss = F.nll_loss(output, target)
loss.backward()
params = list(model.parameters())
named_params = dict(model.named_parameters())
print(model)
# import visdom
# viz = visdom.Visdom()
# # viz.heatmap(img_ori)
# viz.image(img_ori)
optimizer.step()
print(loss)