diff --git a/binary/mnist.py b/binary/mnist.py new file mode 100644 index 0000000..db6361b --- /dev/null +++ b/binary/mnist.py @@ -0,0 +1,278 @@ +import os + +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + + +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +from torchvision import transforms +from torch.utils.data import DataLoader +import math +import torch.nn.functional as F +import numpy as np + +torch.manual_seed(1234) +np.random.seed(1234) +torch.cuda.manual_seed_all(1234) + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # MNIST数据集的均值和标准差 +) + +train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) +test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) + +train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) +test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False) + + +def to_binary_tensor(input_tensor, bits): + int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64) + shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device) + binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1 + return binary_bits + + +class MyLut(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight): + batch = input.shape[0] + count = input.shape[1] + bits = input.shape[2] + assert int(math.log2(weight.shape[-1])) == bits + + index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device) + x = (input > 0).long() + x = x * index + ind = x.sum(dim=-1) + + row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1) + output = weight[row_indices, ind] + + ctx.save_for_backward(input, weight, ind) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight, ind = ctx.saved_tensors + grad_input = grad_weight = None + + batch = input.shape[0] + count = input.shape[1] + bits = input.shape[2] + + if ctx.needs_input_grad[1]: + grad_weight = torch.zeros_like(weight) + ind_p = ind.permute(1, 0) + grad_output_p = grad_output.permute(1, 0) + grad_weight.scatter_add_(1, ind_p, grad_output_p) + + if ctx.needs_input_grad[0]: + row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1) + grad_input = grad_output * weight[row_indices, ind] + grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) + + return grad_input, grad_weight + + +class SimpleCNN(nn.Module): + def __init__(self): + super(SimpleCNN, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.bn = nn.BatchNorm1d(320 * 4) + self.fc1 = nn.Linear(160, 50) + self.fc2 = nn.Linear(50, 10) + self.pool = nn.MaxPool2d(2) + self.relu = nn.ReLU() + self.weight = nn.Parameter(torch.randn(160, pow(2, 8))) + + def forward(self, x): + x = self.relu(self.pool(self.conv1(x))) + x = self.relu((self.conv2(x))) + x = x.view(-1, 320 * 4) + x = self.bn(x) + x = x.view(-1, 160, 8) + + x = MyLut.apply(x, self.weight) + + x = self.relu(self.fc1(x)) + x = self.fc2(x) + return x + + +class Lut(nn.Module): + def __init__(self, bits): + super(Lut, self).__init__() + self.weight = nn.Parameter(torch.randn(pow(2, bits))) + self.bias = nn.Parameter(torch.randn(pow(2, bits))) + self.index = torch.pow(2, torch.arange(bits)) + self.bits = bits + + def forward(self, x): + + x = MyLut.apply(x, self.weight, self.bias) + + # tmp = torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) + # x = tmp + x - x.detach() + + xx = (x > 0).float() + x = xx + (x - x.detach()) + + # print(xx.requires_grad) + # print(xx.grad_fn) + x = x * (self.index.to(x.device)) + x = torch.sum(x, dim=-1) + + w = torch.gather(self.weight, 0, x.long()) + b = torch.gather(self.weight, 0, x.long()) + x = w * x + b + + # tmp = torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) + # x = tmp + x - x.detach() + + xx = (x > 0).float() + x = xx + (x - x.detach()) + + x = x.view(-1, 1) + return x + + +class LutGroup(nn.Module): + def __init__(self, bits, subbits): + super(LutGroup, self).__init__() + assert (bits % subbits) == 0 + self.lutlist = nn.ModuleList([Lut(subbits) for _ in range(int(bits / subbits))]) + self.bits = bits + self.subbits = subbits + + def forward(self, x): + ll = len(self.lutlist) + tmp = torch.empty((x.shape[0], 0), dtype=x.dtype, device=x.device) + start = 0 + end = self.subbits + for i in range(ll): + tx = self.lutlist[i](x[:, start:end]) + tmp = torch.cat((tmp, tx), dim=1) + start += self.subbits + end += self.subbits + return tmp + + +class LutParallel(nn.Module): + def __init__(self, bits, number): + super(LutParallel, self).__init__() + self.lutlist = nn.ModuleList([Lut(bits) for _ in range(number)]) + self.bits = bits + self.number = number + + def forward(self, x): + tmp = torch.empty((x.shape[0], 0), dtype=x.dtype, device=x.device) + for i in range(self.number): + tx = self.lutlist[i](x) + tmp = torch.cat((tmp, tx), dim=1) + return tmp + + +class SimpleBNN(nn.Module): + def __init__(self): + super(SimpleCNN, self).__init__() + self.w = nn.Parameter(torch.randn(3, 784 * 8)) + self.b = nn.Parameter(torch.zeros(3, 784 * 8)) + self.lut1 = LutGroup(784 * 8, 8) + self.lut2 = LutGroup(784, 8) + self.lut3 = LutGroup(98, 14) + self.lut4 = LutParallel(7, 10) + + def forward(self, x): + batch = x.shape[0] + x = x.view(batch, -1) + + # 变换x [-0.5:0.5] 到 0-255,然后按照二进制展开成8个值 + x = (x * 256 + 128).clamp(0, 255).to(torch.uint8) + xx = torch.arange(8).to(x.device) + bits = (x.unsqueeze(-1) >> xx) & 1 + bits = bits.view(batch, -1) + + x = bits + # q = x * self.w[0] + self.b[0] + # k = x * self.w[1] + self.b[1] + # v = x * self.w[2] + self.b[2] + + # q = q.view(batch, -1, 1) + # k = k.view(batch, 1, -1) + # v = v.view(batch, -1, 1) + # kq = q @ k + # kqv = kq @ v + # kqv = kqv.view(batch, -1) + kqv = x + + x = self.lut1(kqv) + x = self.lut2(x) + x = self.lut3(x) + x = self.lut4(x) + return x + + +torch.autograd.set_detect_anomaly(True) +model = SimpleCNN().to(device) +# model = SimpleBNN().to(device) +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + +def train(epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + # output = output * 1.0 + # output = F.softmax(output, dim=1) + + # print(output.requires_grad) + # print(output.grad_fn) + + loss = criterion(output, target) + loss.backward() + optimizer.step() + + if batch_idx % 100 == 0: + print( + f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} " + f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}" + ) + + +def test(): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += criterion(output, target).item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + accuracy = 100.0 * correct / len(test_loader.dataset) + print( + f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} " + f"({accuracy:.0f}%)\n" + ) + + +for epoch in range(1, 30): + train(epoch) + test() + +# torch.save(model.state_dict(), "mnist_cnn.pth") +print("Model saved to mnist_cnn.pth")