diff --git a/test/train.py b/test/train.py new file mode 100644 index 0000000..5778d04 --- /dev/null +++ b/test/train.py @@ -0,0 +1,33 @@ +import torchvision +import torch +from torch import nn +import cv2 +import numpy as np +import torch.nn.functional as F + +resnet = torchvision.models.resnet152(pretrained=False) +resnet.fc = torch.nn.Linear(2048, 10) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +weights_path = "weights" +model = resnet + +model.train().cuda() + +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20) + +# with torch.no_grad(): +img_ori = np.ones([1, 3, 224, 224]) +img_ori = np.float32(img_ori) / 255 +img_ori = torch.tensor(img_ori).cuda() +output = model(img_ori) + +target = torch.ones([1]).to(torch.int64).cuda() + +optimizer.zero_grad() +loss = F.nll_loss(output, target) +loss.backward() +optimizer.step() +print(loss)