diff --git a/binary/mnist.py b/binary/mnist.py index a1ccf80..8749f81 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -22,7 +22,7 @@ np.random.seed(1234) torch.cuda.manual_seed_all(1234) BS = 16 -LR = 0.01 +LR = 0.001 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") @@ -56,12 +56,12 @@ class Lut(torch.autograd.Function): ctx.groupBits = groupBits ctx.groupRepeat = groupRepeat ctx.batch = batch - ctx.save_for_backward(input, weight, ind) + ctx.save_for_backward(input, weight, ind, output) return output @staticmethod def backward(ctx, grad_output): - input, weight, ind = ctx.saved_tensors + input, weight, ind, output = ctx.saved_tensors groupBits = ctx.groupBits groupRepeat = ctx.groupRepeat batch = ctx.batch @@ -73,9 +73,15 @@ class Lut(torch.autograd.Function): if ctx.needs_input_grad[0]: grad_input = grad_output * torch.gather(weight, 0, ind) - in_sign = ((input > 0).float() - 0.5) * 2.0 grad_input = grad_input.view(batch, -1, groupRepeat).sum(2) grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits) + + # print("grad_input.shape") + # print(grad_output.shape) + # print(grad_input.shape) + # print(in_sign.shape) + + in_sign = ((input > 0).float() - 0.5) * 2.0 grad_input = grad_input * in_sign grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0) @@ -93,6 +99,29 @@ class Lut(torch.autograd.Function): # print(in_sign[0].detach().cpu().numpy()) # print(out_sign[0].detach().cpu().numpy()) + + # if weight.shape[0] > 512: + # print("weight") + # print(weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy()) + # else: + # print("weight") + # print(weight.detach().cpu().numpy()) + # print("input") + # print(input.detach().cpu().numpy()) + # print("grad_output") + # print(grad_output.detach().cpu().numpy()) + # if weight.shape[0] > 512: + # print("grad_weight") + # print(grad_weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy()) + # else: + # print("grad_weight") + # print(grad_weight.detach().cpu().numpy()) + # if ctx.needs_input_grad[0]: + # print("grad_input") + # print(grad_input.detach().cpu().numpy()) + # print("==============================================") + # print("==============================================") + return grad_input, grad_weight, None, None, None @@ -245,7 +274,7 @@ class SimpleBNN(nn.Module): x = self.lnn5(x) # xx = 2 ** torch.arange(7, -1, -1).to(x.device) - x = x.view(batch, -1, 8) + x = x.view(batch, 10, -1) # x = x * xx # x = (x - 128.0) / 256.0 x = x.sum(2) @@ -287,6 +316,10 @@ class SimpleLNN(nn.Module): # print(self.lutg2.weight.detach().cpu().numpy()) # print("=============================") # print("=============================") + # print("self.lutg2.grad") + # print(self.lutg2.weight.grad.detach().cpu().numpy()) + # print("=============================") + # print("=============================") torch.autograd.set_detect_anomaly(True) diff --git a/binary/readme.md b/binary/readme.md index 47ca668..9d3e96d 100644 --- a/binary/readme.md +++ b/binary/readme.md @@ -12,15 +12,21 @@ 1. 梯度大​​ → 损失曲面陡峭 → 微小变化导致损失剧烈波动 2. 梯度大,微小变化就可以使得loss变化一个单位 2. 梯度大,和loss的关系越相关 + 6. input的梯度计算 + 1. 对于repeat的操作,需要对grad_output进行sum + 2. 对于bits->index的操作,需要对grad_output进行repeat ## 问题 -* 在一串的binary lut网络中 - 1. 如果每个卷积的channel相互之间没有关系 - 2. 中间插入一层,交换各个channel之间的数据,生成新的相同数量的channel - 3. 方法2的效果很差 +* 在一串的binary lut网络中插入一层,交换各个channel之间的数据,生成新的相同数量的channel + 1. 效果很差 1. 好像是破坏了训练,可能是训练的方法不对,梯度下降不适合这种模型 2. 最终分类是10,10个输出之间有关系就会很差? + 2. 模型总是在把原来的信息进行repeat,不影响最终的精度,进行全连接就有动态选择就很差 + 3. 最后一层的repeat数量对精度的影响 + 1. >10 因为前面的数量不够,导致精度不如10 + 2. >1 and <10 因为10个输出结果中间有交叉数据(可能是最后一层交叉导致的),导致精度不如10 + 3. 为什么 最后一层,10 x 80 精度不如 1 x 80 ???? * LUT层梯度计算的问题 1. 发现LUT的反向计算grad_weight没有考虑weight本来的正负符号,grad表示的是>0的置信度 1. 考虑梯度符号之后,由于整个选择的梯度是一个,没有机会变换到别的 @@ -31,10 +37,15 @@ 3. grad input的目标是,要不要更换别的index 1. 梯度的大小表示更换别的index的程度 2. 梯度正负无所谓,需要随机? + 4. repeat是选择不同的weight,index是同样的,如果repeat出来的loss sum等0,那么这个index的不能下降,梯度等0 * unfold输出的维度不对 1. LUT不是对卷积核进行计算,更容易收敛,但是精度没有更高 2. LUT不是对卷积核进行计算,不容易收敛,精度差不多 * 好像只有AdamW优化器可以优化参数,明显收敛 * LUT的输出进行二值化对精度有影响,大概93->81 -* LUT参数初始化为1.0,收敛速度非常快,好像比随机精度高,大概81->93 \ No newline at end of file +* LUT参数初始化为1.0,收敛速度非常快,好像比随机精度高,大概81->93 +* 把input的Repeat从LutGroup移到Lut里面后 + 1. 训练的收敛速度快很多(最快3epoch基本能收敛) + 2. 稳定性很大,对lr不敏感 + 3. Repeat的反向由Lut统一处理,而不是pytorch自动反向,可能修复了一些维度处理的错误 \ No newline at end of file