From 924a0ca9b438b888a49087dbaffbf8953288f34b Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 9 Jun 2025 14:41:37 +0800 Subject: [PATCH] Fix and update binary network. --- binary/.gitignore | 1 + binary/mnist.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) create mode 100644 binary/.gitignore diff --git a/binary/.gitignore b/binary/.gitignore new file mode 100644 index 0000000..6320cd2 --- /dev/null +++ b/binary/.gitignore @@ -0,0 +1 @@ +data \ No newline at end of file diff --git a/binary/mnist.py b/binary/mnist.py index f65dbce..6bf74d5 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -22,7 +22,7 @@ np.random.seed(1234) torch.cuda.manual_seed_all(1234) BS = 16 -LR = 0.01 +LR = 0.001 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") @@ -63,8 +63,8 @@ class Lut(torch.autograd.Function): if ctx.needs_input_grad[0]: - # grad_input = grad_output * torch.gather(weight, 0, ind) - grad_input = grad_output + grad_input = grad_output * torch.gather(weight, 0, ind) + # grad_input = grad_output grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) output = output.unsqueeze(-1).repeat(1, 1, bits) in_sign = ((input > 0).float() - 0.5) * 2.0 @@ -121,7 +121,7 @@ class LutGroup(nn.Module): def __init__(self, group, groupBits, groupRepeat=1): assert groupBits > 1 super(LutGroup, self).__init__() - self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group))) + self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group))) self.group = group self.groupBits = groupBits self.groupRepeat = groupRepeat @@ -166,7 +166,7 @@ class LutCnn(nn.Module): x = x.view(B, -1, self.kernel_size * self.kernel_size) x = self.lut(x) if self.fc: - x = x.view(B, -1, self.channel_repeat) + x = x.view(B, self.channel_repeat * C, -1) x = x.permute(0, 2, 1) x = self.lutc(x) return x