diff --git a/binary/mnist.py b/binary/mnist.py index 6bf74d5..728fb38 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -13,16 +13,14 @@ import math import torch.nn.functional as F import numpy as np from torch.utils.tensorboard import SummaryWriter -from torch.profiler import profile, ProfilerActivity, record_function -from unfold import generate_unfold_index import datetime torch.manual_seed(1234) np.random.seed(1234) torch.cuda.manual_seed_all(1234) -BS = 16 -LR = 0.001 +batch_size = 16 +lr = 0.001 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") @@ -34,61 +32,59 @@ transform = transforms.Compose( train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) -train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True, drop_last=True, num_workers=4) -test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=True) +train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) +test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False) + + +def to_binary_tensor(input_tensor, bits): + int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64) + shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device) + binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1 + return binary_bits class Lut(torch.autograd.Function): - # input [batch, group, bits ] - # output [batch, group ] - # weight [2**bits, group ] + # input [batch,count,bits] + # weight [count,2**bits] @staticmethod - def forward(ctx, input, weight, index): - ind = ((input > 0).long() * index).sum(dim=-1) - output = torch.gather(weight, 0, ind) - output = (output > 0).float() - output = (output - 0.5) * 2.0 - ctx.save_for_backward(input, weight, ind, output) + def forward(ctx, input, weight): + batch = input.shape[0] + count = input.shape[1] + bits = input.shape[2] + assert int(math.log2(weight.shape[-1])) == bits + + index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device) + x = (input > 0).long() + x = x * index + ind = x.sum(dim=-1) + + row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1) + output = weight[row_indices, ind] + + ctx.save_for_backward(input, weight, ind) return output @staticmethod def backward(ctx, grad_output): - input, weight, ind, output = ctx.saved_tensors + input, weight, ind = ctx.saved_tensors grad_input = grad_weight = None + + batch = input.shape[0] + count = input.shape[1] bits = input.shape[2] if ctx.needs_input_grad[1]: grad_weight = torch.zeros_like(weight) - grad_weight.scatter_add_(0, ind, grad_output) + ind_p = ind.permute(1, 0) + grad_output_p = grad_output.permute(1, 0) + grad_weight.scatter_add_(1, ind_p, grad_output_p) if ctx.needs_input_grad[0]: - - grad_input = grad_output * torch.gather(weight, 0, ind) - # grad_input = grad_output + row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1) + grad_input = grad_output * weight[row_indices, ind] grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) - output = output.unsqueeze(-1).repeat(1, 1, bits) - in_sign = ((input > 0).float() - 0.5) * 2.0 - grad_input = grad_input * in_sign - grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0) - # grad_input = grad_output - # grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) - # output = output.unsqueeze(-1).repeat(1, 1, bits) - # in_sign = ((input > 0).float() - 0.5) * 2.0 - # out_sign = ((output > 0).float() - 0.5) * 2.0 - # grad_sign = ((grad_input > 0).float() - 0.5) * 2.0 - # grad_input = grad_input * in_sign * (out_sign * grad_sign) - # grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0) - - # 需要一个动态的调整系数 - # 能稳定的收敛 - - # print(in_sign[0].detach().cpu().numpy()) - # print(out_sign[0].detach().cpu().numpy()) - # print(grad_sign[0].detach().cpu().numpy()) - # print(grad_input[0].detach().cpu().numpy()) - - return grad_input, grad_weight, None, None + return grad_input, grad_weight class SimpleCNN(nn.Module): @@ -118,57 +114,23 @@ class SimpleCNN(nn.Module): class LutGroup(nn.Module): - def __init__(self, group, groupBits, groupRepeat=1): - assert groupBits > 1 + def __init__(self, totalBits, groupBits=0, repeat=1): + if not groupBits: + groupBits = totalBits super(LutGroup, self).__init__() - self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group))) - self.group = group + assert (totalBits % groupBits) == 0 + self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits))) + self.totalBits = totalBits self.groupBits = groupBits - self.groupRepeat = groupRepeat - self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False) + self.repeat = repeat def forward(self, x): - # input [ batch, group * groupBits ] - # output [ batch, group * groupRepeat ] + # input [ batch, totalBits ] + # output [ batch, totalBits / groupBits * repeat ] batch = x.shape[0] - x = x.reshape(batch, -1, self.groupBits) - if self.groupRepeat > 1: - x = x.repeat(1, self.groupRepeat, 1) - x = Lut.apply(x, self.weight, self.index) - return x - - -class LutCnn(nn.Module): - def __init__(self, channel_repeat, input_shape, kernel_size, stride, dilation, fc=False): - super(LutCnn, self).__init__() - B, C, H, W = input_shape - self.input_shape = input_shape - self.kernel_size = kernel_size - self.channel_repeat = channel_repeat - self.stride = stride - self.dilation = dilation - batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation) - self.batch_idx = nn.Parameter(batch_idx, requires_grad=False) - self.channel_idx = nn.Parameter(channel_idx, requires_grad=False) - self.h_idx = nn.Parameter(h_idx, requires_grad=False) - self.w_idx = nn.Parameter(w_idx, requires_grad=False) - groupBits = kernel_size * kernel_size - group = int(len(self.batch_idx) / B / groupBits) - self.lut = LutGroup(group, groupBits, channel_repeat) - self.fc = fc - if fc: - self.lutc = LutGroup(group, channel_repeat * C, channel_repeat * C) - - def forward(self, x): - B, C, H, W = self.input_shape - x = x.view(self.input_shape) - x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)] - x = x.view(B, -1, self.kernel_size * self.kernel_size) - x = self.lut(x) - if self.fc: - x = x.view(B, self.channel_repeat * C, -1) - x = x.permute(0, 2, 1) - x = self.lutc(x) + x = x.view(batch, -1, self.groupBits) + x = x.repeat(1, self.repeat, 1) + x = Lut.apply(x, self.weight) return x @@ -181,12 +143,14 @@ class SimpleBNN(nn.Module): self.w = nn.Parameter(torch.randn(3, 784)) self.b = nn.Parameter(torch.zeros(3, 784)) - # channel_repeat, input_shape, kernel_size, stride, dilation, fc - self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1, False) - self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1, False) - self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1, False) - self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1, False) - self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1) + self.lut2_p = LutGroup(784, 4, 80) # pool2 => 14*14*4 + self.lut3_p = LutGroup(80 * 14 * 14, 4, 1) # pool2 => 7*7*4 + self.lut4_p = LutGroup(80 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8 + # self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8 + self.lut5_p = LutGroup(80 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16 + # self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16 + self.lut6_p = LutGroup(80 * 3 * 3, 9, 1) # conv 3 => 128 + # self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 80 self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) @@ -195,9 +159,9 @@ class SimpleBNN(nn.Module): self.pool = nn.MaxPool2d(2) self.relu = nn.ReLU() - def forward(self, x, t): + def forward(self, x): batch = x.shape[0] - # x = x.view(batch, -1) + x = x.view(batch, -1) # 变换x [-0.5:0.5] 到 0-255,然后按照二进制展开成8个值 # x = (x * 256 + 128).clamp(0, 255).to(torch.uint8) @@ -237,11 +201,43 @@ class SimpleBNN(nn.Module): ######################### - x = self.lnn1(x) - x = self.lnn2(x) - x = self.lnn3(x) - x = self.lnn4(x) - x = self.lnn5(x) + x = x.view(batch, 1, 28, 28) + x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2) + x = x.view(batch, -1) + x = self.lut2_p(x) + + x = x.view(batch, 80, 14, 14) + x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2) + x = x.view(batch, -1) + x = self.lut3_p(x) + + x = x.view(batch, 80, 7, 7) + x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) + x = x.view(batch, -1) + x = self.lut4_p(x) + + # x = x.view(batch, 8, 5, 5) + # x = x.permute(0, 2, 3, 1) + # x = x.reshape(batch, -1) + # x = self.lut4(x) + + x = x.view(batch, 80, 5, 5) + x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) + x = x.view(batch, -1) + x = self.lut5_p(x) + + # x = x.view(batch, 16, 3, 3) + # x = x.permute(0, 2, 3, 1) + # x = x.reshape(batch, -1) + # x = self.lut5(x) + + x = x.view(batch, 80, 3, 3) + x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1) + x = x.view(batch, -1) + x = self.lut6_p(x) + + # x = x.view(batch, 8, 16) # 16 channel 8 value/channel + # x = self.lut7_p(x) # 10 * 8 value/channel # xx = 2 ** torch.arange(7, -1, -1).to(x.device) x = x.view(batch, -1, 8) @@ -251,62 +247,17 @@ class SimpleBNN(nn.Module): return x - def printWeight(self): - pass - - -class SimpleLNN(nn.Module): - def __init__(self): - super(SimpleLNN, self).__init__() - # group, groupBits, groupRepeat - self.lutg1 = LutGroup(1, 10, 4) - self.lutg2 = LutGroup(1, 4, 10) - - def forward(self, x, t): - batch = x.shape[0] - - x = torch.zeros_like(t).unsqueeze(-1).repeat(1, 10) - x[torch.arange(0, batch), t] = 1 - - x = self.lutg1(x) - x = self.lutg2(x) - - return x - - def printWeight(self): - print("self.lutg1") - print(self.lutg1.weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy()) - print("=============================") - print("=============================") - print("self.lutg1.grad") - print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy()) - print("=============================") - print("=============================") - # print("self.lutg2") - # print(self.lutg2.weight.detach().cpu().numpy()) - # print("=============================") - # print("=============================") - torch.autograd.set_detect_anomaly(True) # model = SimpleCNN().to(device) -# model = SimpleBNN().to(device) -model = SimpleLNN().to(device) +model = SimpleBNN().to(device) criterion = nn.CrossEntropyLoss() -optimizer = torch.optim.AdamW(model.parameters(), lr=LR) +optimizer = torch.optim.AdamW(model.parameters(), lr=lr) -tbWriter = None - - -def AddScalar(tag, value, epoch): - global tbWriter - if not tbWriter: - current_time = datetime.datetime.now().strftime("%m%d-%H%M%S") - tbWriter = SummaryWriter(f"log/{current_time}") - hparam_dict = {"lr": LR, "batch_size": BS} - tbWriter.add_hparams(hparam_dict, {}, run_name=f"./") - - tbWriter.add_scalar(tag, value, epoch) +current_time = datetime.datetime.now().strftime("%m%d-%H%M%S") +writer = SummaryWriter(f"log/{current_time}") +hparam_dict = {"lr": lr, "batch_size": batch_size} +writer.add_hparams(hparam_dict, {}, run_name=f"./") def train(epoch): @@ -314,12 +265,19 @@ def train(epoch): for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() - output = model(data, target) + output = model(data) + # output = output * 1.0 + # output = F.softmax(output, dim=1) + + # print(output.requires_grad) + # print(output.grad_fn) + loss = criterion(output, target) loss.backward() optimizer.step() - AddScalar("loss", loss, epoch) - if batch_idx % 1024 == 0 and batch_idx > 0: + writer.add_scalar("loss", loss, epoch) + + if batch_idx % 512 == 0: print( f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} " f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}" @@ -333,45 +291,24 @@ def test(epoch): with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) - output = model(data, target) + output = model(data) test_loss += criterion(output, target).item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) accuracy = 100.0 * correct / len(test_loader.dataset) - AddScalar("accuracy", accuracy, epoch) + writer.add_scalar("accuracy", accuracy, epoch) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} " f"({accuracy:.0f}%)\n" ) - model.printWeight() -def profiler(): - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("model_inference"): - output = model(data) - loss = criterion(output, target) - loss.backward() - optimizer.step() - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - if batch_idx > 10: - prof.export_chrome_trace("local.json") - assert False - - -# profiler() - for epoch in range(1, 300): train(epoch) test(epoch) # torch.save(model.state_dict(), "mnist_cnn.pth") print("Model saved to mnist_cnn.pth") - -if tbWriter: - tbWriter.close() +writer.close()