Update mnist.

This commit is contained in:
Colin 2025-05-26 16:11:11 +08:00
parent 855296be55
commit 30dd319d8c
2 changed files with 89 additions and 45 deletions

View File

@ -13,6 +13,7 @@ import math
import torch.nn.functional as F
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch.profiler import profile, ProfilerActivity, record_function
import datetime
torch.manual_seed(1234)
@ -54,19 +55,17 @@ class Lut(torch.autograd.Function):
assert int(math.log2(weight.shape[-1])) == bits
index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device)
x = (input > 0).long()
x = x * index
ind = x.sum(dim=-1)
ind = ((input > 0).long() * index).sum(dim=-1)
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1).to(input.device)
output = weight[row_indices, ind]
ctx.save_for_backward(input, weight, ind)
ctx.save_for_backward(input, weight, row_indices, ind)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, ind = ctx.saved_tensors
input, weight, row_indices, ind = ctx.saved_tensors
grad_input = grad_weight = None
batch = input.shape[0]
@ -80,7 +79,6 @@ class Lut(torch.autograd.Function):
grad_weight.scatter_add_(1, ind_p, grad_output_p)
if ctx.needs_input_grad[0]:
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
grad_input = grad_output * weight[row_indices, ind]
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
@ -129,6 +127,7 @@ class LutGroup(nn.Module):
# output [ batch, totalBits / groupBits * repeat ]
batch = x.shape[0]
x = x.view(batch, -1, self.groupBits)
if self.repeat > 1:
x = x.repeat(1, self.repeat, 1)
x = Lut.apply(x, self.weight)
return x
@ -143,14 +142,14 @@ class SimpleBNN(nn.Module):
self.w = nn.Parameter(torch.randn(3, 784))
self.b = nn.Parameter(torch.zeros(3, 784))
self.lut2_p = LutGroup(784, 4, 80) # pool2 => 14*14*4
self.lut3_p = LutGroup(80 * 14 * 14, 4, 1) # pool2 => 7*7*4
self.lut4_p = LutGroup(80 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8
self.lut2_p = LutGroup(784, 4, 8) # pool2 => 14*14*4
self.lut3_p = LutGroup(8 * 14 * 14, 4, 1) # pool2 => 7*7*4
self.lut4_p = LutGroup(8 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8
# self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8
self.lut5_p = LutGroup(80 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16
self.lut5_p = LutGroup(8 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16
# self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16
self.lut6_p = LutGroup(80 * 3 * 3, 9, 1) # conv 3 => 128
# self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 80
self.lut6_p = LutGroup(8 * 3 * 3, 9, 10) # conv 3 => 128
# self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 8
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
@ -201,29 +200,41 @@ class SimpleBNN(nn.Module):
#########################
# unfold
# 原始输出 shape: [batch, channels * kH * kW, L]
# 其中 L 是滑动窗口的数量
x = x.view(batch, 1, 28, 28)
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
x = x.view(batch, -1)
# x = x.view(batch, 1, 4, -1)
# x = x.permute(0, 1, 3, 2)
x = x.reshape(batch, -1)
x = self.lut2_p(x)
x = x.view(batch, 80, 14, 14)
x = x.view(batch, 8, 14, 14)
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
x = x.view(batch, -1)
# x = x.view(batch, 8, 4, -1)
# x = x.permute(0, 1, 3, 2)
x = x.reshape(batch, -1)
x = self.lut3_p(x)
x = x.view(batch, 80, 7, 7)
x = x.view(batch, 8, 7, 7)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
# x = x.view(batch, 8, 9, -1)
# x = x.permute(0, 1, 3, 2)
x = x.reshape(batch, -1)
x = self.lut4_p(x)
# x = x.view(batch, 8, 5, 5)
# x = x.permute(0, 2, 3, 1)
# x = x.reshape(batch, -1)
# x = self.lut4(x)
# x = x.view(batch, 8, 25)
# x = x.permute(0, 2, 1) # batch 25 8
# x = x.reshape(batch, -1) # batch 25*8
# x = self.lut4(x) # batch 8*25
x = x.view(batch, 80, 5, 5)
x = x.reshape(batch, 8, 5, 5)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
# x = x.view(batch, 8, 9, -1)
# x = x.permute(0, 1, 3, 2)
x = x.reshape(batch, -1)
x = self.lut5_p(x)
# x = x.view(batch, 16, 3, 3)
@ -231,14 +242,13 @@ class SimpleBNN(nn.Module):
# x = x.reshape(batch, -1)
# x = self.lut5(x)
x = x.view(batch, 80, 3, 3)
x = x.view(batch, 8, 3, 3)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
# x = x.view(batch, 8, 9, -1)
# x = x.permute(0, 1, 3, 2)
x = x.reshape(batch, -1)
x = self.lut6_p(x)
# x = x.view(batch, 8, 16) # 16 channel 8 value/channel
# x = self.lut7_p(x) # 10 * 8 value/channel
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
x = x.view(batch, -1, 8)
# x = x * xx
@ -254,10 +264,18 @@ model = SimpleBNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
tbWriter = None
def AddScalar(tag, value, epoch):
global tbWriter
if not tbWriter:
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
writer = SummaryWriter(f"log/{current_time}")
tbWriter = SummaryWriter(f"log/{current_time}")
hparam_dict = {"lr": lr, "batch_size": batch_size}
writer.add_hparams(hparam_dict, {}, run_name=f"./")
tbWriter.add_hparams(hparam_dict, {}, run_name=f"./")
tbWriter.add_scalar(tag, value, epoch)
def train(epoch):
@ -266,18 +284,11 @@ def train(epoch):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
# output = output * 1.0
# output = F.softmax(output, dim=1)
# print(output.requires_grad)
# print(output.grad_fn)
loss = criterion(output, target)
loss.backward()
optimizer.step()
writer.add_scalar("loss", loss, epoch)
if batch_idx % 512 == 0:
AddScalar("loss", loss, epoch)
if batch_idx % 512 == 0 and batch_idx > 0:
print(
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
@ -298,17 +309,37 @@ def test(epoch):
test_loss /= len(test_loader.dataset)
accuracy = 100.0 * correct / len(test_loader.dataset)
writer.add_scalar("accuracy", accuracy, epoch)
AddScalar("accuracy", accuracy, epoch)
print(
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
f"({accuracy:.0f}%)\n"
)
def profiler():
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("model_inference"):
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
if batch_idx > 5:
break
# profiler()
for epoch in range(1, 300):
train(epoch)
test(epoch)
# torch.save(model.state_dict(), "mnist_cnn.pth")
print("Model saved to mnist_cnn.pth")
writer.close()
if tbWriter:
tbWriter.close()

13
binary/readme.md Normal file
View File

@ -0,0 +1,13 @@
## 问题
1. 在一串的binary lut网络中
1. 如果每个卷积的channel相互之间没有关系
2. 中间插入一层交换各个channel之间的数据生成新的相同数量的channel
3. 方法2的效果很差
1. 好像是破坏了训练,可能是训练的方法不对
2. 最终分类是1010个输出之间有关系就会很差
2. unfold输出的维度不对
1. LUT不是对卷积核进行计算,更容易收敛,但是精度没有更高
2. LUT不是对卷积核进行计算,不容易收敛,精度差不多