set the channel to 80 at first layer accuracy to 94.

This commit is contained in:
Colin 2025-05-25 15:33:39 +08:00
parent 3eb711a97e
commit 855296be55
1 changed files with 23 additions and 11 deletions

View File

@ -143,13 +143,14 @@ class SimpleBNN(nn.Module):
self.w = nn.Parameter(torch.randn(3, 784))
self.b = nn.Parameter(torch.zeros(3, 784))
self.lut2_p = LutGroup(784, 4, 2) # pool2 => 14*14*4
self.lut3_p = LutGroup(14 * 14 * 2, 4, 2) # pool2 => 7*7*4
self.lut4_p = LutGroup(3 * 3 * 5 * 5 * 4, 9, 2) # conv 3 => 5*5*8
# self.lut4 = LutGroup(5 * 5 * 8, 8, 8) # conv 3 => 5*5*8
self.lut5_p = LutGroup(3 * 3 * 3 * 3 * 8, 9, 2) # conv 3 => 3*3*8*2
# self.lut5 = LutGroup(3 * 3 * 8 * 2, 16, 16) # conv 3 => 3*3*16
self.lut6_p = LutGroup(3 * 3 * 16, 9, 5) # conv 3 => 80
self.lut2_p = LutGroup(784, 4, 80) # pool2 => 14*14*4
self.lut3_p = LutGroup(80 * 14 * 14, 4, 1) # pool2 => 7*7*4
self.lut4_p = LutGroup(80 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8
# self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8
self.lut5_p = LutGroup(80 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16
# self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16
self.lut6_p = LutGroup(80 * 3 * 3, 9, 1) # conv 3 => 128
# self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 80
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
@ -205,28 +206,39 @@ class SimpleBNN(nn.Module):
x = x.view(batch, -1)
x = self.lut2_p(x)
x = x.view(batch, 2, 14, 14)
x = x.view(batch, 80, 14, 14)
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
x = x.view(batch, -1)
x = self.lut3_p(x)
x = x.view(batch, 4, 7, 7)
x = x.view(batch, 80, 7, 7)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
x = self.lut4_p(x)
# x = x.view(batch, 8, 5, 5)
# x = x.permute(0, 2, 3, 1)
# x = x.reshape(batch, -1)
# x = self.lut4(x)
x = x.view(batch, 8, 5, 5)
x = x.view(batch, 80, 5, 5)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
x = self.lut5_p(x)
# x = x.view(batch, 16, 3, 3)
# x = x.permute(0, 2, 3, 1)
# x = x.reshape(batch, -1)
# x = self.lut5(x)
x = x.view(batch, 16, 3, 3)
x = x.view(batch, 80, 3, 3)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
x = self.lut6_p(x)
# x = x.view(batch, 8, 16) # 16 channel 8 value/channel
# x = self.lut7_p(x) # 10 * 8 value/channel
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
x = x.view(batch, -1, 8)
# x = x * xx