Add model.conv1.weight normal after update grad.
This commit is contained in:
parent
6a0b47c674
commit
5b2cd4da61
|
@ -1,2 +1,4 @@
|
||||||
dump1
|
dump1
|
||||||
dump2
|
dump2
|
||||||
|
*.png
|
||||||
|
*.log
|
|
@ -5,15 +5,23 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F # Add this line
|
import torch.nn.functional as F # Add this line
|
||||||
import torchvision
|
import torchvision
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
sys.path.append("..")
|
sys.path.append("..")
|
||||||
from tools import show
|
from tools import show
|
||||||
|
|
||||||
seed = 4321
|
seed = 42
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed_all(seed)
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
# device = torch.device("cpu")
|
||||||
# device = torch.device("mps")
|
# device = torch.device("mps")
|
||||||
|
|
||||||
num_epochs = 1
|
num_epochs = 1
|
||||||
|
@ -60,15 +68,12 @@ class ConvNet(nn.Module):
|
||||||
|
|
||||||
def printFector(self, x, label, dir=""):
|
def printFector(self, x, label, dir=""):
|
||||||
show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]), dir + "/input_image.png", Contrast=[0, 1.0])
|
show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]), dir + "/input_image.png", Contrast=[0, 1.0])
|
||||||
# show.DumpTensorToLog(x, "input_image.log")
|
|
||||||
|
|
||||||
w = self.normal_conv1_weight()
|
w = self.normal_conv1_weight()
|
||||||
x = torch.conv2d(x, w)
|
x = torch.conv2d(x, w)
|
||||||
show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), dir + "/conv1_weight.png")
|
show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), dir + "/conv1_weight.png")
|
||||||
# show.DumpTensorToLog(w, "conv1_weight.log")
|
|
||||||
|
|
||||||
show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]), dir + "/conv1_output.png")
|
show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]), dir + "/conv1_output.png")
|
||||||
# show.DumpTensorToLog(x, "conv1_output.png")
|
|
||||||
|
|
||||||
x = self.pool(x)
|
x = self.pool(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
|
@ -132,16 +137,17 @@ for epoch in range(epochs):
|
||||||
model.conv1.weight.grad = None
|
model.conv1.weight.grad = None
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
model.conv1.weight.data = model.conv1.weight.data - model.conv1.weight.grad * 10000
|
model.conv1.weight.data = model.conv1.weight.data - model.conv1.weight.grad * 1000
|
||||||
|
model.conv1.weight.data = model.normal_conv1_weight()
|
||||||
|
|
||||||
if (i + 1) % 100 == 0:
|
if (i + 1) % 100 == 0:
|
||||||
print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.8f}")
|
print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.8f}")
|
||||||
|
|
||||||
show.DumpTensorToImage(images.view(-1, images.shape[2], images.shape[3]), "input_image.png", Contrast=[0, 1.0])
|
show.DumpTensorToImage(images.view(-1, images.shape[2], images.shape[3]), "input_image.png", Contrast=[0, 1.0])
|
||||||
g = model.conv1.weight.grad
|
g = model.conv1.weight.grad
|
||||||
show.DumpTensorToImage(g.view(-1, g.shape[2], g.shape[3]).cpu(), "conv1_weight_grad.png")
|
show.DumpTensorToImage(g.view(-1, g.shape[2], g.shape[3]).cpu(), "conv1_weight_grad.png", Value2Log=True)
|
||||||
w = model.conv1.weight.data
|
w = model.conv1.weight.data
|
||||||
show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), "conv1_weight_update.png")
|
show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), "conv1_weight_update.png", Value2Log=True)
|
||||||
|
|
||||||
# model.conv1.weight.data = torch.rand(model.conv1.weight.data.shape, device=device)
|
# model.conv1.weight.data = torch.rand(model.conv1.weight.data.shape, device=device)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue