Update mnist to higher accuracy.
This commit is contained in:
		
							parent
							
								
									a68e5ba5ee
								
							
						
					
					
						commit
						d8539b6b2b
					
				| 
						 | 
				
			
			@ -22,7 +22,7 @@ np.random.seed(1234)
 | 
			
		|||
torch.cuda.manual_seed_all(1234)
 | 
			
		||||
 | 
			
		||||
BS = 16
 | 
			
		||||
LR = 0.001
 | 
			
		||||
LR = 0.01
 | 
			
		||||
 | 
			
		||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
print(f"Using device: {device}")
 | 
			
		||||
| 
						 | 
				
			
			@ -51,8 +51,7 @@ class Lut(torch.autograd.Function):
 | 
			
		|||
 | 
			
		||||
        ind = ((x > 0).long() * index).sum(dim=-1)
 | 
			
		||||
        output = torch.gather(weight, 0, ind)
 | 
			
		||||
        output = (output > 0).float()
 | 
			
		||||
        output = (output - 0.5) * 2.0
 | 
			
		||||
        output = ((output > 0).float() - 0.5) * 2.0
 | 
			
		||||
        ctx.groupBits = groupBits
 | 
			
		||||
        ctx.groupRepeat = groupRepeat
 | 
			
		||||
        ctx.batch = batch
 | 
			
		||||
| 
						 | 
				
			
			@ -72,26 +71,13 @@ class Lut(torch.autograd.Function):
 | 
			
		|||
            grad_weight.scatter_add_(0, ind, grad_output)
 | 
			
		||||
 | 
			
		||||
        if ctx.needs_input_grad[0]:
 | 
			
		||||
            grad_input = grad_output * torch.gather(weight, 0, ind)
 | 
			
		||||
            # grad_input = grad_output * torch.gather(weight, 0, ind)
 | 
			
		||||
            grad_input = grad_output
 | 
			
		||||
            grad_input = grad_input.view(batch, -1, groupRepeat).sum(2)
 | 
			
		||||
            grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits)
 | 
			
		||||
 | 
			
		||||
            # print("grad_input.shape")
 | 
			
		||||
            # print(grad_output.shape)
 | 
			
		||||
            # print(grad_input.shape)
 | 
			
		||||
            # print(in_sign.shape)
 | 
			
		||||
 | 
			
		||||
            in_sign = ((input > 0).float() - 0.5) * 2.0
 | 
			
		||||
            grad_input = grad_input * in_sign
 | 
			
		||||
            grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
 | 
			
		||||
 | 
			
		||||
            # grad_input = grad_output
 | 
			
		||||
            # grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
 | 
			
		||||
            # output = output.unsqueeze(-1).repeat(1, 1, bits)
 | 
			
		||||
            # in_sign = ((input > 0).float() - 0.5) * 2.0
 | 
			
		||||
            # out_sign = ((output > 0).float() - 0.5) * 2.0
 | 
			
		||||
            # grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
 | 
			
		||||
            # grad_input = grad_input * in_sign * (out_sign * grad_sign)
 | 
			
		||||
            # grad_input = grad_input * in_sign
 | 
			
		||||
            # grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
 | 
			
		||||
 | 
			
		||||
            # 需要一个动态的调整系数
 | 
			
		||||
| 
						 | 
				
			
			@ -155,7 +141,7 @@ class LutGroup(nn.Module):
 | 
			
		|||
    def __init__(self, group, groupBits, groupRepeat=1):
 | 
			
		||||
        assert groupBits > 1
 | 
			
		||||
        super(LutGroup, self).__init__()
 | 
			
		||||
        self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group)))
 | 
			
		||||
        self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
 | 
			
		||||
        self.group = group
 | 
			
		||||
        self.groupBits = groupBits
 | 
			
		||||
        self.groupRepeat = groupRepeat
 | 
			
		||||
| 
						 | 
				
			
			@ -186,8 +172,8 @@ class LutCnn(nn.Module):
 | 
			
		|||
        group = int(len(self.batch_idx) / B / groupBits)
 | 
			
		||||
        self.lut = LutGroup(group, groupBits, channel_repeat)
 | 
			
		||||
        self.fc = fc
 | 
			
		||||
        if fc:
 | 
			
		||||
            self.lutc = LutGroup(group, channel_repeat * C, channel_repeat * C)
 | 
			
		||||
        if fc and channel_repeat > 1:
 | 
			
		||||
            self.lutc = LutGroup(group, channel_repeat, channel_repeat)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        B, C, H, W = self.input_shape
 | 
			
		||||
| 
						 | 
				
			
			@ -195,8 +181,8 @@ class LutCnn(nn.Module):
 | 
			
		|||
        x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
 | 
			
		||||
        x = x.view(B, -1, self.kernel_size * self.kernel_size)
 | 
			
		||||
        x = self.lut(x)
 | 
			
		||||
        if self.fc:
 | 
			
		||||
            x = x.view(B, self.channel_repeat * C, -1)
 | 
			
		||||
        if self.fc and self.channel_repeat > 1:
 | 
			
		||||
            x = x.view(B, self.channel_repeat, -1)
 | 
			
		||||
            x = x.permute(0, 2, 1)
 | 
			
		||||
            x = self.lutc(x)
 | 
			
		||||
        return x
 | 
			
		||||
| 
						 | 
				
			
			@ -212,11 +198,11 @@ class SimpleBNN(nn.Module):
 | 
			
		|||
        self.b = nn.Parameter(torch.zeros(3, 784))
 | 
			
		||||
 | 
			
		||||
        # channel_repeat, input_shape, kernel_size, stride, dilation, fc
 | 
			
		||||
        self.lnn1 = LutCnn(80, (BS, 1, 28, 28), 2, 2, 1, False)
 | 
			
		||||
        self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False)
 | 
			
		||||
        self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False)
 | 
			
		||||
        self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False)
 | 
			
		||||
        self.lnn5 = LutCnn(1, (BS, 80, 3, 3), 3, 1, 1)
 | 
			
		||||
        self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1, False)
 | 
			
		||||
        self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1, False)
 | 
			
		||||
        self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1, False)
 | 
			
		||||
        self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1, False)
 | 
			
		||||
        self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
 | 
			
		||||
 | 
			
		||||
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
 | 
			
		||||
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,6 +16,18 @@
 | 
			
		|||
        1. 对于repeat的操作,需要对grad_output进行sum
 | 
			
		||||
        2. 对于bits->index的操作,需要对grad_output进行repeat
 | 
			
		||||
 | 
			
		||||
### grad 计算
 | 
			
		||||
 | 
			
		||||
grad_output    output    input    exp_grad   grad_input
 | 
			
		||||
     +           +         +                     +
 | 
			
		||||
     +           -         +                     -
 | 
			
		||||
     -           +         +                     -
 | 
			
		||||
     -           -         +                     +
 | 
			
		||||
     +           +         -                     -
 | 
			
		||||
     +           -         -                     +
 | 
			
		||||
     -           +         -                     +
 | 
			
		||||
     -           -         -                     -
 | 
			
		||||
 | 
			
		||||
## 问题
 | 
			
		||||
 | 
			
		||||
* 在一串的binary lut网络中插入一层,交换各个channel之间的数据,生成新的相同数量的channel
 | 
			
		||||
| 
						 | 
				
			
			@ -49,3 +61,5 @@
 | 
			
		|||
    1. 训练的收敛速度快很多(最快3epoch基本能收敛)
 | 
			
		||||
    2. 稳定性很大,对lr不敏感
 | 
			
		||||
    3. Repeat的反向由Lut统一处理,而不是pytorch自动反向,可能修复了一些维度处理的错误
 | 
			
		||||
        1. 经过这个改动后,把grad_input计算回归到原来的方式 grad_input = grad_output,发现精度提升非常大
 | 
			
		||||
        2. 原来的方式可能因为一些代码的bug导致的,learning rate的设置也不敏感了, 权重的初始化值也不敏感了
 | 
			
		||||
		Loading…
	
		Reference in New Issue