Update mnist.
This commit is contained in:
parent
f1124bc3b1
commit
f98a951b58
140
binary/mnist.py
140
binary/mnist.py
|
@ -12,15 +12,17 @@ from torch.utils.data import DataLoader
|
||||||
import math
|
import math
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
import datetime
|
||||||
|
|
||||||
torch.manual_seed(1234)
|
torch.manual_seed(1234)
|
||||||
np.random.seed(1234)
|
np.random.seed(1234)
|
||||||
torch.cuda.manual_seed_all(1234)
|
torch.cuda.manual_seed_all(1234)
|
||||||
|
|
||||||
|
batch_size = 16
|
||||||
|
lr = 0.001
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
# device = "cpu"
|
|
||||||
# torch.set_num_threads(16)
|
|
||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
transform = transforms.Compose(
|
transform = transforms.Compose(
|
||||||
|
@ -30,7 +32,7 @@ transform = transforms.Compose(
|
||||||
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
|
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
|
||||||
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
|
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||||
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
|
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,28 +114,22 @@ class SimpleCNN(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class LutGroup(nn.Module):
|
class LutGroup(nn.Module):
|
||||||
def __init__(self, bits, subbits):
|
def __init__(self, totalBits, groupBits=0, repeat=1):
|
||||||
|
if not groupBits:
|
||||||
|
groupBits = totalBits
|
||||||
super(LutGroup, self).__init__()
|
super(LutGroup, self).__init__()
|
||||||
assert (bits % subbits) == 0
|
assert (totalBits % groupBits) == 0
|
||||||
self.weight = nn.Parameter(torch.randn(bits, pow(2, subbits)))
|
self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits)))
|
||||||
self.bits = bits
|
self.totalBits = totalBits
|
||||||
self.subbits = subbits
|
self.groupBits = groupBits
|
||||||
|
self.repeat = repeat
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
# input [ batch, totalBits ]
|
||||||
|
# output [ batch, totalBits / groupBits * repeat ]
|
||||||
batch = x.shape[0]
|
batch = x.shape[0]
|
||||||
x = x.view(batch, -1, self.subbits)
|
x = x.view(batch, -1, self.groupBits)
|
||||||
x = Lut.apply(x, self.weight)
|
x = x.repeat(1, self.repeat, 1)
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class LutParallel(nn.Module):
|
|
||||||
def __init__(self, bits, number):
|
|
||||||
super(LutParallel, self).__init__()
|
|
||||||
self.number = number
|
|
||||||
self.weight = nn.Parameter(torch.randn(number, pow(2, bits)))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.unsqueeze(1).repeat(1, self.number, 1)
|
|
||||||
x = Lut.apply(x, self.weight)
|
x = Lut.apply(x, self.weight)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -147,11 +143,13 @@ class SimpleBNN(nn.Module):
|
||||||
self.w = nn.Parameter(torch.randn(3, 784))
|
self.w = nn.Parameter(torch.randn(3, 784))
|
||||||
self.b = nn.Parameter(torch.zeros(3, 784))
|
self.b = nn.Parameter(torch.zeros(3, 784))
|
||||||
|
|
||||||
self.lut1 = LutGroup(784 * 8, 8)
|
self.lut2_p = LutGroup(784, 4, 2) # pool2 => 14*14*4
|
||||||
self.lut2 = LutGroup(784, 4)
|
self.lut3_p = LutGroup(14 * 14 * 2, 4, 2) # pool2 => 7*7*4
|
||||||
self.lut3 = LutGroup(196, 4)
|
self.lut4_p = LutGroup(3 * 3 * 5 * 5 * 4, 9, 2) # conv 3 => 5*5*8
|
||||||
self.lut4 = LutGroup(49, 7)
|
# self.lut4 = LutGroup(5 * 5 * 8, 8, 8) # conv 3 => 5*5*8
|
||||||
self.lut5 = LutParallel(7, 10)
|
self.lut5_p = LutGroup(3 * 3 * 3 * 3 * 8, 9, 2) # conv 3 => 3*3*8*2
|
||||||
|
# self.lut5 = LutGroup(3 * 3 * 8 * 2, 16, 16) # conv 3 => 3*3*16
|
||||||
|
self.lut6_p = LutGroup(3 * 3 * 16, 9, 5) # conv 3 => 80
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||||
|
@ -171,28 +169,28 @@ class SimpleBNN(nn.Module):
|
||||||
# x = bits.view(batch, -1)
|
# x = bits.view(batch, -1)
|
||||||
# x = x.float() - 0.5
|
# x = x.float() - 0.5
|
||||||
|
|
||||||
x = (x > 0).float()
|
# x = (x > 0).float()
|
||||||
|
|
||||||
q = x * self.w[0] + self.b[0]
|
# q = x * self.w[0] + self.b[0]
|
||||||
k = x * self.w[1] + self.b[1]
|
# k = x * self.w[1] + self.b[1]
|
||||||
v = x * self.w[2] + self.b[2]
|
# v = x * self.w[2] + self.b[2]
|
||||||
q = q.view(batch, -1, 1)
|
# q = q.view(batch, -1, 1)
|
||||||
k = k.view(batch, 1, -1)
|
# k = k.view(batch, 1, -1)
|
||||||
v = v.view(batch, -1, 1)
|
# v = v.view(batch, -1, 1)
|
||||||
kq = q @ k
|
# kq = q @ k
|
||||||
kqv = kq @ v
|
# kqv = kq @ v
|
||||||
kqv = kqv.view(batch, -1, 8)
|
# kqv = kqv.view(batch, -1, 8)
|
||||||
x = kqv
|
# x = kqv
|
||||||
|
|
||||||
#########################
|
#########################
|
||||||
|
|
||||||
# x = (x > 0) << xx
|
# # x = (x > 0) << xx
|
||||||
# x = x.sum(2)
|
# # x = x.sum(2)
|
||||||
# x = x.view(batch, 1, 28, 28)
|
# # x = x.view(batch, 1, 28, 28)
|
||||||
# x = (x - 128.0) / 256.0
|
# # x = (x - 128.0) / 256.0
|
||||||
|
|
||||||
# x = x.view(batch, 1, 28, 28)
|
|
||||||
# x = (x > 0).float()
|
# x = (x > 0).float()
|
||||||
|
# x = x.view(batch, 1, 28, 28)
|
||||||
|
|
||||||
# x = self.relu(self.pool(self.conv1(x)))
|
# x = self.relu(self.pool(self.conv1(x)))
|
||||||
# x = self.relu(self.pool((self.conv2(x))))
|
# x = self.relu(self.pool((self.conv2(x))))
|
||||||
|
@ -202,19 +200,39 @@ class SimpleBNN(nn.Module):
|
||||||
|
|
||||||
#########################
|
#########################
|
||||||
|
|
||||||
x = (x > 0).float()
|
x = x.view(batch, 1, 28, 28)
|
||||||
x = x.view(batch, 196, 4)
|
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
||||||
|
x = x.view(batch, -1)
|
||||||
|
x = self.lut2_p(x)
|
||||||
|
|
||||||
# x = self.lut1(x)
|
x = x.view(batch, 2, 14, 14)
|
||||||
x = self.lut2(x)
|
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
||||||
|
x = x.view(batch, -1)
|
||||||
|
x = self.lut3_p(x)
|
||||||
|
|
||||||
x = x.view(-1, 28, 7)
|
x = x.view(batch, 4, 7, 7)
|
||||||
x = x.permute(0, 2, 1)
|
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||||
x = x.reshape(-1, 28 * 7)
|
x = x.view(batch, -1)
|
||||||
|
x = self.lut4_p(x)
|
||||||
|
# x = self.lut4(x)
|
||||||
|
|
||||||
|
x = x.view(batch, 8, 5, 5)
|
||||||
|
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||||
|
x = x.view(batch, -1)
|
||||||
|
x = self.lut5_p(x)
|
||||||
|
# x = self.lut5(x)
|
||||||
|
|
||||||
|
x = x.view(batch, 16, 3, 3)
|
||||||
|
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||||
|
x = x.view(batch, -1)
|
||||||
|
x = self.lut6_p(x)
|
||||||
|
|
||||||
|
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
|
||||||
|
x = x.view(batch, -1, 8)
|
||||||
|
# x = x * xx
|
||||||
|
# x = (x - 128.0) / 256.0
|
||||||
|
x = x.sum(2)
|
||||||
|
|
||||||
x = self.lut3(x)
|
|
||||||
x = self.lut4(x)
|
|
||||||
x = self.lut5(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -222,7 +240,12 @@ torch.autograd.set_detect_anomaly(True)
|
||||||
# model = SimpleCNN().to(device)
|
# model = SimpleCNN().to(device)
|
||||||
model = SimpleBNN().to(device)
|
model = SimpleBNN().to(device)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||||
|
|
||||||
|
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
|
||||||
|
writer = SummaryWriter(f"log/{current_time}")
|
||||||
|
hparam_dict = {"lr": lr, "batch_size": batch_size}
|
||||||
|
writer.add_hparams(hparam_dict, {}, run_name=f"./")
|
||||||
|
|
||||||
|
|
||||||
def train(epoch):
|
def train(epoch):
|
||||||
|
@ -240,15 +263,16 @@ def train(epoch):
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
writer.add_scalar("loss", loss, epoch)
|
||||||
|
|
||||||
if batch_idx % 100 == 0:
|
if batch_idx % 512 == 0:
|
||||||
print(
|
print(
|
||||||
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
||||||
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
|
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test(epoch):
|
||||||
model.eval()
|
model.eval()
|
||||||
test_loss = 0
|
test_loss = 0
|
||||||
correct = 0
|
correct = 0
|
||||||
|
@ -262,15 +286,17 @@ def test():
|
||||||
|
|
||||||
test_loss /= len(test_loader.dataset)
|
test_loss /= len(test_loader.dataset)
|
||||||
accuracy = 100.0 * correct / len(test_loader.dataset)
|
accuracy = 100.0 * correct / len(test_loader.dataset)
|
||||||
|
writer.add_scalar("accuracy", accuracy, epoch)
|
||||||
print(
|
print(
|
||||||
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
|
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
|
||||||
f"({accuracy:.0f}%)\n"
|
f"({accuracy:.0f}%)\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
for epoch in range(1, 30):
|
for epoch in range(1, 300):
|
||||||
train(epoch)
|
train(epoch)
|
||||||
test()
|
test(epoch)
|
||||||
|
|
||||||
# torch.save(model.state_dict(), "mnist_cnn.pth")
|
# torch.save(model.state_dict(), "mnist_cnn.pth")
|
||||||
print("Model saved to mnist_cnn.pth")
|
print("Model saved to mnist_cnn.pth")
|
||||||
|
writer.close()
|
||||||
|
|
Loading…
Reference in New Issue