Refine LUTCNN, keep accuracy to ~93
This commit is contained in:
		
							parent
							
								
									5d03634595
								
							
						
					
					
						commit
						878c690ac4
					
				
							
								
								
									
										219
									
								
								binary/mnist.py
								
								
								
								
							
							
						
						
									
										219
									
								
								binary/mnist.py
								
								
								
								
							|  | @ -13,14 +13,16 @@ import math | |||
| import torch.nn.functional as F | ||||
| import numpy as np | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| from torch.profiler import profile, ProfilerActivity, record_function | ||||
| from unfold import generate_unfold_index | ||||
| import datetime | ||||
| 
 | ||||
| torch.manual_seed(1234) | ||||
| np.random.seed(1234) | ||||
| torch.cuda.manual_seed_all(1234) | ||||
| 
 | ||||
| batch_size = 16 | ||||
| lr = 0.001 | ||||
| BS = 16 | ||||
| LR = 0.001 | ||||
| 
 | ||||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
| print(f"Using device: {device}") | ||||
|  | @ -32,59 +34,38 @@ transform = transforms.Compose( | |||
| train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) | ||||
| test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) | ||||
| 
 | ||||
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | ||||
| test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False) | ||||
| 
 | ||||
| 
 | ||||
| def to_binary_tensor(input_tensor, bits): | ||||
|     int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64) | ||||
|     shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device) | ||||
|     binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1 | ||||
|     return binary_bits | ||||
| train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True, drop_last=True, num_workers=4) | ||||
| test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=True) | ||||
| 
 | ||||
| 
 | ||||
| class Lut(torch.autograd.Function): | ||||
|     # input [batch,count,bits] | ||||
|     # weight [count,2**bits] | ||||
|     # input  [batch, group, bits ] | ||||
|     # output [batch, group ] | ||||
|     # weight [2**bits, group ] | ||||
|     @staticmethod | ||||
|     def forward(ctx, input, weight): | ||||
|         batch = input.shape[0] | ||||
|         count = input.shape[1] | ||||
|         bits = input.shape[2] | ||||
|         assert int(math.log2(weight.shape[-1])) == bits | ||||
| 
 | ||||
|         index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device) | ||||
|         x = (input > 0).long() | ||||
|         x = x * index | ||||
|         ind = x.sum(dim=-1) | ||||
| 
 | ||||
|         row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1) | ||||
|         output = weight[row_indices, ind] | ||||
| 
 | ||||
|         ctx.save_for_backward(input, weight, ind) | ||||
|     def forward(ctx, input, weight, index): | ||||
|         ind = ((input > 0).long() * index).sum(dim=-1) | ||||
|         output = torch.gather(weight, 0, ind) | ||||
|         # output = (output > 0).float() | ||||
|         # output = (output - 0.5) * 2.0 | ||||
|         ctx.save_for_backward(input, weight, ind, output) | ||||
|         return output | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def backward(ctx, grad_output): | ||||
|         input, weight, ind = ctx.saved_tensors | ||||
|         input, weight, ind, output = ctx.saved_tensors | ||||
|         grad_input = grad_weight = None | ||||
| 
 | ||||
|         batch = input.shape[0] | ||||
|         count = input.shape[1] | ||||
|         bits = input.shape[2] | ||||
| 
 | ||||
|         if ctx.needs_input_grad[1]: | ||||
|             grad_weight = torch.zeros_like(weight) | ||||
|             ind_p = ind.permute(1, 0) | ||||
|             grad_output_p = grad_output.permute(1, 0) | ||||
|             grad_weight.scatter_add_(1, ind_p, grad_output_p) | ||||
|             grad_weight.scatter_add_(0, ind, grad_output) | ||||
| 
 | ||||
|         if ctx.needs_input_grad[0]: | ||||
|             row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1) | ||||
|             grad_input = grad_output * weight[row_indices, ind] | ||||
|             grad_input = grad_output * torch.gather(weight, 0, ind) | ||||
|             grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) | ||||
| 
 | ||||
|         return grad_input, grad_weight | ||||
|         return grad_input, grad_weight, None, None | ||||
| 
 | ||||
| 
 | ||||
| class SimpleCNN(nn.Module): | ||||
|  | @ -114,23 +95,57 @@ class SimpleCNN(nn.Module): | |||
| 
 | ||||
| 
 | ||||
| class LutGroup(nn.Module): | ||||
|     def __init__(self, totalBits, groupBits=0, repeat=1): | ||||
|         if not groupBits: | ||||
|             groupBits = totalBits | ||||
|     def __init__(self, group, groupBits, groupRepeat=1): | ||||
|         assert groupBits > 1 | ||||
|         super(LutGroup, self).__init__() | ||||
|         assert (totalBits % groupBits) == 0 | ||||
|         self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits))) | ||||
|         self.totalBits = totalBits | ||||
|         self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group))) | ||||
|         self.group = group | ||||
|         self.groupBits = groupBits | ||||
|         self.repeat = repeat | ||||
|         self.groupRepeat = groupRepeat | ||||
|         self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False) | ||||
| 
 | ||||
|     def forward(self, x): | ||||
|         # input   [ batch, totalBits ] | ||||
|         # output  [ batch, totalBits / groupBits * repeat ] | ||||
|         # input   [ batch, group * groupBits ] | ||||
|         # output  [ batch, group * groupRepeat ] | ||||
|         batch = x.shape[0] | ||||
|         x = x.view(batch, -1, self.groupBits) | ||||
|         x = x.repeat(1, self.repeat, 1) | ||||
|         x = Lut.apply(x, self.weight) | ||||
|         x = x.reshape(batch, -1, self.groupBits) | ||||
|         if self.groupRepeat > 1: | ||||
|             x = x.repeat(1, self.groupRepeat, 1) | ||||
|         x = Lut.apply(x, self.weight, self.index) | ||||
|         return x | ||||
| 
 | ||||
| 
 | ||||
| class LutCnn(nn.Module): | ||||
|     def __init__(self, channel_repeat, input_shape, kernel_size, stride, dilation, fc=False): | ||||
|         super(LutCnn, self).__init__() | ||||
|         B, C, H, W = input_shape | ||||
|         self.input_shape = input_shape | ||||
|         self.kernel_size = kernel_size | ||||
|         self.channel_repeat = channel_repeat | ||||
|         self.stride = stride | ||||
|         self.dilation = dilation | ||||
|         batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation) | ||||
|         self.batch_idx = nn.Parameter(batch_idx, requires_grad=False) | ||||
|         self.channel_idx = nn.Parameter(channel_idx, requires_grad=False) | ||||
|         self.h_idx = nn.Parameter(h_idx, requires_grad=False) | ||||
|         self.w_idx = nn.Parameter(w_idx, requires_grad=False) | ||||
|         groupBits = kernel_size * kernel_size | ||||
|         group = int(len(self.batch_idx) / B / groupBits) | ||||
|         self.lut = LutGroup(group, groupBits, channel_repeat) | ||||
|         self.fc = fc | ||||
|         if fc: | ||||
|             self.lutc = LutGroup(group, channel_repeat * C, channel_repeat * C) | ||||
| 
 | ||||
|     def forward(self, x): | ||||
|         B, C, H, W = self.input_shape | ||||
|         x = x.view(self.input_shape) | ||||
|         x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)] | ||||
|         x = x.view(B, -1, self.kernel_size * self.kernel_size) | ||||
|         x = self.lut(x) | ||||
|         if self.fc: | ||||
|             x = x.view(B, self.channel_repeat * C, -1) | ||||
|             x = x.permute(0, 2, 1) | ||||
|             x = self.lutc(x) | ||||
|         return x | ||||
| 
 | ||||
| 
 | ||||
|  | @ -143,14 +158,12 @@ class SimpleBNN(nn.Module): | |||
|         self.w = nn.Parameter(torch.randn(3, 784)) | ||||
|         self.b = nn.Parameter(torch.zeros(3, 784)) | ||||
| 
 | ||||
|         self.lut2_p = LutGroup(784, 4, 80)  # pool2 => 14*14*4 | ||||
|         self.lut3_p = LutGroup(80 * 14 * 14, 4, 1)  # pool2 => 7*7*4 | ||||
|         self.lut4_p = LutGroup(80 * 5 * 5 * 3 * 3, 9, 1)  # conv 3 => 5*5*8 | ||||
|         # self.lut4 = LutGroup(8 * 5 * 5, 8, 8)  # conv 3 => 5*5*8 | ||||
|         self.lut5_p = LutGroup(80 * 3 * 3 * 3 * 3, 9, 1)  # conv 3 => 3*3*16 | ||||
|         # self.lut5 = LutGroup(16 * 3 * 3, 16, 10)  # conv 3 => 3*3*16 | ||||
|         self.lut6_p = LutGroup(80 * 3 * 3, 9, 1)  # conv 3 => 128 | ||||
|         # self.lut7_p = LutGroup(8 * 16, 16, 10)  # fc 128 => 80 | ||||
|         # channel_repeat, input_shape, kernel_size, stride, dilation, fc | ||||
|         self.lnn1 = LutCnn(80, (BS, 1, 28, 28), 2, 2, 1, False) | ||||
|         self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False) | ||||
|         self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False) | ||||
|         self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False) | ||||
|         self.lnn5 = LutCnn(10, (BS, 80, 3, 3), 3, 1, 1) | ||||
| 
 | ||||
|         self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | ||||
|         self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | ||||
|  | @ -159,9 +172,9 @@ class SimpleBNN(nn.Module): | |||
|         self.pool = nn.MaxPool2d(2) | ||||
|         self.relu = nn.ReLU() | ||||
| 
 | ||||
|     def forward(self, x): | ||||
|     def forward(self, x, t): | ||||
|         batch = x.shape[0] | ||||
|         x = x.view(batch, -1) | ||||
|         # x = x.view(batch, -1) | ||||
| 
 | ||||
|         # 变换x [-0.5:0.5] 到 0-255,然后按照二进制展开成8个值 | ||||
|         # x = (x * 256 + 128).clamp(0, 255).to(torch.uint8) | ||||
|  | @ -201,43 +214,11 @@ class SimpleBNN(nn.Module): | |||
| 
 | ||||
|         ######################### | ||||
| 
 | ||||
|         x = x.view(batch, 1, 28, 28) | ||||
|         x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2) | ||||
|         x = x.view(batch, -1) | ||||
|         x = self.lut2_p(x) | ||||
| 
 | ||||
|         x = x.view(batch, 80, 14, 14) | ||||
|         x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2) | ||||
|         x = x.view(batch, -1) | ||||
|         x = self.lut3_p(x) | ||||
| 
 | ||||
|         x = x.view(batch, 80, 7, 7) | ||||
|         x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) | ||||
|         x = x.view(batch, -1) | ||||
|         x = self.lut4_p(x) | ||||
| 
 | ||||
|         # x = x.view(batch, 8, 5, 5) | ||||
|         # x = x.permute(0, 2, 3, 1) | ||||
|         # x = x.reshape(batch, -1) | ||||
|         # x = self.lut4(x) | ||||
| 
 | ||||
|         x = x.view(batch, 80, 5, 5) | ||||
|         x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) | ||||
|         x = x.view(batch, -1) | ||||
|         x = self.lut5_p(x) | ||||
| 
 | ||||
|         # x = x.view(batch, 16, 3, 3) | ||||
|         # x = x.permute(0, 2, 3, 1) | ||||
|         # x = x.reshape(batch, -1) | ||||
|         # x = self.lut5(x) | ||||
| 
 | ||||
|         x = x.view(batch, 80, 3, 3) | ||||
|         x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) | ||||
|         x = x.view(batch, -1) | ||||
|         x = self.lut6_p(x) | ||||
| 
 | ||||
|         # x = x.view(batch, 8, 16)  # 16 channel   8 value/channel | ||||
|         # x = self.lut7_p(x)  # 10 * 8 value/channel | ||||
|         x = self.lnn1(x) | ||||
|         x = self.lnn2(x) | ||||
|         x = self.lnn3(x) | ||||
|         x = self.lnn4(x) | ||||
|         x = self.lnn5(x) | ||||
| 
 | ||||
|         # xx = 2 ** torch.arange(7, -1, -1).to(x.device) | ||||
|         x = x.view(batch, -1, 8) | ||||
|  | @ -247,17 +228,29 @@ class SimpleBNN(nn.Module): | |||
| 
 | ||||
|         return x | ||||
| 
 | ||||
|     def printWeight(self): | ||||
|         pass | ||||
| 
 | ||||
| 
 | ||||
| torch.autograd.set_detect_anomaly(True) | ||||
| # model = SimpleCNN().to(device) | ||||
| model = SimpleBNN().to(device) | ||||
| # model = SimpleLNN().to(device) | ||||
| criterion = nn.CrossEntropyLoss() | ||||
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr) | ||||
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR) | ||||
| 
 | ||||
| current_time = datetime.datetime.now().strftime("%m%d-%H%M%S") | ||||
| writer = SummaryWriter(f"log/{current_time}") | ||||
| hparam_dict = {"lr": lr, "batch_size": batch_size} | ||||
| writer.add_hparams(hparam_dict, {}, run_name=f"./") | ||||
| tbWriter = None | ||||
| 
 | ||||
| 
 | ||||
| def AddScalar(tag, value, epoch): | ||||
|     global tbWriter | ||||
|     if not tbWriter: | ||||
|         current_time = datetime.datetime.now().strftime("%m%d-%H%M%S") | ||||
|         tbWriter = SummaryWriter(f"log/{current_time}") | ||||
|         hparam_dict = {"lr": LR, "batch_size": BS} | ||||
|         tbWriter.add_hparams(hparam_dict, {}, run_name=f"./") | ||||
| 
 | ||||
|     tbWriter.add_scalar(tag, value, epoch) | ||||
| 
 | ||||
| 
 | ||||
| def train(epoch): | ||||
|  | @ -265,19 +258,12 @@ def train(epoch): | |||
|     for batch_idx, (data, target) in enumerate(train_loader): | ||||
|         data, target = data.to(device), target.to(device) | ||||
|         optimizer.zero_grad() | ||||
|         output = model(data) | ||||
|         # output = output * 1.0 | ||||
|         # output = F.softmax(output, dim=1) | ||||
| 
 | ||||
|         # print(output.requires_grad) | ||||
|         # print(output.grad_fn) | ||||
| 
 | ||||
|         output = model(data, target) | ||||
|         loss = criterion(output, target) | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|         writer.add_scalar("loss", loss, epoch) | ||||
| 
 | ||||
|         if batch_idx % 512 == 0: | ||||
|         AddScalar("loss", loss, epoch) | ||||
|         if batch_idx % 1024 == 0 and batch_idx > 0: | ||||
|             print( | ||||
|                 f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} " | ||||
|                 f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}" | ||||
|  | @ -291,18 +277,19 @@ def test(epoch): | |||
|     with torch.no_grad(): | ||||
|         for data, target in test_loader: | ||||
|             data, target = data.to(device), target.to(device) | ||||
|             output = model(data) | ||||
|             output = model(data, target) | ||||
|             test_loss += criterion(output, target).item() | ||||
|             pred = output.argmax(dim=1, keepdim=True) | ||||
|             correct += pred.eq(target.view_as(pred)).sum().item() | ||||
| 
 | ||||
|     test_loss /= len(test_loader.dataset) | ||||
|     accuracy = 100.0 * correct / len(test_loader.dataset) | ||||
|     writer.add_scalar("accuracy", accuracy, epoch) | ||||
|     AddScalar("accuracy", accuracy, epoch) | ||||
|     print( | ||||
|         f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} " | ||||
|         f"({accuracy:.0f}%)\n" | ||||
|     ) | ||||
|     model.printWeight() | ||||
| 
 | ||||
| 
 | ||||
| for epoch in range(1, 300): | ||||
|  | @ -311,4 +298,6 @@ for epoch in range(1, 300): | |||
| 
 | ||||
| # torch.save(model.state_dict(), "mnist_cnn.pth") | ||||
| print("Model saved to mnist_cnn.pth") | ||||
| writer.close() | ||||
| 
 | ||||
| if tbWriter: | ||||
|     tbWriter.close() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue