From f98a951b58c4af3aac6e8057aab4d59265403861 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 22 May 2025 15:23:41 +0800 Subject: [PATCH] Update mnist. --- binary/mnist.py | 140 ++++++++++++++++++++++++++++-------------------- 1 file changed, 83 insertions(+), 57 deletions(-) diff --git a/binary/mnist.py b/binary/mnist.py index d098e58..8f59a4a 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -12,15 +12,17 @@ from torch.utils.data import DataLoader import math import torch.nn.functional as F import numpy as np +from torch.utils.tensorboard import SummaryWriter +import datetime torch.manual_seed(1234) np.random.seed(1234) torch.cuda.manual_seed_all(1234) +batch_size = 16 +lr = 0.001 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# device = "cpu" -# torch.set_num_threads(16) print(f"Using device: {device}") transform = transforms.Compose( @@ -30,7 +32,7 @@ transform = transforms.Compose( train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) -train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) +train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False) @@ -112,28 +114,22 @@ class SimpleCNN(nn.Module): class LutGroup(nn.Module): - def __init__(self, bits, subbits): + def __init__(self, totalBits, groupBits=0, repeat=1): + if not groupBits: + groupBits = totalBits super(LutGroup, self).__init__() - assert (bits % subbits) == 0 - self.weight = nn.Parameter(torch.randn(bits, pow(2, subbits))) - self.bits = bits - self.subbits = subbits + assert (totalBits % groupBits) == 0 + self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits))) + self.totalBits = totalBits + self.groupBits = groupBits + self.repeat = repeat def forward(self, x): + # input [ batch, totalBits ] + # output [ batch, totalBits / groupBits * repeat ] batch = x.shape[0] - x = x.view(batch, -1, self.subbits) - x = Lut.apply(x, self.weight) - return x - - -class LutParallel(nn.Module): - def __init__(self, bits, number): - super(LutParallel, self).__init__() - self.number = number - self.weight = nn.Parameter(torch.randn(number, pow(2, bits))) - - def forward(self, x): - x = x.unsqueeze(1).repeat(1, self.number, 1) + x = x.view(batch, -1, self.groupBits) + x = x.repeat(1, self.repeat, 1) x = Lut.apply(x, self.weight) return x @@ -147,11 +143,13 @@ class SimpleBNN(nn.Module): self.w = nn.Parameter(torch.randn(3, 784)) self.b = nn.Parameter(torch.zeros(3, 784)) - self.lut1 = LutGroup(784 * 8, 8) - self.lut2 = LutGroup(784, 4) - self.lut3 = LutGroup(196, 4) - self.lut4 = LutGroup(49, 7) - self.lut5 = LutParallel(7, 10) + self.lut2_p = LutGroup(784, 4, 2) # pool2 => 14*14*4 + self.lut3_p = LutGroup(14 * 14 * 2, 4, 2) # pool2 => 7*7*4 + self.lut4_p = LutGroup(3 * 3 * 5 * 5 * 4, 9, 2) # conv 3 => 5*5*8 + # self.lut4 = LutGroup(5 * 5 * 8, 8, 8) # conv 3 => 5*5*8 + self.lut5_p = LutGroup(3 * 3 * 3 * 3 * 8, 9, 2) # conv 3 => 3*3*8*2 + # self.lut5 = LutGroup(3 * 3 * 8 * 2, 16, 16) # conv 3 => 3*3*16 + self.lut6_p = LutGroup(3 * 3 * 16, 9, 5) # conv 3 => 80 self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) @@ -171,28 +169,28 @@ class SimpleBNN(nn.Module): # x = bits.view(batch, -1) # x = x.float() - 0.5 - x = (x > 0).float() + # x = (x > 0).float() - q = x * self.w[0] + self.b[0] - k = x * self.w[1] + self.b[1] - v = x * self.w[2] + self.b[2] - q = q.view(batch, -1, 1) - k = k.view(batch, 1, -1) - v = v.view(batch, -1, 1) - kq = q @ k - kqv = kq @ v - kqv = kqv.view(batch, -1, 8) - x = kqv + # q = x * self.w[0] + self.b[0] + # k = x * self.w[1] + self.b[1] + # v = x * self.w[2] + self.b[2] + # q = q.view(batch, -1, 1) + # k = k.view(batch, 1, -1) + # v = v.view(batch, -1, 1) + # kq = q @ k + # kqv = kq @ v + # kqv = kqv.view(batch, -1, 8) + # x = kqv ######################### - # x = (x > 0) << xx - # x = x.sum(2) - # x = x.view(batch, 1, 28, 28) - # x = (x - 128.0) / 256.0 + # # x = (x > 0) << xx + # # x = x.sum(2) + # # x = x.view(batch, 1, 28, 28) + # # x = (x - 128.0) / 256.0 - # x = x.view(batch, 1, 28, 28) # x = (x > 0).float() + # x = x.view(batch, 1, 28, 28) # x = self.relu(self.pool(self.conv1(x))) # x = self.relu(self.pool((self.conv2(x)))) @@ -202,19 +200,39 @@ class SimpleBNN(nn.Module): ######################### - x = (x > 0).float() - x = x.view(batch, 196, 4) + x = x.view(batch, 1, 28, 28) + x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2) + x = x.view(batch, -1) + x = self.lut2_p(x) - # x = self.lut1(x) - x = self.lut2(x) + x = x.view(batch, 2, 14, 14) + x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2) + x = x.view(batch, -1) + x = self.lut3_p(x) - x = x.view(-1, 28, 7) - x = x.permute(0, 2, 1) - x = x.reshape(-1, 28 * 7) + x = x.view(batch, 4, 7, 7) + x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) + x = x.view(batch, -1) + x = self.lut4_p(x) + # x = self.lut4(x) + + x = x.view(batch, 8, 5, 5) + x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) + x = x.view(batch, -1) + x = self.lut5_p(x) + # x = self.lut5(x) + + x = x.view(batch, 16, 3, 3) + x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) + x = x.view(batch, -1) + x = self.lut6_p(x) + + # xx = 2 ** torch.arange(7, -1, -1).to(x.device) + x = x.view(batch, -1, 8) + # x = x * xx + # x = (x - 128.0) / 256.0 + x = x.sum(2) - x = self.lut3(x) - x = self.lut4(x) - x = self.lut5(x) return x @@ -222,7 +240,12 @@ torch.autograd.set_detect_anomaly(True) # model = SimpleCNN().to(device) model = SimpleBNN().to(device) criterion = nn.CrossEntropyLoss() -optimizer = torch.optim.SGD(model.parameters(), lr=0.001) +optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + +current_time = datetime.datetime.now().strftime("%m%d-%H%M%S") +writer = SummaryWriter(f"log/{current_time}") +hparam_dict = {"lr": lr, "batch_size": batch_size} +writer.add_hparams(hparam_dict, {}, run_name=f"./") def train(epoch): @@ -240,15 +263,16 @@ def train(epoch): loss = criterion(output, target) loss.backward() optimizer.step() + writer.add_scalar("loss", loss, epoch) - if batch_idx % 100 == 0: + if batch_idx % 512 == 0: print( f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} " f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}" ) -def test(): +def test(epoch): model.eval() test_loss = 0 correct = 0 @@ -262,15 +286,17 @@ def test(): test_loss /= len(test_loader.dataset) accuracy = 100.0 * correct / len(test_loader.dataset) + writer.add_scalar("accuracy", accuracy, epoch) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} " f"({accuracy:.0f}%)\n" ) -for epoch in range(1, 30): +for epoch in range(1, 300): train(epoch) - test() + test(epoch) # torch.save(model.state_dict(), "mnist_cnn.pth") print("Model saved to mnist_cnn.pth") +writer.close()