Regine minist unsuper.

This commit is contained in:
Colin 2025-02-17 14:27:15 +08:00
parent f74a5d29bd
commit cdee69bf54
1 changed files with 26 additions and 23 deletions

View File

@ -106,6 +106,31 @@ class ConvNet(nn.Module):
) )
def RandomImage():
images = torch.ones((1, 1, 5, 5), device=device)
type = random.randint(0, 3)
if type == 0:
rand = random.randint(0, 4)
images[:, :, rand, :] = images[:, :, rand, :] * 0.5
if type == 1:
rand = random.randint(0, 4)
images[:, :, :, rand] = images[:, :, :, rand] * 0.5
if type == 2:
images[:, :, 0, 0] = images[:, :, 0, 0] * 0.5
images[:, :, 1, 1] = images[:, :, 1, 1] * 0.5
images[:, :, 2, 2] = images[:, :, 2, 2] * 0.5
images[:, :, 3, 3] = images[:, :, 3, 3] * 0.5
images[:, :, 4, 4] = images[:, :, 4, 4] * 0.5
if type == 3:
randx = random.randint(1, 3)
randy = random.randint(1, 3)
images[:, :, randx, randy] = images[:, :, randx, randy] * 0.5
images[:, :, randx, randy + 1] = images[:, :, randx, randy + 1] * 0.5
images[:, :, randx, randy - 1] = images[:, :, randx, randy - 1] * 0.5
images[:, :, randx + 1, randy] = images[:, :, randx + 1, randy] * 0.5
images[:, :, randx - 1, randy] = images[:, :, randx - 1, randy] * 0.5
return images
model = ConvNet().to(device) model = ConvNet().to(device)
model.train() model.train()
@ -115,29 +140,7 @@ n_total_steps = len(train_loader)
for epoch in range(epochs): for epoch in range(epochs):
for i, (images, labels) in enumerate(train_loader): for i, (images, labels) in enumerate(train_loader):
images = images.to(device) images = images.to(device)
# images = RandomImage()
# images = torch.ones((1, 1, 5, 5), device=device)
# type = random.randint(0, 3)
# if type == 0:
# rand = random.randint(0, 4)
# images[:, :, rand, :] = images[:, :, rand, :] * 0.5
# if type == 1:
# rand = random.randint(0, 4)
# images[:, :, :, rand] = images[:, :, :, rand] * 0.5
# if type == 2:
# images[:, :, 0, 0] = images[:, :, 0, 0] * 0.5
# images[:, :, 1, 1] = images[:, :, 1, 1] * 0.5
# images[:, :, 2, 2] = images[:, :, 2, 2] * 0.5
# images[:, :, 3, 3] = images[:, :, 3, 3] * 0.5
# images[:, :, 4, 4] = images[:, :, 4, 4] * 0.5
# if type == 3:
# randx = random.randint(1, 3)
# randy = random.randint(1, 3)
# images[:, :, randx, randy] = images[:, :, randx, randy] * 0.5
# images[:, :, randx, randy + 1] = images[:, :, randx, randy + 1] * 0.5
# images[:, :, randx, randy - 1] = images[:, :, randx, randy - 1] * 0.5
# images[:, :, randx + 1, randy] = images[:, :, randx + 1, randy] * 0.5
# images[:, :, randx - 1, randy] = images[:, :, randx - 1, randy] * 0.5
outputs = model.forward_unsuper(images) outputs = model.forward_unsuper(images)
outputs = outputs.permute(0, 2, 3, 1) # 64 8 24 24 -> 64 24 24 8 outputs = outputs.permute(0, 2, 3, 1) # 64 8 24 24 -> 64 24 24 8