Update mnist.
This commit is contained in:
		
							parent
							
								
									f1124bc3b1
								
							
						
					
					
						commit
						f98a951b58
					
				
							
								
								
									
										140
									
								
								binary/mnist.py
								
								
								
								
							
							
						
						
									
										140
									
								
								binary/mnist.py
								
								
								
								
							| 
						 | 
					@ -12,15 +12,17 @@ from torch.utils.data import DataLoader
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from torch.utils.tensorboard import SummaryWriter
 | 
				
			||||||
 | 
					import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
torch.manual_seed(1234)
 | 
					torch.manual_seed(1234)
 | 
				
			||||||
np.random.seed(1234)
 | 
					np.random.seed(1234)
 | 
				
			||||||
torch.cuda.manual_seed_all(1234)
 | 
					torch.cuda.manual_seed_all(1234)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					batch_size = 16
 | 
				
			||||||
 | 
					lr = 0.001
 | 
				
			||||||
 | 
					
 | 
				
			||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
					device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
				
			||||||
# device = "cpu"
 | 
					 | 
				
			||||||
# torch.set_num_threads(16)
 | 
					 | 
				
			||||||
print(f"Using device: {device}")
 | 
					print(f"Using device: {device}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
transform = transforms.Compose(
 | 
					transform = transforms.Compose(
 | 
				
			||||||
| 
						 | 
					@ -30,7 +32,7 @@ transform = transforms.Compose(
 | 
				
			||||||
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
 | 
					train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
 | 
				
			||||||
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
 | 
					test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
 | 
					train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
 | 
				
			||||||
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
 | 
					test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -112,28 +114,22 @@ class SimpleCNN(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LutGroup(nn.Module):
 | 
					class LutGroup(nn.Module):
 | 
				
			||||||
    def __init__(self, bits, subbits):
 | 
					    def __init__(self, totalBits, groupBits=0, repeat=1):
 | 
				
			||||||
 | 
					        if not groupBits:
 | 
				
			||||||
 | 
					            groupBits = totalBits
 | 
				
			||||||
        super(LutGroup, self).__init__()
 | 
					        super(LutGroup, self).__init__()
 | 
				
			||||||
        assert (bits % subbits) == 0
 | 
					        assert (totalBits % groupBits) == 0
 | 
				
			||||||
        self.weight = nn.Parameter(torch.randn(bits, pow(2, subbits)))
 | 
					        self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits)))
 | 
				
			||||||
        self.bits = bits
 | 
					        self.totalBits = totalBits
 | 
				
			||||||
        self.subbits = subbits
 | 
					        self.groupBits = groupBits
 | 
				
			||||||
 | 
					        self.repeat = repeat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        # input   [ batch, totalBits ]
 | 
				
			||||||
 | 
					        # output  [ batch, totalBits / groupBits * repeat ]
 | 
				
			||||||
        batch = x.shape[0]
 | 
					        batch = x.shape[0]
 | 
				
			||||||
        x = x.view(batch, -1, self.subbits)
 | 
					        x = x.view(batch, -1, self.groupBits)
 | 
				
			||||||
        x = Lut.apply(x, self.weight)
 | 
					        x = x.repeat(1, self.repeat, 1)
 | 
				
			||||||
        return x
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class LutParallel(nn.Module):
 | 
					 | 
				
			||||||
    def __init__(self, bits, number):
 | 
					 | 
				
			||||||
        super(LutParallel, self).__init__()
 | 
					 | 
				
			||||||
        self.number = number
 | 
					 | 
				
			||||||
        self.weight = nn.Parameter(torch.randn(number, pow(2, bits)))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, x):
 | 
					 | 
				
			||||||
        x = x.unsqueeze(1).repeat(1, self.number, 1)
 | 
					 | 
				
			||||||
        x = Lut.apply(x, self.weight)
 | 
					        x = Lut.apply(x, self.weight)
 | 
				
			||||||
        return x
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -147,11 +143,13 @@ class SimpleBNN(nn.Module):
 | 
				
			||||||
        self.w = nn.Parameter(torch.randn(3, 784))
 | 
					        self.w = nn.Parameter(torch.randn(3, 784))
 | 
				
			||||||
        self.b = nn.Parameter(torch.zeros(3, 784))
 | 
					        self.b = nn.Parameter(torch.zeros(3, 784))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.lut1 = LutGroup(784 * 8, 8)
 | 
					        self.lut2_p = LutGroup(784, 4, 2)  # pool2 => 14*14*4
 | 
				
			||||||
        self.lut2 = LutGroup(784, 4)
 | 
					        self.lut3_p = LutGroup(14 * 14 * 2, 4, 2)  # pool2 => 7*7*4
 | 
				
			||||||
        self.lut3 = LutGroup(196, 4)
 | 
					        self.lut4_p = LutGroup(3 * 3 * 5 * 5 * 4, 9, 2)  # conv 3 => 5*5*8
 | 
				
			||||||
        self.lut4 = LutGroup(49, 7)
 | 
					        # self.lut4 = LutGroup(5 * 5 * 8, 8, 8)  # conv 3 => 5*5*8
 | 
				
			||||||
        self.lut5 = LutParallel(7, 10)
 | 
					        self.lut5_p = LutGroup(3 * 3 * 3 * 3 * 8, 9, 2)  # conv 3 => 3*3*8*2
 | 
				
			||||||
 | 
					        # self.lut5 = LutGroup(3 * 3 * 8 * 2, 16, 16)  # conv 3 => 3*3*16
 | 
				
			||||||
 | 
					        self.lut6_p = LutGroup(3 * 3 * 16, 9, 5)  # conv 3 => 80
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
 | 
					        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
 | 
				
			||||||
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
 | 
					        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
 | 
				
			||||||
| 
						 | 
					@ -171,28 +169,28 @@ class SimpleBNN(nn.Module):
 | 
				
			||||||
        # x = bits.view(batch, -1)
 | 
					        # x = bits.view(batch, -1)
 | 
				
			||||||
        # x = x.float() - 0.5
 | 
					        # x = x.float() - 0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x = (x > 0).float()
 | 
					        # x = (x > 0).float()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        q = x * self.w[0] + self.b[0]
 | 
					        # q = x * self.w[0] + self.b[0]
 | 
				
			||||||
        k = x * self.w[1] + self.b[1]
 | 
					        # k = x * self.w[1] + self.b[1]
 | 
				
			||||||
        v = x * self.w[2] + self.b[2]
 | 
					        # v = x * self.w[2] + self.b[2]
 | 
				
			||||||
        q = q.view(batch, -1, 1)
 | 
					        # q = q.view(batch, -1, 1)
 | 
				
			||||||
        k = k.view(batch, 1, -1)
 | 
					        # k = k.view(batch, 1, -1)
 | 
				
			||||||
        v = v.view(batch, -1, 1)
 | 
					        # v = v.view(batch, -1, 1)
 | 
				
			||||||
        kq = q @ k
 | 
					        # kq = q @ k
 | 
				
			||||||
        kqv = kq @ v
 | 
					        # kqv = kq @ v
 | 
				
			||||||
        kqv = kqv.view(batch, -1, 8)
 | 
					        # kqv = kqv.view(batch, -1, 8)
 | 
				
			||||||
        x = kqv
 | 
					        # x = kqv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #########################
 | 
					        #########################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # x = (x > 0) << xx
 | 
					        # # x = (x > 0) << xx
 | 
				
			||||||
        # x = x.sum(2)
 | 
					        # # x = x.sum(2)
 | 
				
			||||||
        # x = x.view(batch, 1, 28, 28)
 | 
					        # # x = x.view(batch, 1, 28, 28)
 | 
				
			||||||
        # x = (x - 128.0) / 256.0
 | 
					        # # x = (x - 128.0) / 256.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # x = x.view(batch, 1, 28, 28)
 | 
					 | 
				
			||||||
        # x = (x > 0).float()
 | 
					        # x = (x > 0).float()
 | 
				
			||||||
 | 
					        # x = x.view(batch, 1, 28, 28)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # x = self.relu(self.pool(self.conv1(x)))
 | 
					        # x = self.relu(self.pool(self.conv1(x)))
 | 
				
			||||||
        # x = self.relu(self.pool((self.conv2(x))))
 | 
					        # x = self.relu(self.pool((self.conv2(x))))
 | 
				
			||||||
| 
						 | 
					@ -202,19 +200,39 @@ class SimpleBNN(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #########################
 | 
					        #########################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x = (x > 0).float()
 | 
					        x = x.view(batch, 1, 28, 28)
 | 
				
			||||||
        x = x.view(batch, 196, 4)
 | 
					        x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
 | 
				
			||||||
 | 
					        x = x.view(batch, -1)
 | 
				
			||||||
 | 
					        x = self.lut2_p(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # x = self.lut1(x)
 | 
					        x = x.view(batch, 2, 14, 14)
 | 
				
			||||||
        x = self.lut2(x)
 | 
					        x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
 | 
				
			||||||
 | 
					        x = x.view(batch, -1)
 | 
				
			||||||
 | 
					        x = self.lut3_p(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x = x.view(-1, 28, 7)
 | 
					        x = x.view(batch, 4, 7, 7)
 | 
				
			||||||
        x = x.permute(0, 2, 1)
 | 
					        x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
 | 
				
			||||||
        x = x.reshape(-1, 28 * 7)
 | 
					        x = x.view(batch, -1)
 | 
				
			||||||
 | 
					        x = self.lut4_p(x)
 | 
				
			||||||
 | 
					        # x = self.lut4(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = x.view(batch, 8, 5, 5)
 | 
				
			||||||
 | 
					        x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
 | 
				
			||||||
 | 
					        x = x.view(batch, -1)
 | 
				
			||||||
 | 
					        x = self.lut5_p(x)
 | 
				
			||||||
 | 
					        # x = self.lut5(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = x.view(batch, 16, 3, 3)
 | 
				
			||||||
 | 
					        x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
 | 
				
			||||||
 | 
					        x = x.view(batch, -1)
 | 
				
			||||||
 | 
					        x = self.lut6_p(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # xx = 2 ** torch.arange(7, -1, -1).to(x.device)
 | 
				
			||||||
 | 
					        x = x.view(batch, -1, 8)
 | 
				
			||||||
 | 
					        # x = x * xx
 | 
				
			||||||
 | 
					        # x = (x - 128.0) / 256.0
 | 
				
			||||||
 | 
					        x = x.sum(2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x = self.lut3(x)
 | 
					 | 
				
			||||||
        x = self.lut4(x)
 | 
					 | 
				
			||||||
        x = self.lut5(x)
 | 
					 | 
				
			||||||
        return x
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -222,7 +240,12 @@ torch.autograd.set_detect_anomaly(True)
 | 
				
			||||||
# model = SimpleCNN().to(device)
 | 
					# model = SimpleCNN().to(device)
 | 
				
			||||||
model = SimpleBNN().to(device)
 | 
					model = SimpleBNN().to(device)
 | 
				
			||||||
criterion = nn.CrossEntropyLoss()
 | 
					criterion = nn.CrossEntropyLoss()
 | 
				
			||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
 | 
					optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
 | 
				
			||||||
 | 
					writer = SummaryWriter(f"log/{current_time}")
 | 
				
			||||||
 | 
					hparam_dict = {"lr": lr, "batch_size": batch_size}
 | 
				
			||||||
 | 
					writer.add_hparams(hparam_dict, {}, run_name=f"./")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def train(epoch):
 | 
					def train(epoch):
 | 
				
			||||||
| 
						 | 
					@ -240,15 +263,16 @@ def train(epoch):
 | 
				
			||||||
        loss = criterion(output, target)
 | 
					        loss = criterion(output, target)
 | 
				
			||||||
        loss.backward()
 | 
					        loss.backward()
 | 
				
			||||||
        optimizer.step()
 | 
					        optimizer.step()
 | 
				
			||||||
 | 
					        writer.add_scalar("loss", loss, epoch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if batch_idx % 100 == 0:
 | 
					        if batch_idx % 512 == 0:
 | 
				
			||||||
            print(
 | 
					            print(
 | 
				
			||||||
                f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
 | 
					                f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
 | 
				
			||||||
                f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
 | 
					                f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test():
 | 
					def test(epoch):
 | 
				
			||||||
    model.eval()
 | 
					    model.eval()
 | 
				
			||||||
    test_loss = 0
 | 
					    test_loss = 0
 | 
				
			||||||
    correct = 0
 | 
					    correct = 0
 | 
				
			||||||
| 
						 | 
					@ -262,15 +286,17 @@ def test():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    test_loss /= len(test_loader.dataset)
 | 
					    test_loss /= len(test_loader.dataset)
 | 
				
			||||||
    accuracy = 100.0 * correct / len(test_loader.dataset)
 | 
					    accuracy = 100.0 * correct / len(test_loader.dataset)
 | 
				
			||||||
 | 
					    writer.add_scalar("accuracy", accuracy, epoch)
 | 
				
			||||||
    print(
 | 
					    print(
 | 
				
			||||||
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
 | 
					        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
 | 
				
			||||||
        f"({accuracy:.0f}%)\n"
 | 
					        f"({accuracy:.0f}%)\n"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
for epoch in range(1, 30):
 | 
					for epoch in range(1, 300):
 | 
				
			||||||
    train(epoch)
 | 
					    train(epoch)
 | 
				
			||||||
    test()
 | 
					    test(epoch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# torch.save(model.state_dict(), "mnist_cnn.pth")
 | 
					# torch.save(model.state_dict(), "mnist_cnn.pth")
 | 
				
			||||||
print("Model saved to mnist_cnn.pth")
 | 
					print("Model saved to mnist_cnn.pth")
 | 
				
			||||||
 | 
					writer.close()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue