Update mnist.
This commit is contained in:
parent
855296be55
commit
30dd319d8c
119
binary/mnist.py
119
binary/mnist.py
|
@ -13,6 +13,7 @@ import math
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from torch.profiler import profile, ProfilerActivity, record_function
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
torch.manual_seed(1234)
|
torch.manual_seed(1234)
|
||||||
|
@ -54,19 +55,17 @@ class Lut(torch.autograd.Function):
|
||||||
assert int(math.log2(weight.shape[-1])) == bits
|
assert int(math.log2(weight.shape[-1])) == bits
|
||||||
|
|
||||||
index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device)
|
index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device)
|
||||||
x = (input > 0).long()
|
ind = ((input > 0).long() * index).sum(dim=-1)
|
||||||
x = x * index
|
|
||||||
ind = x.sum(dim=-1)
|
|
||||||
|
|
||||||
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
|
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1).to(input.device)
|
||||||
output = weight[row_indices, ind]
|
output = weight[row_indices, ind]
|
||||||
|
|
||||||
ctx.save_for_backward(input, weight, ind)
|
ctx.save_for_backward(input, weight, row_indices, ind)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, weight, ind = ctx.saved_tensors
|
input, weight, row_indices, ind = ctx.saved_tensors
|
||||||
grad_input = grad_weight = None
|
grad_input = grad_weight = None
|
||||||
|
|
||||||
batch = input.shape[0]
|
batch = input.shape[0]
|
||||||
|
@ -80,7 +79,6 @@ class Lut(torch.autograd.Function):
|
||||||
grad_weight.scatter_add_(1, ind_p, grad_output_p)
|
grad_weight.scatter_add_(1, ind_p, grad_output_p)
|
||||||
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
|
|
||||||
grad_input = grad_output * weight[row_indices, ind]
|
grad_input = grad_output * weight[row_indices, ind]
|
||||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||||
|
|
||||||
|
@ -129,6 +127,7 @@ class LutGroup(nn.Module):
|
||||||
# output [ batch, totalBits / groupBits * repeat ]
|
# output [ batch, totalBits / groupBits * repeat ]
|
||||||
batch = x.shape[0]
|
batch = x.shape[0]
|
||||||
x = x.view(batch, -1, self.groupBits)
|
x = x.view(batch, -1, self.groupBits)
|
||||||
|
if self.repeat > 1:
|
||||||
x = x.repeat(1, self.repeat, 1)
|
x = x.repeat(1, self.repeat, 1)
|
||||||
x = Lut.apply(x, self.weight)
|
x = Lut.apply(x, self.weight)
|
||||||
return x
|
return x
|
||||||
|
@ -143,14 +142,14 @@ class SimpleBNN(nn.Module):
|
||||||
self.w = nn.Parameter(torch.randn(3, 784))
|
self.w = nn.Parameter(torch.randn(3, 784))
|
||||||
self.b = nn.Parameter(torch.zeros(3, 784))
|
self.b = nn.Parameter(torch.zeros(3, 784))
|
||||||
|
|
||||||
self.lut2_p = LutGroup(784, 4, 80) # pool2 => 14*14*4
|
self.lut2_p = LutGroup(784, 4, 8) # pool2 => 14*14*4
|
||||||
self.lut3_p = LutGroup(80 * 14 * 14, 4, 1) # pool2 => 7*7*4
|
self.lut3_p = LutGroup(8 * 14 * 14, 4, 1) # pool2 => 7*7*4
|
||||||
self.lut4_p = LutGroup(80 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8
|
self.lut4_p = LutGroup(8 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8
|
||||||
# self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8
|
# self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8
|
||||||
self.lut5_p = LutGroup(80 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16
|
self.lut5_p = LutGroup(8 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16
|
||||||
# self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16
|
# self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16
|
||||||
self.lut6_p = LutGroup(80 * 3 * 3, 9, 1) # conv 3 => 128
|
self.lut6_p = LutGroup(8 * 3 * 3, 9, 10) # conv 3 => 128
|
||||||
# self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 80
|
# self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 8
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||||
|
@ -201,29 +200,41 @@ class SimpleBNN(nn.Module):
|
||||||
|
|
||||||
#########################
|
#########################
|
||||||
|
|
||||||
|
# unfold
|
||||||
|
# 原始输出 shape: [batch, channels * kH * kW, L]
|
||||||
|
# 其中 L 是滑动窗口的数量
|
||||||
|
|
||||||
x = x.view(batch, 1, 28, 28)
|
x = x.view(batch, 1, 28, 28)
|
||||||
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
||||||
x = x.view(batch, -1)
|
# x = x.view(batch, 1, 4, -1)
|
||||||
|
# x = x.permute(0, 1, 3, 2)
|
||||||
|
x = x.reshape(batch, -1)
|
||||||
x = self.lut2_p(x)
|
x = self.lut2_p(x)
|
||||||
|
|
||||||
x = x.view(batch, 80, 14, 14)
|
x = x.view(batch, 8, 14, 14)
|
||||||
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
||||||
x = x.view(batch, -1)
|
# x = x.view(batch, 8, 4, -1)
|
||||||
|
# x = x.permute(0, 1, 3, 2)
|
||||||
|
x = x.reshape(batch, -1)
|
||||||
x = self.lut3_p(x)
|
x = self.lut3_p(x)
|
||||||
|
|
||||||
x = x.view(batch, 80, 7, 7)
|
x = x.view(batch, 8, 7, 7)
|
||||||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||||
x = x.view(batch, -1)
|
# x = x.view(batch, 8, 9, -1)
|
||||||
|
# x = x.permute(0, 1, 3, 2)
|
||||||
|
x = x.reshape(batch, -1)
|
||||||
x = self.lut4_p(x)
|
x = self.lut4_p(x)
|
||||||
|
|
||||||
# x = x.view(batch, 8, 5, 5)
|
# x = x.view(batch, 8, 25)
|
||||||
# x = x.permute(0, 2, 3, 1)
|
# x = x.permute(0, 2, 1) # batch 25 8
|
||||||
# x = x.reshape(batch, -1)
|
# x = x.reshape(batch, -1) # batch 25*8
|
||||||
# x = self.lut4(x)
|
# x = self.lut4(x) # batch 8*25
|
||||||
|
|
||||||
x = x.view(batch, 80, 5, 5)
|
x = x.reshape(batch, 8, 5, 5)
|
||||||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||||
x = x.view(batch, -1)
|
# x = x.view(batch, 8, 9, -1)
|
||||||
|
# x = x.permute(0, 1, 3, 2)
|
||||||
|
x = x.reshape(batch, -1)
|
||||||
x = self.lut5_p(x)
|
x = self.lut5_p(x)
|
||||||
|
|
||||||
# x = x.view(batch, 16, 3, 3)
|
# x = x.view(batch, 16, 3, 3)
|
||||||
|
@ -231,14 +242,13 @@ class SimpleBNN(nn.Module):
|
||||||
# x = x.reshape(batch, -1)
|
# x = x.reshape(batch, -1)
|
||||||
# x = self.lut5(x)
|
# x = self.lut5(x)
|
||||||
|
|
||||||
x = x.view(batch, 80, 3, 3)
|
x = x.view(batch, 8, 3, 3)
|
||||||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||||
x = x.view(batch, -1)
|
# x = x.view(batch, 8, 9, -1)
|
||||||
|
# x = x.permute(0, 1, 3, 2)
|
||||||
|
x = x.reshape(batch, -1)
|
||||||
x = self.lut6_p(x)
|
x = self.lut6_p(x)
|
||||||
|
|
||||||
# x = x.view(batch, 8, 16) # 16 channel 8 value/channel
|
|
||||||
# x = self.lut7_p(x) # 10 * 8 value/channel
|
|
||||||
|
|
||||||
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
|
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
|
||||||
x = x.view(batch, -1, 8)
|
x = x.view(batch, -1, 8)
|
||||||
# x = x * xx
|
# x = x * xx
|
||||||
|
@ -254,10 +264,18 @@ model = SimpleBNN().to(device)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||||
|
|
||||||
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
|
tbWriter = None
|
||||||
writer = SummaryWriter(f"log/{current_time}")
|
|
||||||
hparam_dict = {"lr": lr, "batch_size": batch_size}
|
|
||||||
writer.add_hparams(hparam_dict, {}, run_name=f"./")
|
def AddScalar(tag, value, epoch):
|
||||||
|
global tbWriter
|
||||||
|
if not tbWriter:
|
||||||
|
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
|
||||||
|
tbWriter = SummaryWriter(f"log/{current_time}")
|
||||||
|
hparam_dict = {"lr": lr, "batch_size": batch_size}
|
||||||
|
tbWriter.add_hparams(hparam_dict, {}, run_name=f"./")
|
||||||
|
|
||||||
|
tbWriter.add_scalar(tag, value, epoch)
|
||||||
|
|
||||||
|
|
||||||
def train(epoch):
|
def train(epoch):
|
||||||
|
@ -266,18 +284,11 @@ def train(epoch):
|
||||||
data, target = data.to(device), target.to(device)
|
data, target = data.to(device), target.to(device)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
output = model(data)
|
output = model(data)
|
||||||
# output = output * 1.0
|
|
||||||
# output = F.softmax(output, dim=1)
|
|
||||||
|
|
||||||
# print(output.requires_grad)
|
|
||||||
# print(output.grad_fn)
|
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
writer.add_scalar("loss", loss, epoch)
|
AddScalar("loss", loss, epoch)
|
||||||
|
if batch_idx % 512 == 0 and batch_idx > 0:
|
||||||
if batch_idx % 512 == 0:
|
|
||||||
print(
|
print(
|
||||||
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
||||||
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
|
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
|
||||||
|
@ -298,17 +309,37 @@ def test(epoch):
|
||||||
|
|
||||||
test_loss /= len(test_loader.dataset)
|
test_loss /= len(test_loader.dataset)
|
||||||
accuracy = 100.0 * correct / len(test_loader.dataset)
|
accuracy = 100.0 * correct / len(test_loader.dataset)
|
||||||
writer.add_scalar("accuracy", accuracy, epoch)
|
AddScalar("accuracy", accuracy, epoch)
|
||||||
print(
|
print(
|
||||||
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
|
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
|
||||||
f"({accuracy:.0f}%)\n"
|
f"({accuracy:.0f}%)\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def profiler():
|
||||||
|
for batch_idx, (data, target) in enumerate(train_loader):
|
||||||
|
data, target = data.to(device), target.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
|
||||||
|
with record_function("model_inference"):
|
||||||
|
output = model(data)
|
||||||
|
loss = criterion(output, target)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
|
||||||
|
|
||||||
|
if batch_idx > 5:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
# profiler()
|
||||||
|
|
||||||
for epoch in range(1, 300):
|
for epoch in range(1, 300):
|
||||||
train(epoch)
|
train(epoch)
|
||||||
test(epoch)
|
test(epoch)
|
||||||
|
|
||||||
# torch.save(model.state_dict(), "mnist_cnn.pth")
|
# torch.save(model.state_dict(), "mnist_cnn.pth")
|
||||||
print("Model saved to mnist_cnn.pth")
|
print("Model saved to mnist_cnn.pth")
|
||||||
writer.close()
|
|
||||||
|
if tbWriter:
|
||||||
|
tbWriter.close()
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
|
||||||
|
|
||||||
|
## 问题
|
||||||
|
|
||||||
|
1. 在一串的binary lut网络中
|
||||||
|
1. 如果每个卷积的channel相互之间没有关系
|
||||||
|
2. 中间插入一层,交换各个channel之间的数据,生成新的相同数量的channel
|
||||||
|
3. 方法2的效果很差
|
||||||
|
1. 好像是破坏了训练,可能是训练的方法不对
|
||||||
|
2. 最终分类是10,10个输出之间有关系就会很差?
|
||||||
|
2. unfold输出的维度不对
|
||||||
|
1. LUT不是对卷积核进行计算,更容易收敛,但是精度没有更高
|
||||||
|
2. LUT不是对卷积核进行计算,不容易收敛,精度差不多
|
Loading…
Reference in New Issue