From 30dd319d8cb33ea78f4634313862056aeb65eeb9 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 26 May 2025 16:11:11 +0800 Subject: [PATCH] Update mnist. --- binary/mnist.py | 121 +++++++++++++++++++++++++++++------------------ binary/readme.md | 13 +++++ 2 files changed, 89 insertions(+), 45 deletions(-) create mode 100644 binary/readme.md diff --git a/binary/mnist.py b/binary/mnist.py index 728fb38..d911235 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -13,6 +13,7 @@ import math import torch.nn.functional as F import numpy as np from torch.utils.tensorboard import SummaryWriter +from torch.profiler import profile, ProfilerActivity, record_function import datetime torch.manual_seed(1234) @@ -54,19 +55,17 @@ class Lut(torch.autograd.Function): assert int(math.log2(weight.shape[-1])) == bits index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device) - x = (input > 0).long() - x = x * index - ind = x.sum(dim=-1) + ind = ((input > 0).long() * index).sum(dim=-1) - row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1) + row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1).to(input.device) output = weight[row_indices, ind] - ctx.save_for_backward(input, weight, ind) + ctx.save_for_backward(input, weight, row_indices, ind) return output @staticmethod def backward(ctx, grad_output): - input, weight, ind = ctx.saved_tensors + input, weight, row_indices, ind = ctx.saved_tensors grad_input = grad_weight = None batch = input.shape[0] @@ -80,7 +79,6 @@ class Lut(torch.autograd.Function): grad_weight.scatter_add_(1, ind_p, grad_output_p) if ctx.needs_input_grad[0]: - row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1) grad_input = grad_output * weight[row_indices, ind] grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) @@ -129,7 +127,8 @@ class LutGroup(nn.Module): # output [ batch, totalBits / groupBits * repeat ] batch = x.shape[0] x = x.view(batch, -1, self.groupBits) - x = x.repeat(1, self.repeat, 1) + if self.repeat > 1: + x = x.repeat(1, self.repeat, 1) x = Lut.apply(x, self.weight) return x @@ -143,14 +142,14 @@ class SimpleBNN(nn.Module): self.w = nn.Parameter(torch.randn(3, 784)) self.b = nn.Parameter(torch.zeros(3, 784)) - self.lut2_p = LutGroup(784, 4, 80) # pool2 => 14*14*4 - self.lut3_p = LutGroup(80 * 14 * 14, 4, 1) # pool2 => 7*7*4 - self.lut4_p = LutGroup(80 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8 + self.lut2_p = LutGroup(784, 4, 8) # pool2 => 14*14*4 + self.lut3_p = LutGroup(8 * 14 * 14, 4, 1) # pool2 => 7*7*4 + self.lut4_p = LutGroup(8 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8 # self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8 - self.lut5_p = LutGroup(80 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16 + self.lut5_p = LutGroup(8 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16 # self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16 - self.lut6_p = LutGroup(80 * 3 * 3, 9, 1) # conv 3 => 128 - # self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 80 + self.lut6_p = LutGroup(8 * 3 * 3, 9, 10) # conv 3 => 128 + # self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 8 self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) @@ -201,29 +200,41 @@ class SimpleBNN(nn.Module): ######################### + # unfold + # 原始输出 shape: [batch, channels * kH * kW, L] + # 其中 L 是滑动窗口的数量 + x = x.view(batch, 1, 28, 28) x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2) - x = x.view(batch, -1) + # x = x.view(batch, 1, 4, -1) + # x = x.permute(0, 1, 3, 2) + x = x.reshape(batch, -1) x = self.lut2_p(x) - x = x.view(batch, 80, 14, 14) + x = x.view(batch, 8, 14, 14) x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2) - x = x.view(batch, -1) + # x = x.view(batch, 8, 4, -1) + # x = x.permute(0, 1, 3, 2) + x = x.reshape(batch, -1) x = self.lut3_p(x) - x = x.view(batch, 80, 7, 7) + x = x.view(batch, 8, 7, 7) x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) - x = x.view(batch, -1) + # x = x.view(batch, 8, 9, -1) + # x = x.permute(0, 1, 3, 2) + x = x.reshape(batch, -1) x = self.lut4_p(x) - # x = x.view(batch, 8, 5, 5) - # x = x.permute(0, 2, 3, 1) - # x = x.reshape(batch, -1) - # x = self.lut4(x) + # x = x.view(batch, 8, 25) + # x = x.permute(0, 2, 1) # batch 25 8 + # x = x.reshape(batch, -1) # batch 25*8 + # x = self.lut4(x) # batch 8*25 - x = x.view(batch, 80, 5, 5) + x = x.reshape(batch, 8, 5, 5) x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) - x = x.view(batch, -1) + # x = x.view(batch, 8, 9, -1) + # x = x.permute(0, 1, 3, 2) + x = x.reshape(batch, -1) x = self.lut5_p(x) # x = x.view(batch, 16, 3, 3) @@ -231,14 +242,13 @@ class SimpleBNN(nn.Module): # x = x.reshape(batch, -1) # x = self.lut5(x) - x = x.view(batch, 80, 3, 3) + x = x.view(batch, 8, 3, 3) x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) - x = x.view(batch, -1) + # x = x.view(batch, 8, 9, -1) + # x = x.permute(0, 1, 3, 2) + x = x.reshape(batch, -1) x = self.lut6_p(x) - # x = x.view(batch, 8, 16) # 16 channel 8 value/channel - # x = self.lut7_p(x) # 10 * 8 value/channel - # xx = 2 ** torch.arange(7, -1, -1).to(x.device) x = x.view(batch, -1, 8) # x = x * xx @@ -254,10 +264,18 @@ model = SimpleBNN().to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=lr) -current_time = datetime.datetime.now().strftime("%m%d-%H%M%S") -writer = SummaryWriter(f"log/{current_time}") -hparam_dict = {"lr": lr, "batch_size": batch_size} -writer.add_hparams(hparam_dict, {}, run_name=f"./") +tbWriter = None + + +def AddScalar(tag, value, epoch): + global tbWriter + if not tbWriter: + current_time = datetime.datetime.now().strftime("%m%d-%H%M%S") + tbWriter = SummaryWriter(f"log/{current_time}") + hparam_dict = {"lr": lr, "batch_size": batch_size} + tbWriter.add_hparams(hparam_dict, {}, run_name=f"./") + + tbWriter.add_scalar(tag, value, epoch) def train(epoch): @@ -266,18 +284,11 @@ def train(epoch): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) - # output = output * 1.0 - # output = F.softmax(output, dim=1) - - # print(output.requires_grad) - # print(output.grad_fn) - loss = criterion(output, target) loss.backward() optimizer.step() - writer.add_scalar("loss", loss, epoch) - - if batch_idx % 512 == 0: + AddScalar("loss", loss, epoch) + if batch_idx % 512 == 0 and batch_idx > 0: print( f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} " f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}" @@ -298,17 +309,37 @@ def test(epoch): test_loss /= len(test_loader.dataset) accuracy = 100.0 * correct / len(test_loader.dataset) - writer.add_scalar("accuracy", accuracy, epoch) + AddScalar("accuracy", accuracy, epoch) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} " f"({accuracy:.0f}%)\n" ) +def profiler(): + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("model_inference"): + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) + + if batch_idx > 5: + break + + +# profiler() + for epoch in range(1, 300): train(epoch) test(epoch) # torch.save(model.state_dict(), "mnist_cnn.pth") print("Model saved to mnist_cnn.pth") -writer.close() + +if tbWriter: + tbWriter.close() diff --git a/binary/readme.md b/binary/readme.md new file mode 100644 index 0000000..bca99fc --- /dev/null +++ b/binary/readme.md @@ -0,0 +1,13 @@ + + +## 问题 + +1. 在一串的binary lut网络中 + 1. 如果每个卷积的channel相互之间没有关系 + 2. 中间插入一层,交换各个channel之间的数据,生成新的相同数量的channel + 3. 方法2的效果很差 + 1. 好像是破坏了训练,可能是训练的方法不对 + 2. 最终分类是10,10个输出之间有关系就会很差? +2. unfold输出的维度不对 + 1. LUT不是对卷积核进行计算,更容易收敛,但是精度没有更高 + 2. LUT不是对卷积核进行计算,不容易收敛,精度差不多 \ No newline at end of file