diff --git a/binary/mnist.py b/binary/mnist.py index 525a489..7015044 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -39,35 +39,43 @@ test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=T class Lut(torch.autograd.Function): - # input [batch, group, bits ] - # output [batch, group ] + # input [ batch, group * groupBits ] + # output [ batch, group * groupRepeat ] # weight [2**bits, group ] @staticmethod - def forward(ctx, input, weight, index): - ind = ((input > 0).long() * index).sum(dim=-1) + def forward(ctx, input, weight, index, groupBits, groupRepeat): + batch = input.shape[0] + x = input.reshape(batch, -1, groupBits) + if groupRepeat > 1: + x = x.repeat(1, groupRepeat, 1) + + ind = ((x > 0).long() * index).sum(dim=-1) output = torch.gather(weight, 0, ind) output = (output > 0).float() output = (output - 0.5) * 2.0 + ctx.groupBits = groupBits + ctx.groupRepeat = groupRepeat + ctx.batch = batch ctx.save_for_backward(input, weight, ind) return output @staticmethod def backward(ctx, grad_output): input, weight, ind = ctx.saved_tensors + groupBits = ctx.groupBits + groupRepeat = ctx.groupRepeat + batch = ctx.batch grad_input = grad_weight = None - bits = input.shape[2] if ctx.needs_input_grad[1]: grad_weight = torch.zeros_like(weight) grad_weight.scatter_add_(0, ind, grad_output) if ctx.needs_input_grad[0]: - grad_input = grad_output * torch.gather(weight, 0, ind) - # grad_input = grad_output - grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) - # output = output.unsqueeze(-1).repeat(1, 1, bits) in_sign = ((input > 0).float() - 0.5) * 2.0 + grad_input = grad_input.view(batch, -1, groupRepeat).sum(2) + grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits) grad_input = grad_input * in_sign grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0) @@ -85,10 +93,7 @@ class Lut(torch.autograd.Function): # print(in_sign[0].detach().cpu().numpy()) # print(out_sign[0].detach().cpu().numpy()) - # print(grad_sign[0].detach().cpu().numpy()) - # print(grad_input[0].detach().cpu().numpy()) - - return grad_input, grad_weight, None + return grad_input, grad_weight, None, None, None class SimpleCNN(nn.Module): @@ -121,7 +126,7 @@ class LutGroup(nn.Module): def __init__(self, group, groupBits, groupRepeat=1): assert groupBits > 1 super(LutGroup, self).__init__() - self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group))) + self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group))) self.group = group self.groupBits = groupBits self.groupRepeat = groupRepeat @@ -130,11 +135,7 @@ class LutGroup(nn.Module): def forward(self, x): # input [ batch, group * groupBits ] # output [ batch, group * groupRepeat ] - batch = x.shape[0] - x = x.reshape(batch, -1, self.groupBits) - if self.groupRepeat > 1: - x = x.repeat(1, self.groupRepeat, 1) - x = Lut.apply(x, self.weight, self.index) + x = Lut.apply(x, self.weight, self.index, self.groupBits, self.groupRepeat) return x @@ -186,7 +187,7 @@ class SimpleBNN(nn.Module): self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False) self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False) self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False) - self.lnn5 = LutCnn(10, (BS, 80, 3, 3), 3, 1, 1) + self.lnn5 = LutCnn(1, (BS, 80, 3, 3), 3, 1, 1) self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) @@ -282,10 +283,10 @@ class SimpleLNN(nn.Module): print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy()) print("=============================") print("=============================") - # print("self.lutg2") - # print(self.lutg2.weight.detach().cpu().numpy()) - # print("=============================") - # print("=============================") + print("self.lutg2") + print(self.lutg2.weight.detach().cpu().numpy()) + print("=============================") + print("=============================") torch.autograd.set_detect_anomaly(True)