Update mnist.
This commit is contained in:
parent
855296be55
commit
30dd319d8c
121
binary/mnist.py
121
binary/mnist.py
|
@ -13,6 +13,7 @@ import math
|
|||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.profiler import profile, ProfilerActivity, record_function
|
||||
import datetime
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
@ -54,19 +55,17 @@ class Lut(torch.autograd.Function):
|
|||
assert int(math.log2(weight.shape[-1])) == bits
|
||||
|
||||
index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device)
|
||||
x = (input > 0).long()
|
||||
x = x * index
|
||||
ind = x.sum(dim=-1)
|
||||
ind = ((input > 0).long() * index).sum(dim=-1)
|
||||
|
||||
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
|
||||
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1).to(input.device)
|
||||
output = weight[row_indices, ind]
|
||||
|
||||
ctx.save_for_backward(input, weight, ind)
|
||||
ctx.save_for_backward(input, weight, row_indices, ind)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, ind = ctx.saved_tensors
|
||||
input, weight, row_indices, ind = ctx.saved_tensors
|
||||
grad_input = grad_weight = None
|
||||
|
||||
batch = input.shape[0]
|
||||
|
@ -80,7 +79,6 @@ class Lut(torch.autograd.Function):
|
|||
grad_weight.scatter_add_(1, ind_p, grad_output_p)
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
|
||||
grad_input = grad_output * weight[row_indices, ind]
|
||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||
|
||||
|
@ -129,7 +127,8 @@ class LutGroup(nn.Module):
|
|||
# output [ batch, totalBits / groupBits * repeat ]
|
||||
batch = x.shape[0]
|
||||
x = x.view(batch, -1, self.groupBits)
|
||||
x = x.repeat(1, self.repeat, 1)
|
||||
if self.repeat > 1:
|
||||
x = x.repeat(1, self.repeat, 1)
|
||||
x = Lut.apply(x, self.weight)
|
||||
return x
|
||||
|
||||
|
@ -143,14 +142,14 @@ class SimpleBNN(nn.Module):
|
|||
self.w = nn.Parameter(torch.randn(3, 784))
|
||||
self.b = nn.Parameter(torch.zeros(3, 784))
|
||||
|
||||
self.lut2_p = LutGroup(784, 4, 80) # pool2 => 14*14*4
|
||||
self.lut3_p = LutGroup(80 * 14 * 14, 4, 1) # pool2 => 7*7*4
|
||||
self.lut4_p = LutGroup(80 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8
|
||||
self.lut2_p = LutGroup(784, 4, 8) # pool2 => 14*14*4
|
||||
self.lut3_p = LutGroup(8 * 14 * 14, 4, 1) # pool2 => 7*7*4
|
||||
self.lut4_p = LutGroup(8 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8
|
||||
# self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8
|
||||
self.lut5_p = LutGroup(80 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16
|
||||
self.lut5_p = LutGroup(8 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16
|
||||
# self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16
|
||||
self.lut6_p = LutGroup(80 * 3 * 3, 9, 1) # conv 3 => 128
|
||||
# self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 80
|
||||
self.lut6_p = LutGroup(8 * 3 * 3, 9, 10) # conv 3 => 128
|
||||
# self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 8
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
|
@ -201,29 +200,41 @@ class SimpleBNN(nn.Module):
|
|||
|
||||
#########################
|
||||
|
||||
# unfold
|
||||
# 原始输出 shape: [batch, channels * kH * kW, L]
|
||||
# 其中 L 是滑动窗口的数量
|
||||
|
||||
x = x.view(batch, 1, 28, 28)
|
||||
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
||||
x = x.view(batch, -1)
|
||||
# x = x.view(batch, 1, 4, -1)
|
||||
# x = x.permute(0, 1, 3, 2)
|
||||
x = x.reshape(batch, -1)
|
||||
x = self.lut2_p(x)
|
||||
|
||||
x = x.view(batch, 80, 14, 14)
|
||||
x = x.view(batch, 8, 14, 14)
|
||||
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
||||
x = x.view(batch, -1)
|
||||
# x = x.view(batch, 8, 4, -1)
|
||||
# x = x.permute(0, 1, 3, 2)
|
||||
x = x.reshape(batch, -1)
|
||||
x = self.lut3_p(x)
|
||||
|
||||
x = x.view(batch, 80, 7, 7)
|
||||
x = x.view(batch, 8, 7, 7)
|
||||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||
x = x.view(batch, -1)
|
||||
# x = x.view(batch, 8, 9, -1)
|
||||
# x = x.permute(0, 1, 3, 2)
|
||||
x = x.reshape(batch, -1)
|
||||
x = self.lut4_p(x)
|
||||
|
||||
# x = x.view(batch, 8, 5, 5)
|
||||
# x = x.permute(0, 2, 3, 1)
|
||||
# x = x.reshape(batch, -1)
|
||||
# x = self.lut4(x)
|
||||
# x = x.view(batch, 8, 25)
|
||||
# x = x.permute(0, 2, 1) # batch 25 8
|
||||
# x = x.reshape(batch, -1) # batch 25*8
|
||||
# x = self.lut4(x) # batch 8*25
|
||||
|
||||
x = x.view(batch, 80, 5, 5)
|
||||
x = x.reshape(batch, 8, 5, 5)
|
||||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||
x = x.view(batch, -1)
|
||||
# x = x.view(batch, 8, 9, -1)
|
||||
# x = x.permute(0, 1, 3, 2)
|
||||
x = x.reshape(batch, -1)
|
||||
x = self.lut5_p(x)
|
||||
|
||||
# x = x.view(batch, 16, 3, 3)
|
||||
|
@ -231,14 +242,13 @@ class SimpleBNN(nn.Module):
|
|||
# x = x.reshape(batch, -1)
|
||||
# x = self.lut5(x)
|
||||
|
||||
x = x.view(batch, 80, 3, 3)
|
||||
x = x.view(batch, 8, 3, 3)
|
||||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||
x = x.view(batch, -1)
|
||||
# x = x.view(batch, 8, 9, -1)
|
||||
# x = x.permute(0, 1, 3, 2)
|
||||
x = x.reshape(batch, -1)
|
||||
x = self.lut6_p(x)
|
||||
|
||||
# x = x.view(batch, 8, 16) # 16 channel 8 value/channel
|
||||
# x = self.lut7_p(x) # 10 * 8 value/channel
|
||||
|
||||
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
|
||||
x = x.view(batch, -1, 8)
|
||||
# x = x * xx
|
||||
|
@ -254,10 +264,18 @@ model = SimpleBNN().to(device)
|
|||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||
|
||||
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
|
||||
writer = SummaryWriter(f"log/{current_time}")
|
||||
hparam_dict = {"lr": lr, "batch_size": batch_size}
|
||||
writer.add_hparams(hparam_dict, {}, run_name=f"./")
|
||||
tbWriter = None
|
||||
|
||||
|
||||
def AddScalar(tag, value, epoch):
|
||||
global tbWriter
|
||||
if not tbWriter:
|
||||
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
|
||||
tbWriter = SummaryWriter(f"log/{current_time}")
|
||||
hparam_dict = {"lr": lr, "batch_size": batch_size}
|
||||
tbWriter.add_hparams(hparam_dict, {}, run_name=f"./")
|
||||
|
||||
tbWriter.add_scalar(tag, value, epoch)
|
||||
|
||||
|
||||
def train(epoch):
|
||||
|
@ -266,18 +284,11 @@ def train(epoch):
|
|||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
# output = output * 1.0
|
||||
# output = F.softmax(output, dim=1)
|
||||
|
||||
# print(output.requires_grad)
|
||||
# print(output.grad_fn)
|
||||
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
writer.add_scalar("loss", loss, epoch)
|
||||
|
||||
if batch_idx % 512 == 0:
|
||||
AddScalar("loss", loss, epoch)
|
||||
if batch_idx % 512 == 0 and batch_idx > 0:
|
||||
print(
|
||||
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
||||
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
|
||||
|
@ -298,17 +309,37 @@ def test(epoch):
|
|||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
accuracy = 100.0 * correct / len(test_loader.dataset)
|
||||
writer.add_scalar("accuracy", accuracy, epoch)
|
||||
AddScalar("accuracy", accuracy, epoch)
|
||||
print(
|
||||
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
|
||||
f"({accuracy:.0f}%)\n"
|
||||
)
|
||||
|
||||
|
||||
def profiler():
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
|
||||
with record_function("model_inference"):
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
|
||||
|
||||
if batch_idx > 5:
|
||||
break
|
||||
|
||||
|
||||
# profiler()
|
||||
|
||||
for epoch in range(1, 300):
|
||||
train(epoch)
|
||||
test(epoch)
|
||||
|
||||
# torch.save(model.state_dict(), "mnist_cnn.pth")
|
||||
print("Model saved to mnist_cnn.pth")
|
||||
writer.close()
|
||||
|
||||
if tbWriter:
|
||||
tbWriter.close()
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
|
||||
|
||||
## 问题
|
||||
|
||||
1. 在一串的binary lut网络中
|
||||
1. 如果每个卷积的channel相互之间没有关系
|
||||
2. 中间插入一层,交换各个channel之间的数据,生成新的相同数量的channel
|
||||
3. 方法2的效果很差
|
||||
1. 好像是破坏了训练,可能是训练的方法不对
|
||||
2. 最终分类是10,10个输出之间有关系就会很差?
|
||||
2. unfold输出的维度不对
|
||||
1. LUT不是对卷积核进行计算,更容易收敛,但是精度没有更高
|
||||
2. LUT不是对卷积核进行计算,不容易收敛,精度差不多
|
Loading…
Reference in New Issue