From 878c690ac468d292f3a32a2e04d13457360a1d8f Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 9 Jun 2025 15:56:03 +0800 Subject: [PATCH] Refine LUTCNN, keep accuracy to ~93 --- binary/mnist.py | 219 +++++++++++++++++++++++------------------------- 1 file changed, 104 insertions(+), 115 deletions(-) diff --git a/binary/mnist.py b/binary/mnist.py index 728fb38..dbad614 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -13,14 +13,16 @@ import math import torch.nn.functional as F import numpy as np from torch.utils.tensorboard import SummaryWriter +from torch.profiler import profile, ProfilerActivity, record_function +from unfold import generate_unfold_index import datetime torch.manual_seed(1234) np.random.seed(1234) torch.cuda.manual_seed_all(1234) -batch_size = 16 -lr = 0.001 +BS = 16 +LR = 0.001 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") @@ -32,59 +34,38 @@ transform = transforms.Compose( train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) -train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) -test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False) - - -def to_binary_tensor(input_tensor, bits): - int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64) - shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device) - binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1 - return binary_bits +train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True, drop_last=True, num_workers=4) +test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=True) class Lut(torch.autograd.Function): - # input [batch,count,bits] - # weight [count,2**bits] + # input [batch, group, bits ] + # output [batch, group ] + # weight [2**bits, group ] @staticmethod - def forward(ctx, input, weight): - batch = input.shape[0] - count = input.shape[1] - bits = input.shape[2] - assert int(math.log2(weight.shape[-1])) == bits - - index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device) - x = (input > 0).long() - x = x * index - ind = x.sum(dim=-1) - - row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1) - output = weight[row_indices, ind] - - ctx.save_for_backward(input, weight, ind) + def forward(ctx, input, weight, index): + ind = ((input > 0).long() * index).sum(dim=-1) + output = torch.gather(weight, 0, ind) + # output = (output > 0).float() + # output = (output - 0.5) * 2.0 + ctx.save_for_backward(input, weight, ind, output) return output @staticmethod def backward(ctx, grad_output): - input, weight, ind = ctx.saved_tensors + input, weight, ind, output = ctx.saved_tensors grad_input = grad_weight = None - - batch = input.shape[0] - count = input.shape[1] bits = input.shape[2] if ctx.needs_input_grad[1]: grad_weight = torch.zeros_like(weight) - ind_p = ind.permute(1, 0) - grad_output_p = grad_output.permute(1, 0) - grad_weight.scatter_add_(1, ind_p, grad_output_p) + grad_weight.scatter_add_(0, ind, grad_output) if ctx.needs_input_grad[0]: - row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1) - grad_input = grad_output * weight[row_indices, ind] + grad_input = grad_output * torch.gather(weight, 0, ind) grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) - return grad_input, grad_weight + return grad_input, grad_weight, None, None class SimpleCNN(nn.Module): @@ -114,23 +95,57 @@ class SimpleCNN(nn.Module): class LutGroup(nn.Module): - def __init__(self, totalBits, groupBits=0, repeat=1): - if not groupBits: - groupBits = totalBits + def __init__(self, group, groupBits, groupRepeat=1): + assert groupBits > 1 super(LutGroup, self).__init__() - assert (totalBits % groupBits) == 0 - self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits))) - self.totalBits = totalBits + self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group))) + self.group = group self.groupBits = groupBits - self.repeat = repeat + self.groupRepeat = groupRepeat + self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False) def forward(self, x): - # input [ batch, totalBits ] - # output [ batch, totalBits / groupBits * repeat ] + # input [ batch, group * groupBits ] + # output [ batch, group * groupRepeat ] batch = x.shape[0] - x = x.view(batch, -1, self.groupBits) - x = x.repeat(1, self.repeat, 1) - x = Lut.apply(x, self.weight) + x = x.reshape(batch, -1, self.groupBits) + if self.groupRepeat > 1: + x = x.repeat(1, self.groupRepeat, 1) + x = Lut.apply(x, self.weight, self.index) + return x + + +class LutCnn(nn.Module): + def __init__(self, channel_repeat, input_shape, kernel_size, stride, dilation, fc=False): + super(LutCnn, self).__init__() + B, C, H, W = input_shape + self.input_shape = input_shape + self.kernel_size = kernel_size + self.channel_repeat = channel_repeat + self.stride = stride + self.dilation = dilation + batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation) + self.batch_idx = nn.Parameter(batch_idx, requires_grad=False) + self.channel_idx = nn.Parameter(channel_idx, requires_grad=False) + self.h_idx = nn.Parameter(h_idx, requires_grad=False) + self.w_idx = nn.Parameter(w_idx, requires_grad=False) + groupBits = kernel_size * kernel_size + group = int(len(self.batch_idx) / B / groupBits) + self.lut = LutGroup(group, groupBits, channel_repeat) + self.fc = fc + if fc: + self.lutc = LutGroup(group, channel_repeat * C, channel_repeat * C) + + def forward(self, x): + B, C, H, W = self.input_shape + x = x.view(self.input_shape) + x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)] + x = x.view(B, -1, self.kernel_size * self.kernel_size) + x = self.lut(x) + if self.fc: + x = x.view(B, self.channel_repeat * C, -1) + x = x.permute(0, 2, 1) + x = self.lutc(x) return x @@ -143,14 +158,12 @@ class SimpleBNN(nn.Module): self.w = nn.Parameter(torch.randn(3, 784)) self.b = nn.Parameter(torch.zeros(3, 784)) - self.lut2_p = LutGroup(784, 4, 80) # pool2 => 14*14*4 - self.lut3_p = LutGroup(80 * 14 * 14, 4, 1) # pool2 => 7*7*4 - self.lut4_p = LutGroup(80 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8 - # self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8 - self.lut5_p = LutGroup(80 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16 - # self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16 - self.lut6_p = LutGroup(80 * 3 * 3, 9, 1) # conv 3 => 128 - # self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 80 + # channel_repeat, input_shape, kernel_size, stride, dilation, fc + self.lnn1 = LutCnn(80, (BS, 1, 28, 28), 2, 2, 1, False) + self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False) + self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False) + self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False) + self.lnn5 = LutCnn(10, (BS, 80, 3, 3), 3, 1, 1) self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) @@ -159,9 +172,9 @@ class SimpleBNN(nn.Module): self.pool = nn.MaxPool2d(2) self.relu = nn.ReLU() - def forward(self, x): + def forward(self, x, t): batch = x.shape[0] - x = x.view(batch, -1) + # x = x.view(batch, -1) # 变换x [-0.5:0.5] 到 0-255,然后按照二进制展开成8个值 # x = (x * 256 + 128).clamp(0, 255).to(torch.uint8) @@ -201,43 +214,11 @@ class SimpleBNN(nn.Module): ######################### - x = x.view(batch, 1, 28, 28) - x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2) - x = x.view(batch, -1) - x = self.lut2_p(x) - - x = x.view(batch, 80, 14, 14) - x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2) - x = x.view(batch, -1) - x = self.lut3_p(x) - - x = x.view(batch, 80, 7, 7) - x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) - x = x.view(batch, -1) - x = self.lut4_p(x) - - # x = x.view(batch, 8, 5, 5) - # x = x.permute(0, 2, 3, 1) - # x = x.reshape(batch, -1) - # x = self.lut4(x) - - x = x.view(batch, 80, 5, 5) - x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) - x = x.view(batch, -1) - x = self.lut5_p(x) - - # x = x.view(batch, 16, 3, 3) - # x = x.permute(0, 2, 3, 1) - # x = x.reshape(batch, -1) - # x = self.lut5(x) - - x = x.view(batch, 80, 3, 3) - x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) - x = x.view(batch, -1) - x = self.lut6_p(x) - - # x = x.view(batch, 8, 16) # 16 channel 8 value/channel - # x = self.lut7_p(x) # 10 * 8 value/channel + x = self.lnn1(x) + x = self.lnn2(x) + x = self.lnn3(x) + x = self.lnn4(x) + x = self.lnn5(x) # xx = 2 ** torch.arange(7, -1, -1).to(x.device) x = x.view(batch, -1, 8) @@ -247,17 +228,29 @@ class SimpleBNN(nn.Module): return x + def printWeight(self): + pass + torch.autograd.set_detect_anomaly(True) # model = SimpleCNN().to(device) model = SimpleBNN().to(device) +# model = SimpleLNN().to(device) criterion = nn.CrossEntropyLoss() -optimizer = torch.optim.AdamW(model.parameters(), lr=lr) +optimizer = torch.optim.AdamW(model.parameters(), lr=LR) -current_time = datetime.datetime.now().strftime("%m%d-%H%M%S") -writer = SummaryWriter(f"log/{current_time}") -hparam_dict = {"lr": lr, "batch_size": batch_size} -writer.add_hparams(hparam_dict, {}, run_name=f"./") +tbWriter = None + + +def AddScalar(tag, value, epoch): + global tbWriter + if not tbWriter: + current_time = datetime.datetime.now().strftime("%m%d-%H%M%S") + tbWriter = SummaryWriter(f"log/{current_time}") + hparam_dict = {"lr": LR, "batch_size": BS} + tbWriter.add_hparams(hparam_dict, {}, run_name=f"./") + + tbWriter.add_scalar(tag, value, epoch) def train(epoch): @@ -265,19 +258,12 @@ def train(epoch): for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() - output = model(data) - # output = output * 1.0 - # output = F.softmax(output, dim=1) - - # print(output.requires_grad) - # print(output.grad_fn) - + output = model(data, target) loss = criterion(output, target) loss.backward() optimizer.step() - writer.add_scalar("loss", loss, epoch) - - if batch_idx % 512 == 0: + AddScalar("loss", loss, epoch) + if batch_idx % 1024 == 0 and batch_idx > 0: print( f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} " f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}" @@ -291,18 +277,19 @@ def test(epoch): with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) - output = model(data) + output = model(data, target) test_loss += criterion(output, target).item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) accuracy = 100.0 * correct / len(test_loader.dataset) - writer.add_scalar("accuracy", accuracy, epoch) + AddScalar("accuracy", accuracy, epoch) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} " f"({accuracy:.0f}%)\n" ) + model.printWeight() for epoch in range(1, 300): @@ -311,4 +298,6 @@ for epoch in range(1, 300): # torch.save(model.state_dict(), "mnist_cnn.pth") print("Model saved to mnist_cnn.pth") -writer.close() + +if tbWriter: + tbWriter.close()