From 855296be558b88c4e7707feac4f3923e01f37246 Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 25 May 2025 15:33:39 +0800 Subject: [PATCH] set the channel to 80 at first layer accuracy to 94. --- binary/mnist.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/binary/mnist.py b/binary/mnist.py index 8f59a4a..728fb38 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -143,13 +143,14 @@ class SimpleBNN(nn.Module): self.w = nn.Parameter(torch.randn(3, 784)) self.b = nn.Parameter(torch.zeros(3, 784)) - self.lut2_p = LutGroup(784, 4, 2) # pool2 => 14*14*4 - self.lut3_p = LutGroup(14 * 14 * 2, 4, 2) # pool2 => 7*7*4 - self.lut4_p = LutGroup(3 * 3 * 5 * 5 * 4, 9, 2) # conv 3 => 5*5*8 - # self.lut4 = LutGroup(5 * 5 * 8, 8, 8) # conv 3 => 5*5*8 - self.lut5_p = LutGroup(3 * 3 * 3 * 3 * 8, 9, 2) # conv 3 => 3*3*8*2 - # self.lut5 = LutGroup(3 * 3 * 8 * 2, 16, 16) # conv 3 => 3*3*16 - self.lut6_p = LutGroup(3 * 3 * 16, 9, 5) # conv 3 => 80 + self.lut2_p = LutGroup(784, 4, 80) # pool2 => 14*14*4 + self.lut3_p = LutGroup(80 * 14 * 14, 4, 1) # pool2 => 7*7*4 + self.lut4_p = LutGroup(80 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8 + # self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8 + self.lut5_p = LutGroup(80 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16 + # self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16 + self.lut6_p = LutGroup(80 * 3 * 3, 9, 1) # conv 3 => 128 + # self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 80 self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) @@ -205,28 +206,39 @@ class SimpleBNN(nn.Module): x = x.view(batch, -1) x = self.lut2_p(x) - x = x.view(batch, 2, 14, 14) + x = x.view(batch, 80, 14, 14) x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2) x = x.view(batch, -1) x = self.lut3_p(x) - x = x.view(batch, 4, 7, 7) + x = x.view(batch, 80, 7, 7) x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) x = x.view(batch, -1) x = self.lut4_p(x) + + # x = x.view(batch, 8, 5, 5) + # x = x.permute(0, 2, 3, 1) + # x = x.reshape(batch, -1) # x = self.lut4(x) - x = x.view(batch, 8, 5, 5) + x = x.view(batch, 80, 5, 5) x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) x = x.view(batch, -1) x = self.lut5_p(x) + + # x = x.view(batch, 16, 3, 3) + # x = x.permute(0, 2, 3, 1) + # x = x.reshape(batch, -1) # x = self.lut5(x) - x = x.view(batch, 16, 3, 3) + x = x.view(batch, 80, 3, 3) x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) x = x.view(batch, -1) x = self.lut6_p(x) + # x = x.view(batch, 8, 16) # 16 channel 8 value/channel + # x = self.lut7_p(x) # 10 * 8 value/channel + # xx = 2 ** torch.arange(7, -1, -1).to(x.device) x = x.view(batch, -1, 8) # x = x * xx