Update binary mnist.

This commit is contained in:
Colin 2025-06-12 14:51:07 +08:00
parent 3d2ff85fc0
commit a68e5ba5ee
2 changed files with 54 additions and 10 deletions

View File

@ -22,7 +22,7 @@ np.random.seed(1234)
torch.cuda.manual_seed_all(1234)
BS = 16
LR = 0.01
LR = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
@ -56,12 +56,12 @@ class Lut(torch.autograd.Function):
ctx.groupBits = groupBits
ctx.groupRepeat = groupRepeat
ctx.batch = batch
ctx.save_for_backward(input, weight, ind)
ctx.save_for_backward(input, weight, ind, output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, ind = ctx.saved_tensors
input, weight, ind, output = ctx.saved_tensors
groupBits = ctx.groupBits
groupRepeat = ctx.groupRepeat
batch = ctx.batch
@ -73,9 +73,15 @@ class Lut(torch.autograd.Function):
if ctx.needs_input_grad[0]:
grad_input = grad_output * torch.gather(weight, 0, ind)
in_sign = ((input > 0).float() - 0.5) * 2.0
grad_input = grad_input.view(batch, -1, groupRepeat).sum(2)
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits)
# print("grad_input.shape")
# print(grad_output.shape)
# print(grad_input.shape)
# print(in_sign.shape)
in_sign = ((input > 0).float() - 0.5) * 2.0
grad_input = grad_input * in_sign
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
@ -93,6 +99,29 @@ class Lut(torch.autograd.Function):
# print(in_sign[0].detach().cpu().numpy())
# print(out_sign[0].detach().cpu().numpy())
# if weight.shape[0] > 512:
# print("weight")
# print(weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
# else:
# print("weight")
# print(weight.detach().cpu().numpy())
# print("input")
# print(input.detach().cpu().numpy())
# print("grad_output")
# print(grad_output.detach().cpu().numpy())
# if weight.shape[0] > 512:
# print("grad_weight")
# print(grad_weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
# else:
# print("grad_weight")
# print(grad_weight.detach().cpu().numpy())
# if ctx.needs_input_grad[0]:
# print("grad_input")
# print(grad_input.detach().cpu().numpy())
# print("==============================================")
# print("==============================================")
return grad_input, grad_weight, None, None, None
@ -245,7 +274,7 @@ class SimpleBNN(nn.Module):
x = self.lnn5(x)
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
x = x.view(batch, -1, 8)
x = x.view(batch, 10, -1)
# x = x * xx
# x = (x - 128.0) / 256.0
x = x.sum(2)
@ -287,6 +316,10 @@ class SimpleLNN(nn.Module):
# print(self.lutg2.weight.detach().cpu().numpy())
# print("=============================")
# print("=============================")
# print("self.lutg2.grad")
# print(self.lutg2.weight.grad.detach().cpu().numpy())
# print("=============================")
# print("=============================")
torch.autograd.set_detect_anomaly(True)

View File

@ -12,15 +12,21 @@
1. 梯度大​​ → 损失曲面陡峭 → 微小变化导致损失剧烈波动
2. 梯度大微小变化就可以使得loss变化一个单位
2. 梯度大和loss的关系越相关
6. input的梯度计算
1. 对于repeat的操作需要对grad_output进行sum
2. 对于bits->index的操作需要对grad_output进行repeat
## 问题
* 在一串的binary lut网络中
1. 如果每个卷积的channel相互之间没有关系
2. 中间插入一层交换各个channel之间的数据生成新的相同数量的channel
3. 方法2的效果很差
* 在一串的binary lut网络中插入一层交换各个channel之间的数据生成新的相同数量的channel
1. 效果很差
1. 好像是破坏了训练,可能是训练的方法不对,梯度下降不适合这种模型
2. 最终分类是1010个输出之间有关系就会很差
2. 模型总是在把原来的信息进行repeat不影响最终的精度进行全连接就有动态选择就很差
3. 最后一层的repeat数量对精度的影响
1. >10 因为前面的数量不够导致精度不如10
2. >1 and <10 因为10个输出结果中间有交叉数据可能是最后一层交叉导致的导致精度不如10
3. 为什么 最后一层10 x 80 精度不如 1 x 80
* LUT层梯度计算的问题
1. 发现LUT的反向计算grad_weight没有考虑weight本来的正负符号grad表示的是>0的置信度
1. 考虑梯度符号之后,由于整个选择的梯度是一个,没有机会变换到别的
@ -31,10 +37,15 @@
3. grad input的目标是要不要更换别的index
1. 梯度的大小表示更换别的index的程度
2. 梯度正负无所谓,需要随机?
4. repeat是选择不同的weightindex是同样的如果repeat出来的loss sum等0那么这个index的不能下降梯度等0
* unfold输出的维度不对
1. LUT不是对卷积核进行计算,更容易收敛,但是精度没有更高
2. LUT不是对卷积核进行计算,不容易收敛,精度差不多
* 好像只有AdamW优化器可以优化参数明显收敛
* LUT的输出进行二值化对精度有影响大概93->81
* LUT参数初始化为1.0收敛速度非常快好像比随机精度高大概81->93
* LUT参数初始化为1.0收敛速度非常快好像比随机精度高大概81->93
* 把input的Repeat从LutGroup移到Lut里面后
1. 训练的收敛速度快很多最快3epoch基本能收敛
2. 稳定性很大对lr不敏感
3. Repeat的反向由Lut统一处理而不是pytorch自动反向可能修复了一些维度处理的错误