Update mnist to higher accuracy.
This commit is contained in:
parent
a68e5ba5ee
commit
d8539b6b2b
|
@ -22,7 +22,7 @@ np.random.seed(1234)
|
|||
torch.cuda.manual_seed_all(1234)
|
||||
|
||||
BS = 16
|
||||
LR = 0.001
|
||||
LR = 0.01
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
@ -51,8 +51,7 @@ class Lut(torch.autograd.Function):
|
|||
|
||||
ind = ((x > 0).long() * index).sum(dim=-1)
|
||||
output = torch.gather(weight, 0, ind)
|
||||
output = (output > 0).float()
|
||||
output = (output - 0.5) * 2.0
|
||||
output = ((output > 0).float() - 0.5) * 2.0
|
||||
ctx.groupBits = groupBits
|
||||
ctx.groupRepeat = groupRepeat
|
||||
ctx.batch = batch
|
||||
|
@ -72,26 +71,13 @@ class Lut(torch.autograd.Function):
|
|||
grad_weight.scatter_add_(0, ind, grad_output)
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||
# grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||
grad_input = grad_output
|
||||
grad_input = grad_input.view(batch, -1, groupRepeat).sum(2)
|
||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits)
|
||||
|
||||
# print("grad_input.shape")
|
||||
# print(grad_output.shape)
|
||||
# print(grad_input.shape)
|
||||
# print(in_sign.shape)
|
||||
|
||||
in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||
grad_input = grad_input * in_sign
|
||||
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||
|
||||
# grad_input = grad_output
|
||||
# grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||
# output = output.unsqueeze(-1).repeat(1, 1, bits)
|
||||
# in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||
# out_sign = ((output > 0).float() - 0.5) * 2.0
|
||||
# grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
|
||||
# grad_input = grad_input * in_sign * (out_sign * grad_sign)
|
||||
# grad_input = grad_input * in_sign
|
||||
# grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||
|
||||
# 需要一个动态的调整系数
|
||||
|
@ -155,7 +141,7 @@ class LutGroup(nn.Module):
|
|||
def __init__(self, group, groupBits, groupRepeat=1):
|
||||
assert groupBits > 1
|
||||
super(LutGroup, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group)))
|
||||
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
|
||||
self.group = group
|
||||
self.groupBits = groupBits
|
||||
self.groupRepeat = groupRepeat
|
||||
|
@ -186,8 +172,8 @@ class LutCnn(nn.Module):
|
|||
group = int(len(self.batch_idx) / B / groupBits)
|
||||
self.lut = LutGroup(group, groupBits, channel_repeat)
|
||||
self.fc = fc
|
||||
if fc:
|
||||
self.lutc = LutGroup(group, channel_repeat * C, channel_repeat * C)
|
||||
if fc and channel_repeat > 1:
|
||||
self.lutc = LutGroup(group, channel_repeat, channel_repeat)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = self.input_shape
|
||||
|
@ -195,8 +181,8 @@ class LutCnn(nn.Module):
|
|||
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
|
||||
x = x.view(B, -1, self.kernel_size * self.kernel_size)
|
||||
x = self.lut(x)
|
||||
if self.fc:
|
||||
x = x.view(B, self.channel_repeat * C, -1)
|
||||
if self.fc and self.channel_repeat > 1:
|
||||
x = x.view(B, self.channel_repeat, -1)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.lutc(x)
|
||||
return x
|
||||
|
@ -212,11 +198,11 @@ class SimpleBNN(nn.Module):
|
|||
self.b = nn.Parameter(torch.zeros(3, 784))
|
||||
|
||||
# channel_repeat, input_shape, kernel_size, stride, dilation, fc
|
||||
self.lnn1 = LutCnn(80, (BS, 1, 28, 28), 2, 2, 1, False)
|
||||
self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False)
|
||||
self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False)
|
||||
self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False)
|
||||
self.lnn5 = LutCnn(1, (BS, 80, 3, 3), 3, 1, 1)
|
||||
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1, False)
|
||||
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1, False)
|
||||
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1, False)
|
||||
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1, False)
|
||||
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
|
|
|
@ -16,6 +16,18 @@
|
|||
1. 对于repeat的操作,需要对grad_output进行sum
|
||||
2. 对于bits->index的操作,需要对grad_output进行repeat
|
||||
|
||||
### grad 计算
|
||||
|
||||
grad_output output input exp_grad grad_input
|
||||
+ + + +
|
||||
+ - + -
|
||||
- + + -
|
||||
- - + +
|
||||
+ + - -
|
||||
+ - - +
|
||||
- + - +
|
||||
- - - -
|
||||
|
||||
## 问题
|
||||
|
||||
* 在一串的binary lut网络中插入一层,交换各个channel之间的数据,生成新的相同数量的channel
|
||||
|
@ -48,4 +60,6 @@
|
|||
* 把input的Repeat从LutGroup移到Lut里面后
|
||||
1. 训练的收敛速度快很多(最快3epoch基本能收敛)
|
||||
2. 稳定性很大,对lr不敏感
|
||||
3. Repeat的反向由Lut统一处理,而不是pytorch自动反向,可能修复了一些维度处理的错误
|
||||
3. Repeat的反向由Lut统一处理,而不是pytorch自动反向,可能修复了一些维度处理的错误
|
||||
1. 经过这个改动后,把grad_input计算回归到原来的方式 grad_input = grad_output,发现精度提升非常大
|
||||
2. 原来的方式可能因为一些代码的bug导致的,learning rate的设置也不敏感了, 权重的初始化值也不敏感了
|
Loading…
Reference in New Issue