import os os.environ["CUDA_LAUNCH_BLOCKING"] = "1" import torch import torch.nn as nn import torch.optim as optim import torchvision from torchvision import transforms from torch.utils.data import DataLoader import math import torch.nn.functional as F import numpy as np torch.manual_seed(1234) np.random.seed(1234) torch.cuda.manual_seed_all(1234) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = "cpu" # torch.set_num_threads(16) print(f"Using device: {device}") transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # MNIST数据集的均值和标准差 ) train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False) def to_binary_tensor(input_tensor, bits): int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64) shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device) binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1 return binary_bits class Lut(torch.autograd.Function): # input [batch,count,bits] # weight [count,2**bits] @staticmethod def forward(ctx, input, weight): batch = input.shape[0] count = input.shape[1] bits = input.shape[2] assert int(math.log2(weight.shape[-1])) == bits index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device) x = (input > 0).long() x = x * index ind = x.sum(dim=-1) row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1) output = weight[row_indices, ind] ctx.save_for_backward(input, weight, ind) return output @staticmethod def backward(ctx, grad_output): input, weight, ind = ctx.saved_tensors grad_input = grad_weight = None batch = input.shape[0] count = input.shape[1] bits = input.shape[2] if ctx.needs_input_grad[1]: grad_weight = torch.zeros_like(weight) ind_p = ind.permute(1, 0) grad_output_p = grad_output.permute(1, 0) grad_weight.scatter_add_(1, ind_p, grad_output_p) if ctx.needs_input_grad[0]: row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1) grad_input = grad_output * weight[row_indices, ind] grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) return grad_input, grad_weight class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.bn = nn.BatchNorm1d(320 * 4) self.fc1 = nn.Linear(160, 50) self.fc2 = nn.Linear(50, 10) self.pool = nn.MaxPool2d(2) self.relu = nn.ReLU() self.weight = nn.Parameter(torch.randn(160, pow(2, 8))) def forward(self, x): x = self.relu(self.pool(self.conv1(x))) x = self.relu((self.conv2(x))) x = x.view(-1, 320 * 4) x = self.bn(x) x = x.view(-1, 160, 8) x = Lut.apply(x, self.weight) x = self.relu(self.fc1(x)) x = self.fc2(x) return x class LutGroup(nn.Module): def __init__(self, bits, subbits): super(LutGroup, self).__init__() assert (bits % subbits) == 0 self.weight = nn.Parameter(torch.randn(bits, pow(2, subbits))) self.bits = bits self.subbits = subbits def forward(self, x): batch = x.shape[0] x = x.view(batch, -1, self.subbits) x = Lut.apply(x, self.weight) return x class LutParallel(nn.Module): def __init__(self, bits, number): super(LutParallel, self).__init__() self.number = number self.weight = nn.Parameter(torch.randn(number, pow(2, bits))) def forward(self, x): x = x.unsqueeze(1).repeat(1, self.number, 1) x = Lut.apply(x, self.weight) return x class SimpleBNN(nn.Module): def __init__(self): super(SimpleBNN, self).__init__() # self.w = nn.Parameter(torch.randn(3, 784 * 8)) # self.b = nn.Parameter(torch.zeros(3, 784 * 8)) self.w = nn.Parameter(torch.randn(3, 784)) self.b = nn.Parameter(torch.zeros(3, 784)) self.lut1 = LutGroup(784 * 8, 8) self.lut2 = LutGroup(784, 4) self.lut3 = LutGroup(196, 4) self.lut4 = LutGroup(49, 7) self.lut5 = LutParallel(7, 10) self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) self.pool = nn.MaxPool2d(2) self.relu = nn.ReLU() def forward(self, x): batch = x.shape[0] x = x.view(batch, -1) # 变换x [-0.5:0.5] 到 0-255,然后按照二进制展开成8个值 # x = (x * 256 + 128).clamp(0, 255).to(torch.uint8) # xx = torch.arange(7, -1, -1).to(x.device) # bits = (x.unsqueeze(-1) >> xx) & 1 # x = bits.view(batch, -1) # x = x.float() - 0.5 x = (x > 0).float() q = x * self.w[0] + self.b[0] k = x * self.w[1] + self.b[1] v = x * self.w[2] + self.b[2] q = q.view(batch, -1, 1) k = k.view(batch, 1, -1) v = v.view(batch, -1, 1) kq = q @ k kqv = kq @ v kqv = kqv.view(batch, -1, 8) x = kqv ######################### # x = (x > 0) << xx # x = x.sum(2) # x = x.view(batch, 1, 28, 28) # x = (x - 128.0) / 256.0 # x = x.view(batch, 1, 28, 28) # x = (x > 0).float() # x = self.relu(self.pool(self.conv1(x))) # x = self.relu(self.pool((self.conv2(x)))) # x = x.view(-1, 320) # x = self.relu(self.fc1(x)) # x = self.fc2(x) ######################### x = (x > 0).float() x = x.view(batch, 196, 4) # x = self.lut1(x) x = self.lut2(x) x = x.view(-1, 28, 7) x = x.permute(0, 2, 1) x = x.reshape(-1, 28 * 7) x = self.lut3(x) x = self.lut4(x) x = self.lut5(x) return x torch.autograd.set_detect_anomaly(True) # model = SimpleCNN().to(device) model = SimpleBNN().to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.001) def train(epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) # output = output * 1.0 # output = F.softmax(output, dim=1) # print(output.requires_grad) # print(output.grad_fn) loss = criterion(output, target) loss.backward() optimizer.step() if batch_idx % 100 == 0: print( f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} " f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}" ) def test(): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += criterion(output, target).item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) accuracy = 100.0 * correct / len(test_loader.dataset) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} " f"({accuracy:.0f}%)\n" ) for epoch in range(1, 30): train(epoch) test() # torch.save(model.state_dict(), "mnist_cnn.pth") print("Model saved to mnist_cnn.pth")