Update mnist to higher accuracy.

This commit is contained in:
Colin 2025-06-13 20:55:43 +08:00
parent a68e5ba5ee
commit d8539b6b2b
2 changed files with 30 additions and 30 deletions

View File

@ -22,7 +22,7 @@ np.random.seed(1234)
torch.cuda.manual_seed_all(1234) torch.cuda.manual_seed_all(1234)
BS = 16 BS = 16
LR = 0.001 LR = 0.01
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") print(f"Using device: {device}")
@ -51,8 +51,7 @@ class Lut(torch.autograd.Function):
ind = ((x > 0).long() * index).sum(dim=-1) ind = ((x > 0).long() * index).sum(dim=-1)
output = torch.gather(weight, 0, ind) output = torch.gather(weight, 0, ind)
output = (output > 0).float() output = ((output > 0).float() - 0.5) * 2.0
output = (output - 0.5) * 2.0
ctx.groupBits = groupBits ctx.groupBits = groupBits
ctx.groupRepeat = groupRepeat ctx.groupRepeat = groupRepeat
ctx.batch = batch ctx.batch = batch
@ -72,26 +71,13 @@ class Lut(torch.autograd.Function):
grad_weight.scatter_add_(0, ind, grad_output) grad_weight.scatter_add_(0, ind, grad_output)
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_input = grad_output * torch.gather(weight, 0, ind) # grad_input = grad_output * torch.gather(weight, 0, ind)
grad_input = grad_output
grad_input = grad_input.view(batch, -1, groupRepeat).sum(2) grad_input = grad_input.view(batch, -1, groupRepeat).sum(2)
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits) grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits)
# print("grad_input.shape")
# print(grad_output.shape)
# print(grad_input.shape)
# print(in_sign.shape)
in_sign = ((input > 0).float() - 0.5) * 2.0
grad_input = grad_input * in_sign
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
# grad_input = grad_output
# grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
# output = output.unsqueeze(-1).repeat(1, 1, bits)
# in_sign = ((input > 0).float() - 0.5) * 2.0 # in_sign = ((input > 0).float() - 0.5) * 2.0
# out_sign = ((output > 0).float() - 0.5) * 2.0 # grad_input = grad_input * in_sign
# grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
# grad_input = grad_input * in_sign * (out_sign * grad_sign)
# grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0) # grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
# 需要一个动态的调整系数 # 需要一个动态的调整系数
@ -155,7 +141,7 @@ class LutGroup(nn.Module):
def __init__(self, group, groupBits, groupRepeat=1): def __init__(self, group, groupBits, groupRepeat=1):
assert groupBits > 1 assert groupBits > 1
super(LutGroup, self).__init__() super(LutGroup, self).__init__()
self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group))) self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
self.group = group self.group = group
self.groupBits = groupBits self.groupBits = groupBits
self.groupRepeat = groupRepeat self.groupRepeat = groupRepeat
@ -186,8 +172,8 @@ class LutCnn(nn.Module):
group = int(len(self.batch_idx) / B / groupBits) group = int(len(self.batch_idx) / B / groupBits)
self.lut = LutGroup(group, groupBits, channel_repeat) self.lut = LutGroup(group, groupBits, channel_repeat)
self.fc = fc self.fc = fc
if fc: if fc and channel_repeat > 1:
self.lutc = LutGroup(group, channel_repeat * C, channel_repeat * C) self.lutc = LutGroup(group, channel_repeat, channel_repeat)
def forward(self, x): def forward(self, x):
B, C, H, W = self.input_shape B, C, H, W = self.input_shape
@ -195,8 +181,8 @@ class LutCnn(nn.Module):
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)] x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
x = x.view(B, -1, self.kernel_size * self.kernel_size) x = x.view(B, -1, self.kernel_size * self.kernel_size)
x = self.lut(x) x = self.lut(x)
if self.fc: if self.fc and self.channel_repeat > 1:
x = x.view(B, self.channel_repeat * C, -1) x = x.view(B, self.channel_repeat, -1)
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
x = self.lutc(x) x = self.lutc(x)
return x return x
@ -212,11 +198,11 @@ class SimpleBNN(nn.Module):
self.b = nn.Parameter(torch.zeros(3, 784)) self.b = nn.Parameter(torch.zeros(3, 784))
# channel_repeat, input_shape, kernel_size, stride, dilation, fc # channel_repeat, input_shape, kernel_size, stride, dilation, fc
self.lnn1 = LutCnn(80, (BS, 1, 28, 28), 2, 2, 1, False) self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1, False)
self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False) self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1, False)
self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False) self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1, False)
self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False) self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1, False)
self.lnn5 = LutCnn(1, (BS, 80, 3, 3), 3, 1, 1) self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5)

View File

@ -16,6 +16,18 @@
1. 对于repeat的操作需要对grad_output进行sum 1. 对于repeat的操作需要对grad_output进行sum
2. 对于bits->index的操作需要对grad_output进行repeat 2. 对于bits->index的操作需要对grad_output进行repeat
### grad 计算
grad_output output input exp_grad grad_input
+ + + +
+ - + -
- + + -
- - + +
+ + - -
+ - - +
- + - +
- - - -
## 问题 ## 问题
* 在一串的binary lut网络中插入一层交换各个channel之间的数据生成新的相同数量的channel * 在一串的binary lut网络中插入一层交换各个channel之间的数据生成新的相同数量的channel
@ -48,4 +60,6 @@
* 把input的Repeat从LutGroup移到Lut里面后 * 把input的Repeat从LutGroup移到Lut里面后
1. 训练的收敛速度快很多最快3epoch基本能收敛 1. 训练的收敛速度快很多最快3epoch基本能收敛
2. 稳定性很大对lr不敏感 2. 稳定性很大对lr不敏感
3. Repeat的反向由Lut统一处理而不是pytorch自动反向可能修复了一些维度处理的错误 3. Repeat的反向由Lut统一处理而不是pytorch自动反向可能修复了一些维度处理的错误
1. 经过这个改动后把grad_input计算回归到原来的方式 grad_input = grad_output发现精度提升非常大
2. 原来的方式可能因为一些代码的bug导致的learning rate的设置也不敏感了, 权重的初始化值也不敏感了