diff --git a/unsuper/minist.py b/unsuper/minist.py index cdef917..f7350a5 100644 --- a/unsuper/minist.py +++ b/unsuper/minist.py @@ -106,6 +106,31 @@ class ConvNet(nn.Module): ) +def RandomImage(): + images = torch.ones((1, 1, 5, 5), device=device) + type = random.randint(0, 3) + if type == 0: + rand = random.randint(0, 4) + images[:, :, rand, :] = images[:, :, rand, :] * 0.5 + if type == 1: + rand = random.randint(0, 4) + images[:, :, :, rand] = images[:, :, :, rand] * 0.5 + if type == 2: + images[:, :, 0, 0] = images[:, :, 0, 0] * 0.5 + images[:, :, 1, 1] = images[:, :, 1, 1] * 0.5 + images[:, :, 2, 2] = images[:, :, 2, 2] * 0.5 + images[:, :, 3, 3] = images[:, :, 3, 3] * 0.5 + images[:, :, 4, 4] = images[:, :, 4, 4] * 0.5 + if type == 3: + randx = random.randint(1, 3) + randy = random.randint(1, 3) + images[:, :, randx, randy] = images[:, :, randx, randy] * 0.5 + images[:, :, randx, randy + 1] = images[:, :, randx, randy + 1] * 0.5 + images[:, :, randx, randy - 1] = images[:, :, randx, randy - 1] * 0.5 + images[:, :, randx + 1, randy] = images[:, :, randx + 1, randy] * 0.5 + images[:, :, randx - 1, randy] = images[:, :, randx - 1, randy] * 0.5 + return images + model = ConvNet().to(device) model.train() @@ -115,29 +140,7 @@ n_total_steps = len(train_loader) for epoch in range(epochs): for i, (images, labels) in enumerate(train_loader): images = images.to(device) - - # images = torch.ones((1, 1, 5, 5), device=device) - # type = random.randint(0, 3) - # if type == 0: - # rand = random.randint(0, 4) - # images[:, :, rand, :] = images[:, :, rand, :] * 0.5 - # if type == 1: - # rand = random.randint(0, 4) - # images[:, :, :, rand] = images[:, :, :, rand] * 0.5 - # if type == 2: - # images[:, :, 0, 0] = images[:, :, 0, 0] * 0.5 - # images[:, :, 1, 1] = images[:, :, 1, 1] * 0.5 - # images[:, :, 2, 2] = images[:, :, 2, 2] * 0.5 - # images[:, :, 3, 3] = images[:, :, 3, 3] * 0.5 - # images[:, :, 4, 4] = images[:, :, 4, 4] * 0.5 - # if type == 3: - # randx = random.randint(1, 3) - # randy = random.randint(1, 3) - # images[:, :, randx, randy] = images[:, :, randx, randy] * 0.5 - # images[:, :, randx, randy + 1] = images[:, :, randx, randy + 1] * 0.5 - # images[:, :, randx, randy - 1] = images[:, :, randx, randy - 1] * 0.5 - # images[:, :, randx + 1, randy] = images[:, :, randx + 1, randy] * 0.5 - # images[:, :, randx - 1, randy] = images[:, :, randx - 1, randy] * 0.5 + # images = RandomImage() outputs = model.forward_unsuper(images) outputs = outputs.permute(0, 2, 3, 1) # 64 8 24 24 -> 64 24 24 8