import os os.environ["CUDA_LAUNCH_BLOCKING"] = "1" import torch import torch.nn as nn import torch.optim as optim import torchvision from torchvision import transforms from torch.utils.data import DataLoader import math import torch.nn.functional as F import numpy as np from torch.utils.tensorboard import SummaryWriter from torch.profiler import profile, ProfilerActivity, record_function from unfold import generate_unfold_index import datetime torch.manual_seed(1234) np.random.seed(1234) torch.cuda.manual_seed_all(1234) BS = 16 LR = 0.01 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # MNIST数据集的均值和标准差 ) train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True, drop_last=True, num_workers=4) test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=True) class Lut(torch.autograd.Function): # input [batch, group, bits ] # output [batch, group ] # weight [2**bits, group ] @staticmethod def forward(ctx, input, weight, index): ind = ((input > 0).long() * index).sum(dim=-1) output = torch.gather(weight, 0, ind) output = (output > 0).float() output = (output - 0.5) * 2.0 ctx.save_for_backward(input, weight, ind, output) return output @staticmethod def backward(ctx, grad_output): input, weight, ind, output = ctx.saved_tensors grad_input = grad_weight = None bits = input.shape[2] if ctx.needs_input_grad[1]: grad_weight = torch.zeros_like(weight) grad_weight.scatter_add_(0, ind, grad_output) if ctx.needs_input_grad[0]: # grad_input = grad_output * torch.gather(weight, 0, ind) grad_input = grad_output grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) output = output.unsqueeze(-1).repeat(1, 1, bits) in_sign = ((input > 0).float() - 0.5) * 2.0 grad_input = grad_input * in_sign grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0) # grad_input = grad_output # grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) # output = output.unsqueeze(-1).repeat(1, 1, bits) # in_sign = ((input > 0).float() - 0.5) * 2.0 # out_sign = ((output > 0).float() - 0.5) * 2.0 # grad_sign = ((grad_input > 0).float() - 0.5) * 2.0 # grad_input = grad_input * in_sign * (out_sign * grad_sign) # grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0) # 需要一个动态的调整系数 # 能稳定的收敛 # print(in_sign[0].detach().cpu().numpy()) # print(out_sign[0].detach().cpu().numpy()) # print(grad_sign[0].detach().cpu().numpy()) # print(grad_input[0].detach().cpu().numpy()) return grad_input, grad_weight, None, None class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.bn = nn.BatchNorm1d(320 * 4) self.fc1 = nn.Linear(160, 50) self.fc2 = nn.Linear(50, 10) self.pool = nn.MaxPool2d(2) self.relu = nn.ReLU() self.weight = nn.Parameter(torch.randn(160, pow(2, 8))) def forward(self, x): x = self.relu(self.pool(self.conv1(x))) x = self.relu((self.conv2(x))) x = x.view(-1, 320 * 4) x = self.bn(x) x = x.view(-1, 160, 8) x = Lut.apply(x, self.weight) x = self.relu(self.fc1(x)) x = self.fc2(x) return x class LutGroup(nn.Module): def __init__(self, group, groupBits, groupRepeat=1): assert groupBits > 1 super(LutGroup, self).__init__() self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group))) self.group = group self.groupBits = groupBits self.groupRepeat = groupRepeat self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False) def forward(self, x): # input [ batch, group * groupBits ] # output [ batch, group * groupRepeat ] batch = x.shape[0] x = x.reshape(batch, -1, self.groupBits) if self.groupRepeat > 1: x = x.repeat(1, self.groupRepeat, 1) x = Lut.apply(x, self.weight, self.index) return x class LutCnn(nn.Module): def __init__(self, channel_repeat, input_shape, kernel_size, stride, dilation, fc=False): super(LutCnn, self).__init__() B, C, H, W = input_shape self.input_shape = input_shape self.kernel_size = kernel_size self.channel_repeat = channel_repeat self.stride = stride self.dilation = dilation batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation) self.batch_idx = nn.Parameter(batch_idx, requires_grad=False) self.channel_idx = nn.Parameter(channel_idx, requires_grad=False) self.h_idx = nn.Parameter(h_idx, requires_grad=False) self.w_idx = nn.Parameter(w_idx, requires_grad=False) groupBits = kernel_size * kernel_size group = int(len(self.batch_idx) / B / groupBits) self.lut = LutGroup(group, groupBits, channel_repeat) self.fc = fc if fc: self.lutc = LutGroup(group, channel_repeat * C, channel_repeat * C) def forward(self, x): B, C, H, W = self.input_shape x = x.view(self.input_shape) x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)] x = x.view(B, -1, self.kernel_size * self.kernel_size) x = self.lut(x) if self.fc: x = x.view(B, -1, self.channel_repeat) x = x.permute(0, 2, 1) x = self.lutc(x) return x class SimpleBNN(nn.Module): def __init__(self): super(SimpleBNN, self).__init__() # self.w = nn.Parameter(torch.randn(3, 784 * 8)) # self.b = nn.Parameter(torch.zeros(3, 784 * 8)) self.w = nn.Parameter(torch.randn(3, 784)) self.b = nn.Parameter(torch.zeros(3, 784)) # channel_repeat, input_shape, kernel_size, stride, dilation, fc self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1, False) self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1, False) self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1, False) self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1, False) self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1) self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) self.pool = nn.MaxPool2d(2) self.relu = nn.ReLU() def forward(self, x, t): batch = x.shape[0] # x = x.view(batch, -1) # 变换x [-0.5:0.5] 到 0-255,然后按照二进制展开成8个值 # x = (x * 256 + 128).clamp(0, 255).to(torch.uint8) # xx = torch.arange(7, -1, -1).to(x.device) # bits = (x.unsqueeze(-1) >> xx) & 1 # x = bits.view(batch, -1) # x = x.float() - 0.5 # x = (x > 0).float() # q = x * self.w[0] + self.b[0] # k = x * self.w[1] + self.b[1] # v = x * self.w[2] + self.b[2] # q = q.view(batch, -1, 1) # k = k.view(batch, 1, -1) # v = v.view(batch, -1, 1) # kq = q @ k # kqv = kq @ v # kqv = kqv.view(batch, -1, 8) # x = kqv ######################### # # x = (x > 0) << xx # # x = x.sum(2) # # x = x.view(batch, 1, 28, 28) # # x = (x - 128.0) / 256.0 # x = (x > 0).float() # x = x.view(batch, 1, 28, 28) # x = self.relu(self.pool(self.conv1(x))) # x = self.relu(self.pool((self.conv2(x)))) # x = x.view(-1, 320) # x = self.relu(self.fc1(x)) # x = self.fc2(x) ######################### x = self.lnn1(x) x = self.lnn2(x) x = self.lnn3(x) x = self.lnn4(x) x = self.lnn5(x) # xx = 2 ** torch.arange(7, -1, -1).to(x.device) x = x.view(batch, -1, 8) # x = x * xx # x = (x - 128.0) / 256.0 x = x.sum(2) return x def printWeight(self): pass class SimpleLNN(nn.Module): def __init__(self): super(SimpleLNN, self).__init__() # group, groupBits, groupRepeat self.lutg1 = LutGroup(1, 10, 4) self.lutg2 = LutGroup(1, 4, 10) def forward(self, x, t): batch = x.shape[0] x = torch.zeros_like(t).unsqueeze(-1).repeat(1, 10) x[torch.arange(0, batch), t] = 1 x = self.lutg1(x) x = self.lutg2(x) return x def printWeight(self): print("self.lutg1") print(self.lutg1.weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy()) print("=============================") print("=============================") print("self.lutg1.grad") print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy()) print("=============================") print("=============================") # print("self.lutg2") # print(self.lutg2.weight.detach().cpu().numpy()) # print("=============================") # print("=============================") torch.autograd.set_detect_anomaly(True) # model = SimpleCNN().to(device) # model = SimpleBNN().to(device) model = SimpleLNN().to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=LR) tbWriter = None def AddScalar(tag, value, epoch): global tbWriter if not tbWriter: current_time = datetime.datetime.now().strftime("%m%d-%H%M%S") tbWriter = SummaryWriter(f"log/{current_time}") hparam_dict = {"lr": LR, "batch_size": BS} tbWriter.add_hparams(hparam_dict, {}, run_name=f"./") tbWriter.add_scalar(tag, value, epoch) def train(epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data, target) loss = criterion(output, target) loss.backward() optimizer.step() AddScalar("loss", loss, epoch) if batch_idx % 1024 == 0 and batch_idx > 0: print( f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} " f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}" ) def test(epoch): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data, target) test_loss += criterion(output, target).item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) accuracy = 100.0 * correct / len(test_loader.dataset) AddScalar("accuracy", accuracy, epoch) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} " f"({accuracy:.0f}%)\n" ) model.printWeight() def profiler(): for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: with record_function("model_inference"): output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) if batch_idx > 10: prof.export_chrome_trace("local.json") assert False # profiler() for epoch in range(1, 300): train(epoch) test(epoch) # torch.save(model.state_dict(), "mnist_cnn.pth") print("Model saved to mnist_cnn.pth") if tbWriter: tbWriter.close()