set the channel to 80 at first layer accuracy to 94.
This commit is contained in:
parent
3eb711a97e
commit
855296be55
|
@ -143,13 +143,14 @@ class SimpleBNN(nn.Module):
|
|||
self.w = nn.Parameter(torch.randn(3, 784))
|
||||
self.b = nn.Parameter(torch.zeros(3, 784))
|
||||
|
||||
self.lut2_p = LutGroup(784, 4, 2) # pool2 => 14*14*4
|
||||
self.lut3_p = LutGroup(14 * 14 * 2, 4, 2) # pool2 => 7*7*4
|
||||
self.lut4_p = LutGroup(3 * 3 * 5 * 5 * 4, 9, 2) # conv 3 => 5*5*8
|
||||
# self.lut4 = LutGroup(5 * 5 * 8, 8, 8) # conv 3 => 5*5*8
|
||||
self.lut5_p = LutGroup(3 * 3 * 3 * 3 * 8, 9, 2) # conv 3 => 3*3*8*2
|
||||
# self.lut5 = LutGroup(3 * 3 * 8 * 2, 16, 16) # conv 3 => 3*3*16
|
||||
self.lut6_p = LutGroup(3 * 3 * 16, 9, 5) # conv 3 => 80
|
||||
self.lut2_p = LutGroup(784, 4, 80) # pool2 => 14*14*4
|
||||
self.lut3_p = LutGroup(80 * 14 * 14, 4, 1) # pool2 => 7*7*4
|
||||
self.lut4_p = LutGroup(80 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8
|
||||
# self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8
|
||||
self.lut5_p = LutGroup(80 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16
|
||||
# self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16
|
||||
self.lut6_p = LutGroup(80 * 3 * 3, 9, 1) # conv 3 => 128
|
||||
# self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 80
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
|
@ -205,28 +206,39 @@ class SimpleBNN(nn.Module):
|
|||
x = x.view(batch, -1)
|
||||
x = self.lut2_p(x)
|
||||
|
||||
x = x.view(batch, 2, 14, 14)
|
||||
x = x.view(batch, 80, 14, 14)
|
||||
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
||||
x = x.view(batch, -1)
|
||||
x = self.lut3_p(x)
|
||||
|
||||
x = x.view(batch, 4, 7, 7)
|
||||
x = x.view(batch, 80, 7, 7)
|
||||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||
x = x.view(batch, -1)
|
||||
x = self.lut4_p(x)
|
||||
|
||||
# x = x.view(batch, 8, 5, 5)
|
||||
# x = x.permute(0, 2, 3, 1)
|
||||
# x = x.reshape(batch, -1)
|
||||
# x = self.lut4(x)
|
||||
|
||||
x = x.view(batch, 8, 5, 5)
|
||||
x = x.view(batch, 80, 5, 5)
|
||||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||
x = x.view(batch, -1)
|
||||
x = self.lut5_p(x)
|
||||
|
||||
# x = x.view(batch, 16, 3, 3)
|
||||
# x = x.permute(0, 2, 3, 1)
|
||||
# x = x.reshape(batch, -1)
|
||||
# x = self.lut5(x)
|
||||
|
||||
x = x.view(batch, 16, 3, 3)
|
||||
x = x.view(batch, 80, 3, 3)
|
||||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||
x = x.view(batch, -1)
|
||||
x = self.lut6_p(x)
|
||||
|
||||
# x = x.view(batch, 8, 16) # 16 channel 8 value/channel
|
||||
# x = self.lut7_p(x) # 10 * 8 value/channel
|
||||
|
||||
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
|
||||
x = x.view(batch, -1, 8)
|
||||
# x = x * xx
|
||||
|
|
Loading…
Reference in New Issue