Witllm/binary/mnist.py

378 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import math
import torch.nn.functional as F
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch.profiler import profile, ProfilerActivity, record_function
from unfold import generate_unfold_index
import datetime
torch.manual_seed(1234)
np.random.seed(1234)
torch.cuda.manual_seed_all(1234)
BS = 16
LR = 0.01
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # MNIST数据集的均值和标准差
)
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True, drop_last=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=True)
class Lut(torch.autograd.Function):
# input [batch, group, bits ]
# output [batch, group ]
# weight [2**bits, group ]
@staticmethod
def forward(ctx, input, weight, index):
ind = ((input > 0).long() * index).sum(dim=-1)
output = torch.gather(weight, 0, ind)
output = (output > 0).float()
output = (output - 0.5) * 2.0
ctx.save_for_backward(input, weight, ind, output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, ind, output = ctx.saved_tensors
grad_input = grad_weight = None
bits = input.shape[2]
if ctx.needs_input_grad[1]:
grad_weight = torch.zeros_like(weight)
grad_weight.scatter_add_(0, ind, grad_output)
if ctx.needs_input_grad[0]:
# grad_input = grad_output * torch.gather(weight, 0, ind)
grad_input = grad_output
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
output = output.unsqueeze(-1).repeat(1, 1, bits)
in_sign = ((input > 0).float() - 0.5) * 2.0
grad_input = grad_input * in_sign
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
# grad_input = grad_output
# grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
# output = output.unsqueeze(-1).repeat(1, 1, bits)
# in_sign = ((input > 0).float() - 0.5) * 2.0
# out_sign = ((output > 0).float() - 0.5) * 2.0
# grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
# grad_input = grad_input * in_sign * (out_sign * grad_sign)
# grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
# 需要一个动态的调整系数
# 能稳定的收敛
# print(in_sign[0].detach().cpu().numpy())
# print(out_sign[0].detach().cpu().numpy())
# print(grad_sign[0].detach().cpu().numpy())
# print(grad_input[0].detach().cpu().numpy())
return grad_input, grad_weight, None, None
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn = nn.BatchNorm1d(320 * 4)
self.fc1 = nn.Linear(160, 50)
self.fc2 = nn.Linear(50, 10)
self.pool = nn.MaxPool2d(2)
self.relu = nn.ReLU()
self.weight = nn.Parameter(torch.randn(160, pow(2, 8)))
def forward(self, x):
x = self.relu(self.pool(self.conv1(x)))
x = self.relu((self.conv2(x)))
x = x.view(-1, 320 * 4)
x = self.bn(x)
x = x.view(-1, 160, 8)
x = Lut.apply(x, self.weight)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
class LutGroup(nn.Module):
def __init__(self, group, groupBits, groupRepeat=1):
assert groupBits > 1
super(LutGroup, self).__init__()
self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group)))
self.group = group
self.groupBits = groupBits
self.groupRepeat = groupRepeat
self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False)
def forward(self, x):
# input [ batch, group * groupBits ]
# output [ batch, group * groupRepeat ]
batch = x.shape[0]
x = x.reshape(batch, -1, self.groupBits)
if self.groupRepeat > 1:
x = x.repeat(1, self.groupRepeat, 1)
x = Lut.apply(x, self.weight, self.index)
return x
class LutCnn(nn.Module):
def __init__(self, channel_repeat, input_shape, kernel_size, stride, dilation, fc=False):
super(LutCnn, self).__init__()
B, C, H, W = input_shape
self.input_shape = input_shape
self.kernel_size = kernel_size
self.channel_repeat = channel_repeat
self.stride = stride
self.dilation = dilation
batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation)
self.batch_idx = nn.Parameter(batch_idx, requires_grad=False)
self.channel_idx = nn.Parameter(channel_idx, requires_grad=False)
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
groupBits = kernel_size * kernel_size
group = int(len(self.batch_idx) / B / groupBits)
self.lut = LutGroup(group, groupBits, channel_repeat)
self.fc = fc
if fc:
self.lutc = LutGroup(group, channel_repeat * C, channel_repeat * C)
def forward(self, x):
B, C, H, W = self.input_shape
x = x.view(self.input_shape)
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
x = x.view(B, -1, self.kernel_size * self.kernel_size)
x = self.lut(x)
if self.fc:
x = x.view(B, -1, self.channel_repeat)
x = x.permute(0, 2, 1)
x = self.lutc(x)
return x
class SimpleBNN(nn.Module):
def __init__(self):
super(SimpleBNN, self).__init__()
# self.w = nn.Parameter(torch.randn(3, 784 * 8))
# self.b = nn.Parameter(torch.zeros(3, 784 * 8))
self.w = nn.Parameter(torch.randn(3, 784))
self.b = nn.Parameter(torch.zeros(3, 784))
# channel_repeat, input_shape, kernel_size, stride, dilation, fc
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1, False)
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1, False)
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1, False)
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1, False)
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.pool = nn.MaxPool2d(2)
self.relu = nn.ReLU()
def forward(self, x, t):
batch = x.shape[0]
# x = x.view(batch, -1)
# 变换x [-0.5:0.5] 到 0-255然后按照二进制展开成8个值
# x = (x * 256 + 128).clamp(0, 255).to(torch.uint8)
# xx = torch.arange(7, -1, -1).to(x.device)
# bits = (x.unsqueeze(-1) >> xx) & 1
# x = bits.view(batch, -1)
# x = x.float() - 0.5
# x = (x > 0).float()
# q = x * self.w[0] + self.b[0]
# k = x * self.w[1] + self.b[1]
# v = x * self.w[2] + self.b[2]
# q = q.view(batch, -1, 1)
# k = k.view(batch, 1, -1)
# v = v.view(batch, -1, 1)
# kq = q @ k
# kqv = kq @ v
# kqv = kqv.view(batch, -1, 8)
# x = kqv
#########################
# # x = (x > 0) << xx
# # x = x.sum(2)
# # x = x.view(batch, 1, 28, 28)
# # x = (x - 128.0) / 256.0
# x = (x > 0).float()
# x = x.view(batch, 1, 28, 28)
# x = self.relu(self.pool(self.conv1(x)))
# x = self.relu(self.pool((self.conv2(x))))
# x = x.view(-1, 320)
# x = self.relu(self.fc1(x))
# x = self.fc2(x)
#########################
x = self.lnn1(x)
x = self.lnn2(x)
x = self.lnn3(x)
x = self.lnn4(x)
x = self.lnn5(x)
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
x = x.view(batch, -1, 8)
# x = x * xx
# x = (x - 128.0) / 256.0
x = x.sum(2)
return x
def printWeight(self):
pass
class SimpleLNN(nn.Module):
def __init__(self):
super(SimpleLNN, self).__init__()
# group, groupBits, groupRepeat
self.lutg1 = LutGroup(1, 10, 4)
self.lutg2 = LutGroup(1, 4, 10)
def forward(self, x, t):
batch = x.shape[0]
x = torch.zeros_like(t).unsqueeze(-1).repeat(1, 10)
x[torch.arange(0, batch), t] = 1
x = self.lutg1(x)
x = self.lutg2(x)
return x
def printWeight(self):
print("self.lutg1")
print(self.lutg1.weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
print("=============================")
print("=============================")
print("self.lutg1.grad")
print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
print("=============================")
print("=============================")
# print("self.lutg2")
# print(self.lutg2.weight.detach().cpu().numpy())
# print("=============================")
# print("=============================")
torch.autograd.set_detect_anomaly(True)
# model = SimpleCNN().to(device)
# model = SimpleBNN().to(device)
model = SimpleLNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
tbWriter = None
def AddScalar(tag, value, epoch):
global tbWriter
if not tbWriter:
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
tbWriter = SummaryWriter(f"log/{current_time}")
hparam_dict = {"lr": LR, "batch_size": BS}
tbWriter.add_hparams(hparam_dict, {}, run_name=f"./")
tbWriter.add_scalar(tag, value, epoch)
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data, target)
loss = criterion(output, target)
loss.backward()
optimizer.step()
AddScalar("loss", loss, epoch)
if batch_idx % 1024 == 0 and batch_idx > 0:
print(
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
)
def test(epoch):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data, target)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100.0 * correct / len(test_loader.dataset)
AddScalar("accuracy", accuracy, epoch)
print(
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
f"({accuracy:.0f}%)\n"
)
model.printWeight()
def profiler():
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("model_inference"):
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
if batch_idx > 10:
prof.export_chrome_trace("local.json")
assert False
# profiler()
for epoch in range(1, 300):
train(epoch)
test(epoch)
# torch.save(model.state_dict(), "mnist_cnn.pth")
print("Model saved to mnist_cnn.pth")
if tbWriter:
tbWriter.close()