Witllm/binary/mnist.py

311 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import math
import torch.nn.functional as F
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch.profiler import profile, ProfilerActivity, record_function
from unfold import generate_unfold_index
import datetime
torch.manual_seed(1234)
np.random.seed(1234)
torch.cuda.manual_seed_all(1234)
BS = 16
LR = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # MNIST数据集的均值和标准差
)
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True, drop_last=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=True)
class Lut(torch.autograd.Function):
# input [batch, group, bits ]
# output [batch, group ]
# weight [2**bits, group ]
@staticmethod
def forward(ctx, input, weight, index):
ind = ((input > 0).long() * index).sum(dim=-1)
output = torch.gather(weight, 0, ind)
ctx.save_for_backward(input, weight, ind)
output = (output > 0).float()
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, ind = ctx.saved_tensors
grad_input = grad_weight = None
bits = input.shape[2]
if ctx.needs_input_grad[1]:
grad_weight = torch.zeros_like(weight)
grad_weight.scatter_add_(0, ind, grad_output)
if ctx.needs_input_grad[0]:
grad_input = grad_output * torch.gather(weight, 0, ind)
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
return grad_input, grad_weight, None
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn = nn.BatchNorm1d(320 * 4)
self.fc1 = nn.Linear(160, 50)
self.fc2 = nn.Linear(50, 10)
self.pool = nn.MaxPool2d(2)
self.relu = nn.ReLU()
self.weight = nn.Parameter(torch.randn(160, pow(2, 8)))
def forward(self, x):
x = self.relu(self.pool(self.conv1(x)))
x = self.relu((self.conv2(x)))
x = x.view(-1, 320 * 4)
x = self.bn(x)
x = x.view(-1, 160, 8)
x = Lut.apply(x, self.weight)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
class LutGroup(nn.Module):
def __init__(self, group, groupBits, groupRepeat=1):
assert groupBits > 1
super(LutGroup, self).__init__()
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
self.group = group
self.groupBits = groupBits
self.groupRepeat = groupRepeat
self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False)
def forward(self, x):
# input [ batch, group * groupBits ]
# output [ batch, group * groupRepeat ]
batch = x.shape[0]
x = x.view(batch, -1, self.groupBits)
if self.groupRepeat > 1:
x = x.repeat(1, self.groupRepeat, 1)
x = Lut.apply(x, self.weight, self.index)
return x
class LutCnn(nn.Module):
def __init__(self, output_c, input_shape, kernel_size, stride, dilation):
super(LutCnn, self).__init__()
B, C, H, W = input_shape
self.input_shape = input_shape
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation)
self.batch_idx = nn.Parameter(batch_idx, requires_grad=False)
self.channel_idx = nn.Parameter(channel_idx, requires_grad=False)
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
groupBits = kernel_size * kernel_size
self.lut = LutGroup(len(self.batch_idx) / B / groupBits, groupBits, output_c)
def forward(self, x):
B, C, H, W = self.input_shape
x = x.view(self.input_shape)
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
x = x.view(B, -1, self.kernel_size * self.kernel_size)
x = self.lut(x)
return x
class SimpleBNN(nn.Module):
def __init__(self):
super(SimpleBNN, self).__init__()
# self.w = nn.Parameter(torch.randn(3, 784 * 8))
# self.b = nn.Parameter(torch.zeros(3, 784 * 8))
self.w = nn.Parameter(torch.randn(3, 784))
self.b = nn.Parameter(torch.zeros(3, 784))
# output_c, input_shape, kernel_size, stride, dilation
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1)
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1)
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1)
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1)
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
# self.lutg = LutGroup()
# class LutGroup(nn.Module):
# def __init__(self, group, groupBits, groupRepeat=1):
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.pool = nn.MaxPool2d(2)
self.relu = nn.ReLU()
def forward(self, x):
batch = x.shape[0]
# x = x.view(batch, -1)
# 变换x [-0.5:0.5] 到 0-255然后按照二进制展开成8个值
# x = (x * 256 + 128).clamp(0, 255).to(torch.uint8)
# xx = torch.arange(7, -1, -1).to(x.device)
# bits = (x.unsqueeze(-1) >> xx) & 1
# x = bits.view(batch, -1)
# x = x.float() - 0.5
# x = (x > 0).float()
# q = x * self.w[0] + self.b[0]
# k = x * self.w[1] + self.b[1]
# v = x * self.w[2] + self.b[2]
# q = q.view(batch, -1, 1)
# k = k.view(batch, 1, -1)
# v = v.view(batch, -1, 1)
# kq = q @ k
# kqv = kq @ v
# kqv = kqv.view(batch, -1, 8)
# x = kqv
#########################
# # x = (x > 0) << xx
# # x = x.sum(2)
# # x = x.view(batch, 1, 28, 28)
# # x = (x - 128.0) / 256.0
# x = (x > 0).float()
# x = x.view(batch, 1, 28, 28)
# x = self.relu(self.pool(self.conv1(x)))
# x = self.relu(self.pool((self.conv2(x))))
# x = x.view(-1, 320)
# x = self.relu(self.fc1(x))
# x = self.fc2(x)
#########################
x = self.lnn1(x)
x = self.lnn2(x)
x = self.lnn3(x)
x = self.lnn4(x)
x = self.lnn5(x)
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
x = x.view(batch, -1, 8)
# x = x * xx
# x = (x - 128.0) / 256.0
x = x.sum(2)
return x
torch.autograd.set_detect_anomaly(True)
# model = SimpleCNN().to(device)
model = SimpleBNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
tbWriter = None
def AddScalar(tag, value, epoch):
global tbWriter
if not tbWriter:
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
tbWriter = SummaryWriter(f"log/{current_time}")
hparam_dict = {"lr": LR, "batch_size": BS}
tbWriter.add_hparams(hparam_dict, {}, run_name=f"./")
tbWriter.add_scalar(tag, value, epoch)
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
AddScalar("loss", loss, epoch)
if batch_idx % 512 == 0 and batch_idx > 0:
print(
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
)
def test(epoch):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100.0 * correct / len(test_loader.dataset)
AddScalar("accuracy", accuracy, epoch)
print(
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
f"({accuracy:.0f}%)\n"
)
def profiler():
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("model_inference"):
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
if batch_idx > 10:
prof.export_chrome_trace("local.json")
assert False
# profiler()
for epoch in range(1, 300):
train(epoch)
test(epoch)
# torch.save(model.state_dict(), "mnist_cnn.pth")
print("Model saved to mnist_cnn.pth")
if tbWriter:
tbWriter.close()