Update mnist.

This commit is contained in:
Colin 2025-05-22 15:23:41 +08:00
parent f1124bc3b1
commit f98a951b58
1 changed files with 83 additions and 57 deletions

View File

@ -12,15 +12,17 @@ from torch.utils.data import DataLoader
import math
import torch.nn.functional as F
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import datetime
torch.manual_seed(1234)
np.random.seed(1234)
torch.cuda.manual_seed_all(1234)
batch_size = 16
lr = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
# torch.set_num_threads(16)
print(f"Using device: {device}")
transform = transforms.Compose(
@ -30,7 +32,7 @@ transform = transforms.Compose(
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
@ -112,28 +114,22 @@ class SimpleCNN(nn.Module):
class LutGroup(nn.Module):
def __init__(self, bits, subbits):
def __init__(self, totalBits, groupBits=0, repeat=1):
if not groupBits:
groupBits = totalBits
super(LutGroup, self).__init__()
assert (bits % subbits) == 0
self.weight = nn.Parameter(torch.randn(bits, pow(2, subbits)))
self.bits = bits
self.subbits = subbits
assert (totalBits % groupBits) == 0
self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits)))
self.totalBits = totalBits
self.groupBits = groupBits
self.repeat = repeat
def forward(self, x):
# input [ batch, totalBits ]
# output [ batch, totalBits / groupBits * repeat ]
batch = x.shape[0]
x = x.view(batch, -1, self.subbits)
x = Lut.apply(x, self.weight)
return x
class LutParallel(nn.Module):
def __init__(self, bits, number):
super(LutParallel, self).__init__()
self.number = number
self.weight = nn.Parameter(torch.randn(number, pow(2, bits)))
def forward(self, x):
x = x.unsqueeze(1).repeat(1, self.number, 1)
x = x.view(batch, -1, self.groupBits)
x = x.repeat(1, self.repeat, 1)
x = Lut.apply(x, self.weight)
return x
@ -147,11 +143,13 @@ class SimpleBNN(nn.Module):
self.w = nn.Parameter(torch.randn(3, 784))
self.b = nn.Parameter(torch.zeros(3, 784))
self.lut1 = LutGroup(784 * 8, 8)
self.lut2 = LutGroup(784, 4)
self.lut3 = LutGroup(196, 4)
self.lut4 = LutGroup(49, 7)
self.lut5 = LutParallel(7, 10)
self.lut2_p = LutGroup(784, 4, 2) # pool2 => 14*14*4
self.lut3_p = LutGroup(14 * 14 * 2, 4, 2) # pool2 => 7*7*4
self.lut4_p = LutGroup(3 * 3 * 5 * 5 * 4, 9, 2) # conv 3 => 5*5*8
# self.lut4 = LutGroup(5 * 5 * 8, 8, 8) # conv 3 => 5*5*8
self.lut5_p = LutGroup(3 * 3 * 3 * 3 * 8, 9, 2) # conv 3 => 3*3*8*2
# self.lut5 = LutGroup(3 * 3 * 8 * 2, 16, 16) # conv 3 => 3*3*16
self.lut6_p = LutGroup(3 * 3 * 16, 9, 5) # conv 3 => 80
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
@ -171,28 +169,28 @@ class SimpleBNN(nn.Module):
# x = bits.view(batch, -1)
# x = x.float() - 0.5
x = (x > 0).float()
# x = (x > 0).float()
q = x * self.w[0] + self.b[0]
k = x * self.w[1] + self.b[1]
v = x * self.w[2] + self.b[2]
q = q.view(batch, -1, 1)
k = k.view(batch, 1, -1)
v = v.view(batch, -1, 1)
kq = q @ k
kqv = kq @ v
kqv = kqv.view(batch, -1, 8)
x = kqv
# q = x * self.w[0] + self.b[0]
# k = x * self.w[1] + self.b[1]
# v = x * self.w[2] + self.b[2]
# q = q.view(batch, -1, 1)
# k = k.view(batch, 1, -1)
# v = v.view(batch, -1, 1)
# kq = q @ k
# kqv = kq @ v
# kqv = kqv.view(batch, -1, 8)
# x = kqv
#########################
# x = (x > 0) << xx
# x = x.sum(2)
# x = x.view(batch, 1, 28, 28)
# x = (x - 128.0) / 256.0
# # x = (x > 0) << xx
# # x = x.sum(2)
# # x = x.view(batch, 1, 28, 28)
# # x = (x - 128.0) / 256.0
# x = x.view(batch, 1, 28, 28)
# x = (x > 0).float()
# x = x.view(batch, 1, 28, 28)
# x = self.relu(self.pool(self.conv1(x)))
# x = self.relu(self.pool((self.conv2(x))))
@ -202,19 +200,39 @@ class SimpleBNN(nn.Module):
#########################
x = (x > 0).float()
x = x.view(batch, 196, 4)
x = x.view(batch, 1, 28, 28)
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
x = x.view(batch, -1)
x = self.lut2_p(x)
# x = self.lut1(x)
x = self.lut2(x)
x = x.view(batch, 2, 14, 14)
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
x = x.view(batch, -1)
x = self.lut3_p(x)
x = x.view(-1, 28, 7)
x = x.permute(0, 2, 1)
x = x.reshape(-1, 28 * 7)
x = x.view(batch, 4, 7, 7)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
x = self.lut4_p(x)
# x = self.lut4(x)
x = x.view(batch, 8, 5, 5)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
x = self.lut5_p(x)
# x = self.lut5(x)
x = x.view(batch, 16, 3, 3)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
x = self.lut6_p(x)
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
x = x.view(batch, -1, 8)
# x = x * xx
# x = (x - 128.0) / 256.0
x = x.sum(2)
x = self.lut3(x)
x = self.lut4(x)
x = self.lut5(x)
return x
@ -222,7 +240,12 @@ torch.autograd.set_detect_anomaly(True)
# model = SimpleCNN().to(device)
model = SimpleBNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
writer = SummaryWriter(f"log/{current_time}")
hparam_dict = {"lr": lr, "batch_size": batch_size}
writer.add_hparams(hparam_dict, {}, run_name=f"./")
def train(epoch):
@ -240,15 +263,16 @@ def train(epoch):
loss = criterion(output, target)
loss.backward()
optimizer.step()
writer.add_scalar("loss", loss, epoch)
if batch_idx % 100 == 0:
if batch_idx % 512 == 0:
print(
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
)
def test():
def test(epoch):
model.eval()
test_loss = 0
correct = 0
@ -262,15 +286,17 @@ def test():
test_loss /= len(test_loader.dataset)
accuracy = 100.0 * correct / len(test_loader.dataset)
writer.add_scalar("accuracy", accuracy, epoch)
print(
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
f"({accuracy:.0f}%)\n"
)
for epoch in range(1, 30):
for epoch in range(1, 300):
train(epoch)
test()
test(epoch)
# torch.save(model.state_dict(), "mnist_cnn.pth")
print("Model saved to mnist_cnn.pth")
writer.close()