Witllm/binary/mnist.py

303 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import math
import torch.nn.functional as F
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import datetime
torch.manual_seed(1234)
np.random.seed(1234)
torch.cuda.manual_seed_all(1234)
batch_size = 16
lr = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # MNIST数据集的均值和标准差
)
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
def to_binary_tensor(input_tensor, bits):
int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64)
shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device)
binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1
return binary_bits
class Lut(torch.autograd.Function):
# input [batch,count,bits]
# weight [count,2**bits]
@staticmethod
def forward(ctx, input, weight):
batch = input.shape[0]
count = input.shape[1]
bits = input.shape[2]
assert int(math.log2(weight.shape[-1])) == bits
index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device)
x = (input > 0).long()
x = x * index
ind = x.sum(dim=-1)
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
output = weight[row_indices, ind]
ctx.save_for_backward(input, weight, ind)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, ind = ctx.saved_tensors
grad_input = grad_weight = None
batch = input.shape[0]
count = input.shape[1]
bits = input.shape[2]
if ctx.needs_input_grad[1]:
grad_weight = torch.zeros_like(weight)
ind_p = ind.permute(1, 0)
grad_output_p = grad_output.permute(1, 0)
grad_weight.scatter_add_(1, ind_p, grad_output_p)
if ctx.needs_input_grad[0]:
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
grad_input = grad_output * weight[row_indices, ind]
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
return grad_input, grad_weight
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn = nn.BatchNorm1d(320 * 4)
self.fc1 = nn.Linear(160, 50)
self.fc2 = nn.Linear(50, 10)
self.pool = nn.MaxPool2d(2)
self.relu = nn.ReLU()
self.weight = nn.Parameter(torch.randn(160, pow(2, 8)))
def forward(self, x):
x = self.relu(self.pool(self.conv1(x)))
x = self.relu((self.conv2(x)))
x = x.view(-1, 320 * 4)
x = self.bn(x)
x = x.view(-1, 160, 8)
x = Lut.apply(x, self.weight)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
class LutGroup(nn.Module):
def __init__(self, totalBits, groupBits=0, repeat=1):
if not groupBits:
groupBits = totalBits
super(LutGroup, self).__init__()
assert (totalBits % groupBits) == 0
self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits)))
self.totalBits = totalBits
self.groupBits = groupBits
self.repeat = repeat
def forward(self, x):
# input [ batch, totalBits ]
# output [ batch, totalBits / groupBits * repeat ]
batch = x.shape[0]
x = x.view(batch, -1, self.groupBits)
x = x.repeat(1, self.repeat, 1)
x = Lut.apply(x, self.weight)
return x
class SimpleBNN(nn.Module):
def __init__(self):
super(SimpleBNN, self).__init__()
# self.w = nn.Parameter(torch.randn(3, 784 * 8))
# self.b = nn.Parameter(torch.zeros(3, 784 * 8))
self.w = nn.Parameter(torch.randn(3, 784))
self.b = nn.Parameter(torch.zeros(3, 784))
self.lut2_p = LutGroup(784, 4, 2) # pool2 => 14*14*4
self.lut3_p = LutGroup(14 * 14 * 2, 4, 2) # pool2 => 7*7*4
self.lut4_p = LutGroup(3 * 3 * 5 * 5 * 4, 9, 2) # conv 3 => 5*5*8
# self.lut4 = LutGroup(5 * 5 * 8, 8, 8) # conv 3 => 5*5*8
self.lut5_p = LutGroup(3 * 3 * 3 * 3 * 8, 9, 2) # conv 3 => 3*3*8*2
# self.lut5 = LutGroup(3 * 3 * 8 * 2, 16, 16) # conv 3 => 3*3*16
self.lut6_p = LutGroup(3 * 3 * 16, 9, 5) # conv 3 => 80
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.pool = nn.MaxPool2d(2)
self.relu = nn.ReLU()
def forward(self, x):
batch = x.shape[0]
x = x.view(batch, -1)
# 变换x [-0.5:0.5] 到 0-255然后按照二进制展开成8个值
# x = (x * 256 + 128).clamp(0, 255).to(torch.uint8)
# xx = torch.arange(7, -1, -1).to(x.device)
# bits = (x.unsqueeze(-1) >> xx) & 1
# x = bits.view(batch, -1)
# x = x.float() - 0.5
# x = (x > 0).float()
# q = x * self.w[0] + self.b[0]
# k = x * self.w[1] + self.b[1]
# v = x * self.w[2] + self.b[2]
# q = q.view(batch, -1, 1)
# k = k.view(batch, 1, -1)
# v = v.view(batch, -1, 1)
# kq = q @ k
# kqv = kq @ v
# kqv = kqv.view(batch, -1, 8)
# x = kqv
#########################
# # x = (x > 0) << xx
# # x = x.sum(2)
# # x = x.view(batch, 1, 28, 28)
# # x = (x - 128.0) / 256.0
# x = (x > 0).float()
# x = x.view(batch, 1, 28, 28)
# x = self.relu(self.pool(self.conv1(x)))
# x = self.relu(self.pool((self.conv2(x))))
# x = x.view(-1, 320)
# x = self.relu(self.fc1(x))
# x = self.fc2(x)
#########################
x = x.view(batch, 1, 28, 28)
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
x = x.view(batch, -1)
x = self.lut2_p(x)
x = x.view(batch, 2, 14, 14)
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
x = x.view(batch, -1)
x = self.lut3_p(x)
x = x.view(batch, 4, 7, 7)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
x = self.lut4_p(x)
# x = self.lut4(x)
x = x.view(batch, 8, 5, 5)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
x = self.lut5_p(x)
# x = self.lut5(x)
x = x.view(batch, 16, 3, 3)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
x = self.lut6_p(x)
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
x = x.view(batch, -1, 8)
# x = x * xx
# x = (x - 128.0) / 256.0
x = x.sum(2)
return x
torch.autograd.set_detect_anomaly(True)
# model = SimpleCNN().to(device)
model = SimpleBNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
writer = SummaryWriter(f"log/{current_time}")
hparam_dict = {"lr": lr, "batch_size": batch_size}
writer.add_hparams(hparam_dict, {}, run_name=f"./")
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
# output = output * 1.0
# output = F.softmax(output, dim=1)
# print(output.requires_grad)
# print(output.grad_fn)
loss = criterion(output, target)
loss.backward()
optimizer.step()
writer.add_scalar("loss", loss, epoch)
if batch_idx % 512 == 0:
print(
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
)
def test(epoch):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100.0 * correct / len(test_loader.dataset)
writer.add_scalar("accuracy", accuracy, epoch)
print(
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
f"({accuracy:.0f}%)\n"
)
for epoch in range(1, 300):
train(epoch)
test(epoch)
# torch.save(model.state_dict(), "mnist_cnn.pth")
print("Model saved to mnist_cnn.pth")
writer.close()