Update mnist to higher accuracy.
This commit is contained in:
parent
a68e5ba5ee
commit
d8539b6b2b
|
@ -22,7 +22,7 @@ np.random.seed(1234)
|
||||||
torch.cuda.manual_seed_all(1234)
|
torch.cuda.manual_seed_all(1234)
|
||||||
|
|
||||||
BS = 16
|
BS = 16
|
||||||
LR = 0.001
|
LR = 0.01
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
@ -51,8 +51,7 @@ class Lut(torch.autograd.Function):
|
||||||
|
|
||||||
ind = ((x > 0).long() * index).sum(dim=-1)
|
ind = ((x > 0).long() * index).sum(dim=-1)
|
||||||
output = torch.gather(weight, 0, ind)
|
output = torch.gather(weight, 0, ind)
|
||||||
output = (output > 0).float()
|
output = ((output > 0).float() - 0.5) * 2.0
|
||||||
output = (output - 0.5) * 2.0
|
|
||||||
ctx.groupBits = groupBits
|
ctx.groupBits = groupBits
|
||||||
ctx.groupRepeat = groupRepeat
|
ctx.groupRepeat = groupRepeat
|
||||||
ctx.batch = batch
|
ctx.batch = batch
|
||||||
|
@ -72,26 +71,13 @@ class Lut(torch.autograd.Function):
|
||||||
grad_weight.scatter_add_(0, ind, grad_output)
|
grad_weight.scatter_add_(0, ind, grad_output)
|
||||||
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
grad_input = grad_output * torch.gather(weight, 0, ind)
|
# grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||||
|
grad_input = grad_output
|
||||||
grad_input = grad_input.view(batch, -1, groupRepeat).sum(2)
|
grad_input = grad_input.view(batch, -1, groupRepeat).sum(2)
|
||||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits)
|
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits)
|
||||||
|
|
||||||
# print("grad_input.shape")
|
|
||||||
# print(grad_output.shape)
|
|
||||||
# print(grad_input.shape)
|
|
||||||
# print(in_sign.shape)
|
|
||||||
|
|
||||||
in_sign = ((input > 0).float() - 0.5) * 2.0
|
|
||||||
grad_input = grad_input * in_sign
|
|
||||||
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
|
||||||
|
|
||||||
# grad_input = grad_output
|
|
||||||
# grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
|
||||||
# output = output.unsqueeze(-1).repeat(1, 1, bits)
|
|
||||||
# in_sign = ((input > 0).float() - 0.5) * 2.0
|
# in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||||
# out_sign = ((output > 0).float() - 0.5) * 2.0
|
# grad_input = grad_input * in_sign
|
||||||
# grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
|
|
||||||
# grad_input = grad_input * in_sign * (out_sign * grad_sign)
|
|
||||||
# grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
# grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||||
|
|
||||||
# 需要一个动态的调整系数
|
# 需要一个动态的调整系数
|
||||||
|
@ -155,7 +141,7 @@ class LutGroup(nn.Module):
|
||||||
def __init__(self, group, groupBits, groupRepeat=1):
|
def __init__(self, group, groupBits, groupRepeat=1):
|
||||||
assert groupBits > 1
|
assert groupBits > 1
|
||||||
super(LutGroup, self).__init__()
|
super(LutGroup, self).__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group)))
|
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
|
||||||
self.group = group
|
self.group = group
|
||||||
self.groupBits = groupBits
|
self.groupBits = groupBits
|
||||||
self.groupRepeat = groupRepeat
|
self.groupRepeat = groupRepeat
|
||||||
|
@ -186,8 +172,8 @@ class LutCnn(nn.Module):
|
||||||
group = int(len(self.batch_idx) / B / groupBits)
|
group = int(len(self.batch_idx) / B / groupBits)
|
||||||
self.lut = LutGroup(group, groupBits, channel_repeat)
|
self.lut = LutGroup(group, groupBits, channel_repeat)
|
||||||
self.fc = fc
|
self.fc = fc
|
||||||
if fc:
|
if fc and channel_repeat > 1:
|
||||||
self.lutc = LutGroup(group, channel_repeat * C, channel_repeat * C)
|
self.lutc = LutGroup(group, channel_repeat, channel_repeat)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = self.input_shape
|
B, C, H, W = self.input_shape
|
||||||
|
@ -195,8 +181,8 @@ class LutCnn(nn.Module):
|
||||||
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
|
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
|
||||||
x = x.view(B, -1, self.kernel_size * self.kernel_size)
|
x = x.view(B, -1, self.kernel_size * self.kernel_size)
|
||||||
x = self.lut(x)
|
x = self.lut(x)
|
||||||
if self.fc:
|
if self.fc and self.channel_repeat > 1:
|
||||||
x = x.view(B, self.channel_repeat * C, -1)
|
x = x.view(B, self.channel_repeat, -1)
|
||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
x = self.lutc(x)
|
x = self.lutc(x)
|
||||||
return x
|
return x
|
||||||
|
@ -212,11 +198,11 @@ class SimpleBNN(nn.Module):
|
||||||
self.b = nn.Parameter(torch.zeros(3, 784))
|
self.b = nn.Parameter(torch.zeros(3, 784))
|
||||||
|
|
||||||
# channel_repeat, input_shape, kernel_size, stride, dilation, fc
|
# channel_repeat, input_shape, kernel_size, stride, dilation, fc
|
||||||
self.lnn1 = LutCnn(80, (BS, 1, 28, 28), 2, 2, 1, False)
|
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1, False)
|
||||||
self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False)
|
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1, False)
|
||||||
self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False)
|
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1, False)
|
||||||
self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False)
|
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1, False)
|
||||||
self.lnn5 = LutCnn(1, (BS, 80, 3, 3), 3, 1, 1)
|
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||||
|
|
|
@ -16,6 +16,18 @@
|
||||||
1. 对于repeat的操作,需要对grad_output进行sum
|
1. 对于repeat的操作,需要对grad_output进行sum
|
||||||
2. 对于bits->index的操作,需要对grad_output进行repeat
|
2. 对于bits->index的操作,需要对grad_output进行repeat
|
||||||
|
|
||||||
|
### grad 计算
|
||||||
|
|
||||||
|
grad_output output input exp_grad grad_input
|
||||||
|
+ + + +
|
||||||
|
+ - + -
|
||||||
|
- + + -
|
||||||
|
- - + +
|
||||||
|
+ + - -
|
||||||
|
+ - - +
|
||||||
|
- + - +
|
||||||
|
- - - -
|
||||||
|
|
||||||
## 问题
|
## 问题
|
||||||
|
|
||||||
* 在一串的binary lut网络中插入一层,交换各个channel之间的数据,生成新的相同数量的channel
|
* 在一串的binary lut网络中插入一层,交换各个channel之间的数据,生成新的相同数量的channel
|
||||||
|
@ -48,4 +60,6 @@
|
||||||
* 把input的Repeat从LutGroup移到Lut里面后
|
* 把input的Repeat从LutGroup移到Lut里面后
|
||||||
1. 训练的收敛速度快很多(最快3epoch基本能收敛)
|
1. 训练的收敛速度快很多(最快3epoch基本能收敛)
|
||||||
2. 稳定性很大,对lr不敏感
|
2. 稳定性很大,对lr不敏感
|
||||||
3. Repeat的反向由Lut统一处理,而不是pytorch自动反向,可能修复了一些维度处理的错误
|
3. Repeat的反向由Lut统一处理,而不是pytorch自动反向,可能修复了一些维度处理的错误
|
||||||
|
1. 经过这个改动后,把grad_input计算回归到原来的方式 grad_input = grad_output,发现精度提升非常大
|
||||||
|
2. 原来的方式可能因为一些代码的bug导致的,learning rate的设置也不敏感了, 权重的初始化值也不敏感了
|
Loading…
Reference in New Issue