Witllm/binary/mnist.py

279 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import math
import torch.nn.functional as F
import numpy as np
torch.manual_seed(1234)
np.random.seed(1234)
torch.cuda.manual_seed_all(1234)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # MNIST数据集的均值和标准差
)
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
def to_binary_tensor(input_tensor, bits):
int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64)
shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device)
binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1
return binary_bits
class MyLut(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight):
batch = input.shape[0]
count = input.shape[1]
bits = input.shape[2]
assert int(math.log2(weight.shape[-1])) == bits
index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device)
x = (input > 0).long()
x = x * index
ind = x.sum(dim=-1)
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
output = weight[row_indices, ind]
ctx.save_for_backward(input, weight, ind)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, ind = ctx.saved_tensors
grad_input = grad_weight = None
batch = input.shape[0]
count = input.shape[1]
bits = input.shape[2]
if ctx.needs_input_grad[1]:
grad_weight = torch.zeros_like(weight)
ind_p = ind.permute(1, 0)
grad_output_p = grad_output.permute(1, 0)
grad_weight.scatter_add_(1, ind_p, grad_output_p)
if ctx.needs_input_grad[0]:
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
grad_input = grad_output * weight[row_indices, ind]
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
return grad_input, grad_weight
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn = nn.BatchNorm1d(320 * 4)
self.fc1 = nn.Linear(160, 50)
self.fc2 = nn.Linear(50, 10)
self.pool = nn.MaxPool2d(2)
self.relu = nn.ReLU()
self.weight = nn.Parameter(torch.randn(160, pow(2, 8)))
def forward(self, x):
x = self.relu(self.pool(self.conv1(x)))
x = self.relu((self.conv2(x)))
x = x.view(-1, 320 * 4)
x = self.bn(x)
x = x.view(-1, 160, 8)
x = MyLut.apply(x, self.weight)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
class Lut(nn.Module):
def __init__(self, bits):
super(Lut, self).__init__()
self.weight = nn.Parameter(torch.randn(pow(2, bits)))
self.bias = nn.Parameter(torch.randn(pow(2, bits)))
self.index = torch.pow(2, torch.arange(bits))
self.bits = bits
def forward(self, x):
x = MyLut.apply(x, self.weight, self.bias)
# tmp = torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
# x = tmp + x - x.detach()
xx = (x > 0).float()
x = xx + (x - x.detach())
# print(xx.requires_grad)
# print(xx.grad_fn)
x = x * (self.index.to(x.device))
x = torch.sum(x, dim=-1)
w = torch.gather(self.weight, 0, x.long())
b = torch.gather(self.weight, 0, x.long())
x = w * x + b
# tmp = torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
# x = tmp + x - x.detach()
xx = (x > 0).float()
x = xx + (x - x.detach())
x = x.view(-1, 1)
return x
class LutGroup(nn.Module):
def __init__(self, bits, subbits):
super(LutGroup, self).__init__()
assert (bits % subbits) == 0
self.lutlist = nn.ModuleList([Lut(subbits) for _ in range(int(bits / subbits))])
self.bits = bits
self.subbits = subbits
def forward(self, x):
ll = len(self.lutlist)
tmp = torch.empty((x.shape[0], 0), dtype=x.dtype, device=x.device)
start = 0
end = self.subbits
for i in range(ll):
tx = self.lutlist[i](x[:, start:end])
tmp = torch.cat((tmp, tx), dim=1)
start += self.subbits
end += self.subbits
return tmp
class LutParallel(nn.Module):
def __init__(self, bits, number):
super(LutParallel, self).__init__()
self.lutlist = nn.ModuleList([Lut(bits) for _ in range(number)])
self.bits = bits
self.number = number
def forward(self, x):
tmp = torch.empty((x.shape[0], 0), dtype=x.dtype, device=x.device)
for i in range(self.number):
tx = self.lutlist[i](x)
tmp = torch.cat((tmp, tx), dim=1)
return tmp
class SimpleBNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.w = nn.Parameter(torch.randn(3, 784 * 8))
self.b = nn.Parameter(torch.zeros(3, 784 * 8))
self.lut1 = LutGroup(784 * 8, 8)
self.lut2 = LutGroup(784, 8)
self.lut3 = LutGroup(98, 14)
self.lut4 = LutParallel(7, 10)
def forward(self, x):
batch = x.shape[0]
x = x.view(batch, -1)
# 变换x [-0.5:0.5] 到 0-255然后按照二进制展开成8个值
x = (x * 256 + 128).clamp(0, 255).to(torch.uint8)
xx = torch.arange(8).to(x.device)
bits = (x.unsqueeze(-1) >> xx) & 1
bits = bits.view(batch, -1)
x = bits
# q = x * self.w[0] + self.b[0]
# k = x * self.w[1] + self.b[1]
# v = x * self.w[2] + self.b[2]
# q = q.view(batch, -1, 1)
# k = k.view(batch, 1, -1)
# v = v.view(batch, -1, 1)
# kq = q @ k
# kqv = kq @ v
# kqv = kqv.view(batch, -1)
kqv = x
x = self.lut1(kqv)
x = self.lut2(x)
x = self.lut3(x)
x = self.lut4(x)
return x
torch.autograd.set_detect_anomaly(True)
model = SimpleCNN().to(device)
# model = SimpleBNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
# output = output * 1.0
# output = F.softmax(output, dim=1)
# print(output.requires_grad)
# print(output.grad_fn)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
)
def test():
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100.0 * correct / len(test_loader.dataset)
print(
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
f"({accuracy:.0f}%)\n"
)
for epoch in range(1, 30):
train(epoch)
test()
# torch.save(model.state_dict(), "mnist_cnn.pth")
print("Model saved to mnist_cnn.pth")