diff --git a/unsuper/conv1_weight.png b/unsuper/conv1_weight.png deleted file mode 100644 index aec14cf..0000000 Binary files a/unsuper/conv1_weight.png and /dev/null differ diff --git a/unsuper/dump1/conv1_output.png b/unsuper/dump1/conv1_output.png index 034f0a7..efc4466 100644 Binary files a/unsuper/dump1/conv1_output.png and b/unsuper/dump1/conv1_output.png differ diff --git a/unsuper/dump1/conv1_weight.png b/unsuper/dump1/conv1_weight.png index 4ee741a..7cf3c24 100644 Binary files a/unsuper/dump1/conv1_weight.png and b/unsuper/dump1/conv1_weight.png differ diff --git a/unsuper/dump1/conv2_output.png b/unsuper/dump1/conv2_output.png index fb9737a..a9b7470 100644 Binary files a/unsuper/dump1/conv2_output.png and b/unsuper/dump1/conv2_output.png differ diff --git a/unsuper/dump1/conv2_weight.png b/unsuper/dump1/conv2_weight.png index f9bd915..b364492 100644 Binary files a/unsuper/dump1/conv2_weight.png and b/unsuper/dump1/conv2_weight.png differ diff --git a/unsuper/dump1/conv2_weight_grad.png b/unsuper/dump1/conv2_weight_grad.png index 63e0b78..b1a36c2 100644 Binary files a/unsuper/dump1/conv2_weight_grad.png and b/unsuper/dump1/conv2_weight_grad.png differ diff --git a/unsuper/dump1/fc_output.png b/unsuper/dump1/fc_output.png index 8a0200f..aa91f6f 100644 Binary files a/unsuper/dump1/fc_output.png and b/unsuper/dump1/fc_output.png differ diff --git a/unsuper/dump1/fc_weight.png b/unsuper/dump1/fc_weight.png index 3d66203..43ef3f9 100644 Binary files a/unsuper/dump1/fc_weight.png and b/unsuper/dump1/fc_weight.png differ diff --git a/unsuper/dump1/fc_weight_grad.png b/unsuper/dump1/fc_weight_grad.png index f86c1dc..fa41304 100644 Binary files a/unsuper/dump1/fc_weight_grad.png and b/unsuper/dump1/fc_weight_grad.png differ diff --git a/unsuper/dump1/pool_output.png b/unsuper/dump1/pool_output.png index 49eb371..41f6562 100644 Binary files a/unsuper/dump1/pool_output.png and b/unsuper/dump1/pool_output.png differ diff --git a/unsuper/dump2/conv1_output.png b/unsuper/dump2/conv1_output.png index d13915c..c01819e 100644 Binary files a/unsuper/dump2/conv1_output.png and b/unsuper/dump2/conv1_output.png differ diff --git a/unsuper/dump2/conv1_weight.png b/unsuper/dump2/conv1_weight.png index 4ee741a..7cf3c24 100644 Binary files a/unsuper/dump2/conv1_weight.png and b/unsuper/dump2/conv1_weight.png differ diff --git a/unsuper/dump2/conv2_output.png b/unsuper/dump2/conv2_output.png index 58df023..6d62a35 100644 Binary files a/unsuper/dump2/conv2_output.png and b/unsuper/dump2/conv2_output.png differ diff --git a/unsuper/dump2/conv2_weight.png b/unsuper/dump2/conv2_weight.png index f9bd915..b364492 100644 Binary files a/unsuper/dump2/conv2_weight.png and b/unsuper/dump2/conv2_weight.png differ diff --git a/unsuper/dump2/conv2_weight_grad.png b/unsuper/dump2/conv2_weight_grad.png index 0a9ae3a..4efbca8 100644 Binary files a/unsuper/dump2/conv2_weight_grad.png and b/unsuper/dump2/conv2_weight_grad.png differ diff --git a/unsuper/dump2/fc_output.png b/unsuper/dump2/fc_output.png index d078719..e9675c0 100644 Binary files a/unsuper/dump2/fc_output.png and b/unsuper/dump2/fc_output.png differ diff --git a/unsuper/dump2/fc_weight.png b/unsuper/dump2/fc_weight.png index 3d66203..43ef3f9 100644 Binary files a/unsuper/dump2/fc_weight.png and b/unsuper/dump2/fc_weight.png differ diff --git a/unsuper/dump2/fc_weight_grad.png b/unsuper/dump2/fc_weight_grad.png index f7041e8..0e6c057 100644 Binary files a/unsuper/dump2/fc_weight_grad.png and b/unsuper/dump2/fc_weight_grad.png differ diff --git a/unsuper/dump2/pool_output.png b/unsuper/dump2/pool_output.png index c08cd8d..40efe5b 100644 Binary files a/unsuper/dump2/pool_output.png and b/unsuper/dump2/pool_output.png differ diff --git a/unsuper/minist.py b/unsuper/minist.py index 8a1dec2..4421692 100644 --- a/unsuper/minist.py +++ b/unsuper/minist.py @@ -125,14 +125,14 @@ for epoch in range(epochs): loss.backward() optimizer_unsuper.step() if (i + 1) % 100 == 0: - print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.8f}") + print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.8f}") # Train the model model.conv1.weight.requires_grad = False model.conv2.weight.requires_grad = True model.fc1.weight.requires_grad = True criterion = nn.CrossEntropyLoss() -optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.6) +optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.2) n_total_steps = len(train_loader) for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader):