Update binary mnist.
This commit is contained in:
parent
3d2ff85fc0
commit
a68e5ba5ee
|
@ -22,7 +22,7 @@ np.random.seed(1234)
|
||||||
torch.cuda.manual_seed_all(1234)
|
torch.cuda.manual_seed_all(1234)
|
||||||
|
|
||||||
BS = 16
|
BS = 16
|
||||||
LR = 0.01
|
LR = 0.001
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
@ -56,12 +56,12 @@ class Lut(torch.autograd.Function):
|
||||||
ctx.groupBits = groupBits
|
ctx.groupBits = groupBits
|
||||||
ctx.groupRepeat = groupRepeat
|
ctx.groupRepeat = groupRepeat
|
||||||
ctx.batch = batch
|
ctx.batch = batch
|
||||||
ctx.save_for_backward(input, weight, ind)
|
ctx.save_for_backward(input, weight, ind, output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, weight, ind = ctx.saved_tensors
|
input, weight, ind, output = ctx.saved_tensors
|
||||||
groupBits = ctx.groupBits
|
groupBits = ctx.groupBits
|
||||||
groupRepeat = ctx.groupRepeat
|
groupRepeat = ctx.groupRepeat
|
||||||
batch = ctx.batch
|
batch = ctx.batch
|
||||||
|
@ -73,9 +73,15 @@ class Lut(torch.autograd.Function):
|
||||||
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
grad_input = grad_output * torch.gather(weight, 0, ind)
|
grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||||
in_sign = ((input > 0).float() - 0.5) * 2.0
|
|
||||||
grad_input = grad_input.view(batch, -1, groupRepeat).sum(2)
|
grad_input = grad_input.view(batch, -1, groupRepeat).sum(2)
|
||||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits)
|
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits)
|
||||||
|
|
||||||
|
# print("grad_input.shape")
|
||||||
|
# print(grad_output.shape)
|
||||||
|
# print(grad_input.shape)
|
||||||
|
# print(in_sign.shape)
|
||||||
|
|
||||||
|
in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||||
grad_input = grad_input * in_sign
|
grad_input = grad_input * in_sign
|
||||||
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||||
|
|
||||||
|
@ -93,6 +99,29 @@ class Lut(torch.autograd.Function):
|
||||||
|
|
||||||
# print(in_sign[0].detach().cpu().numpy())
|
# print(in_sign[0].detach().cpu().numpy())
|
||||||
# print(out_sign[0].detach().cpu().numpy())
|
# print(out_sign[0].detach().cpu().numpy())
|
||||||
|
|
||||||
|
# if weight.shape[0] > 512:
|
||||||
|
# print("weight")
|
||||||
|
# print(weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
|
||||||
|
# else:
|
||||||
|
# print("weight")
|
||||||
|
# print(weight.detach().cpu().numpy())
|
||||||
|
# print("input")
|
||||||
|
# print(input.detach().cpu().numpy())
|
||||||
|
# print("grad_output")
|
||||||
|
# print(grad_output.detach().cpu().numpy())
|
||||||
|
# if weight.shape[0] > 512:
|
||||||
|
# print("grad_weight")
|
||||||
|
# print(grad_weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
|
||||||
|
# else:
|
||||||
|
# print("grad_weight")
|
||||||
|
# print(grad_weight.detach().cpu().numpy())
|
||||||
|
# if ctx.needs_input_grad[0]:
|
||||||
|
# print("grad_input")
|
||||||
|
# print(grad_input.detach().cpu().numpy())
|
||||||
|
# print("==============================================")
|
||||||
|
# print("==============================================")
|
||||||
|
|
||||||
return grad_input, grad_weight, None, None, None
|
return grad_input, grad_weight, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -245,7 +274,7 @@ class SimpleBNN(nn.Module):
|
||||||
x = self.lnn5(x)
|
x = self.lnn5(x)
|
||||||
|
|
||||||
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
|
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
|
||||||
x = x.view(batch, -1, 8)
|
x = x.view(batch, 10, -1)
|
||||||
# x = x * xx
|
# x = x * xx
|
||||||
# x = (x - 128.0) / 256.0
|
# x = (x - 128.0) / 256.0
|
||||||
x = x.sum(2)
|
x = x.sum(2)
|
||||||
|
@ -287,6 +316,10 @@ class SimpleLNN(nn.Module):
|
||||||
# print(self.lutg2.weight.detach().cpu().numpy())
|
# print(self.lutg2.weight.detach().cpu().numpy())
|
||||||
# print("=============================")
|
# print("=============================")
|
||||||
# print("=============================")
|
# print("=============================")
|
||||||
|
# print("self.lutg2.grad")
|
||||||
|
# print(self.lutg2.weight.grad.detach().cpu().numpy())
|
||||||
|
# print("=============================")
|
||||||
|
# print("=============================")
|
||||||
|
|
||||||
|
|
||||||
torch.autograd.set_detect_anomaly(True)
|
torch.autograd.set_detect_anomaly(True)
|
||||||
|
|
|
@ -12,15 +12,21 @@
|
||||||
1. 梯度大 → 损失曲面陡峭 → 微小变化导致损失剧烈波动
|
1. 梯度大 → 损失曲面陡峭 → 微小变化导致损失剧烈波动
|
||||||
2. 梯度大,微小变化就可以使得loss变化一个单位
|
2. 梯度大,微小变化就可以使得loss变化一个单位
|
||||||
2. 梯度大,和loss的关系越相关
|
2. 梯度大,和loss的关系越相关
|
||||||
|
6. input的梯度计算
|
||||||
|
1. 对于repeat的操作,需要对grad_output进行sum
|
||||||
|
2. 对于bits->index的操作,需要对grad_output进行repeat
|
||||||
|
|
||||||
## 问题
|
## 问题
|
||||||
|
|
||||||
* 在一串的binary lut网络中
|
* 在一串的binary lut网络中插入一层,交换各个channel之间的数据,生成新的相同数量的channel
|
||||||
1. 如果每个卷积的channel相互之间没有关系
|
1. 效果很差
|
||||||
2. 中间插入一层,交换各个channel之间的数据,生成新的相同数量的channel
|
|
||||||
3. 方法2的效果很差
|
|
||||||
1. 好像是破坏了训练,可能是训练的方法不对,梯度下降不适合这种模型
|
1. 好像是破坏了训练,可能是训练的方法不对,梯度下降不适合这种模型
|
||||||
2. 最终分类是10,10个输出之间有关系就会很差?
|
2. 最终分类是10,10个输出之间有关系就会很差?
|
||||||
|
2. 模型总是在把原来的信息进行repeat,不影响最终的精度,进行全连接就有动态选择就很差
|
||||||
|
3. 最后一层的repeat数量对精度的影响
|
||||||
|
1. >10 因为前面的数量不够,导致精度不如10
|
||||||
|
2. >1 and <10 因为10个输出结果中间有交叉数据(可能是最后一层交叉导致的),导致精度不如10
|
||||||
|
3. 为什么 最后一层,10 x 80 精度不如 1 x 80 ????
|
||||||
* LUT层梯度计算的问题
|
* LUT层梯度计算的问题
|
||||||
1. 发现LUT的反向计算grad_weight没有考虑weight本来的正负符号,grad表示的是>0的置信度
|
1. 发现LUT的反向计算grad_weight没有考虑weight本来的正负符号,grad表示的是>0的置信度
|
||||||
1. 考虑梯度符号之后,由于整个选择的梯度是一个,没有机会变换到别的
|
1. 考虑梯度符号之后,由于整个选择的梯度是一个,没有机会变换到别的
|
||||||
|
@ -31,6 +37,7 @@
|
||||||
3. grad input的目标是,要不要更换别的index
|
3. grad input的目标是,要不要更换别的index
|
||||||
1. 梯度的大小表示更换别的index的程度
|
1. 梯度的大小表示更换别的index的程度
|
||||||
2. 梯度正负无所谓,需要随机?
|
2. 梯度正负无所谓,需要随机?
|
||||||
|
4. repeat是选择不同的weight,index是同样的,如果repeat出来的loss sum等0,那么这个index的不能下降,梯度等0
|
||||||
|
|
||||||
* unfold输出的维度不对
|
* unfold输出的维度不对
|
||||||
1. LUT不是对卷积核进行计算,更容易收敛,但是精度没有更高
|
1. LUT不是对卷积核进行计算,更容易收敛,但是精度没有更高
|
||||||
|
@ -38,3 +45,7 @@
|
||||||
* 好像只有AdamW优化器可以优化参数,明显收敛
|
* 好像只有AdamW优化器可以优化参数,明显收敛
|
||||||
* LUT的输出进行二值化对精度有影响,大概93->81
|
* LUT的输出进行二值化对精度有影响,大概93->81
|
||||||
* LUT参数初始化为1.0,收敛速度非常快,好像比随机精度高,大概81->93
|
* LUT参数初始化为1.0,收敛速度非常快,好像比随机精度高,大概81->93
|
||||||
|
* 把input的Repeat从LutGroup移到Lut里面后
|
||||||
|
1. 训练的收敛速度快很多(最快3epoch基本能收敛)
|
||||||
|
2. 稳定性很大,对lr不敏感
|
||||||
|
3. Repeat的反向由Lut统一处理,而不是pytorch自动反向,可能修复了一些维度处理的错误
|
Loading…
Reference in New Issue