Refine LUTCNN, keep accuracy to ~93

This commit is contained in:
Colin 2025-06-09 15:56:03 +08:00
parent 5d03634595
commit 878c690ac4
1 changed files with 104 additions and 115 deletions

View File

@ -13,14 +13,16 @@ import math
import torch.nn.functional as F
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch.profiler import profile, ProfilerActivity, record_function
from unfold import generate_unfold_index
import datetime
torch.manual_seed(1234)
np.random.seed(1234)
torch.cuda.manual_seed_all(1234)
batch_size = 16
lr = 0.001
BS = 16
LR = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
@ -32,59 +34,38 @@ transform = transforms.Compose(
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
def to_binary_tensor(input_tensor, bits):
int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64)
shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device)
binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1
return binary_bits
train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True, drop_last=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=True)
class Lut(torch.autograd.Function):
# input [batch,count,bits]
# weight [count,2**bits]
# input [batch, group, bits ]
# output [batch, group ]
# weight [2**bits, group ]
@staticmethod
def forward(ctx, input, weight):
batch = input.shape[0]
count = input.shape[1]
bits = input.shape[2]
assert int(math.log2(weight.shape[-1])) == bits
index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device)
x = (input > 0).long()
x = x * index
ind = x.sum(dim=-1)
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
output = weight[row_indices, ind]
ctx.save_for_backward(input, weight, ind)
def forward(ctx, input, weight, index):
ind = ((input > 0).long() * index).sum(dim=-1)
output = torch.gather(weight, 0, ind)
# output = (output > 0).float()
# output = (output - 0.5) * 2.0
ctx.save_for_backward(input, weight, ind, output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, ind = ctx.saved_tensors
input, weight, ind, output = ctx.saved_tensors
grad_input = grad_weight = None
batch = input.shape[0]
count = input.shape[1]
bits = input.shape[2]
if ctx.needs_input_grad[1]:
grad_weight = torch.zeros_like(weight)
ind_p = ind.permute(1, 0)
grad_output_p = grad_output.permute(1, 0)
grad_weight.scatter_add_(1, ind_p, grad_output_p)
grad_weight.scatter_add_(0, ind, grad_output)
if ctx.needs_input_grad[0]:
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
grad_input = grad_output * weight[row_indices, ind]
grad_input = grad_output * torch.gather(weight, 0, ind)
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
return grad_input, grad_weight
return grad_input, grad_weight, None, None
class SimpleCNN(nn.Module):
@ -114,23 +95,57 @@ class SimpleCNN(nn.Module):
class LutGroup(nn.Module):
def __init__(self, totalBits, groupBits=0, repeat=1):
if not groupBits:
groupBits = totalBits
def __init__(self, group, groupBits, groupRepeat=1):
assert groupBits > 1
super(LutGroup, self).__init__()
assert (totalBits % groupBits) == 0
self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits)))
self.totalBits = totalBits
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
self.group = group
self.groupBits = groupBits
self.repeat = repeat
self.groupRepeat = groupRepeat
self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False)
def forward(self, x):
# input [ batch, totalBits ]
# output [ batch, totalBits / groupBits * repeat ]
# input [ batch, group * groupBits ]
# output [ batch, group * groupRepeat ]
batch = x.shape[0]
x = x.view(batch, -1, self.groupBits)
x = x.repeat(1, self.repeat, 1)
x = Lut.apply(x, self.weight)
x = x.reshape(batch, -1, self.groupBits)
if self.groupRepeat > 1:
x = x.repeat(1, self.groupRepeat, 1)
x = Lut.apply(x, self.weight, self.index)
return x
class LutCnn(nn.Module):
def __init__(self, channel_repeat, input_shape, kernel_size, stride, dilation, fc=False):
super(LutCnn, self).__init__()
B, C, H, W = input_shape
self.input_shape = input_shape
self.kernel_size = kernel_size
self.channel_repeat = channel_repeat
self.stride = stride
self.dilation = dilation
batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation)
self.batch_idx = nn.Parameter(batch_idx, requires_grad=False)
self.channel_idx = nn.Parameter(channel_idx, requires_grad=False)
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
groupBits = kernel_size * kernel_size
group = int(len(self.batch_idx) / B / groupBits)
self.lut = LutGroup(group, groupBits, channel_repeat)
self.fc = fc
if fc:
self.lutc = LutGroup(group, channel_repeat * C, channel_repeat * C)
def forward(self, x):
B, C, H, W = self.input_shape
x = x.view(self.input_shape)
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
x = x.view(B, -1, self.kernel_size * self.kernel_size)
x = self.lut(x)
if self.fc:
x = x.view(B, self.channel_repeat * C, -1)
x = x.permute(0, 2, 1)
x = self.lutc(x)
return x
@ -143,14 +158,12 @@ class SimpleBNN(nn.Module):
self.w = nn.Parameter(torch.randn(3, 784))
self.b = nn.Parameter(torch.zeros(3, 784))
self.lut2_p = LutGroup(784, 4, 80) # pool2 => 14*14*4
self.lut3_p = LutGroup(80 * 14 * 14, 4, 1) # pool2 => 7*7*4
self.lut4_p = LutGroup(80 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8
# self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8
self.lut5_p = LutGroup(80 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16
# self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16
self.lut6_p = LutGroup(80 * 3 * 3, 9, 1) # conv 3 => 128
# self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 80
# channel_repeat, input_shape, kernel_size, stride, dilation, fc
self.lnn1 = LutCnn(80, (BS, 1, 28, 28), 2, 2, 1, False)
self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False)
self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False)
self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False)
self.lnn5 = LutCnn(10, (BS, 80, 3, 3), 3, 1, 1)
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
@ -159,9 +172,9 @@ class SimpleBNN(nn.Module):
self.pool = nn.MaxPool2d(2)
self.relu = nn.ReLU()
def forward(self, x):
def forward(self, x, t):
batch = x.shape[0]
x = x.view(batch, -1)
# x = x.view(batch, -1)
# 变换x [-0.5:0.5] 到 0-255然后按照二进制展开成8个值
# x = (x * 256 + 128).clamp(0, 255).to(torch.uint8)
@ -201,43 +214,11 @@ class SimpleBNN(nn.Module):
#########################
x = x.view(batch, 1, 28, 28)
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
x = x.view(batch, -1)
x = self.lut2_p(x)
x = x.view(batch, 80, 14, 14)
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
x = x.view(batch, -1)
x = self.lut3_p(x)
x = x.view(batch, 80, 7, 7)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
x = self.lut4_p(x)
# x = x.view(batch, 8, 5, 5)
# x = x.permute(0, 2, 3, 1)
# x = x.reshape(batch, -1)
# x = self.lut4(x)
x = x.view(batch, 80, 5, 5)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
x = self.lut5_p(x)
# x = x.view(batch, 16, 3, 3)
# x = x.permute(0, 2, 3, 1)
# x = x.reshape(batch, -1)
# x = self.lut5(x)
x = x.view(batch, 80, 3, 3)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
x = x.view(batch, -1)
x = self.lut6_p(x)
# x = x.view(batch, 8, 16) # 16 channel 8 value/channel
# x = self.lut7_p(x) # 10 * 8 value/channel
x = self.lnn1(x)
x = self.lnn2(x)
x = self.lnn3(x)
x = self.lnn4(x)
x = self.lnn5(x)
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
x = x.view(batch, -1, 8)
@ -247,17 +228,29 @@ class SimpleBNN(nn.Module):
return x
def printWeight(self):
pass
torch.autograd.set_detect_anomaly(True)
# model = SimpleCNN().to(device)
model = SimpleBNN().to(device)
# model = SimpleLNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
writer = SummaryWriter(f"log/{current_time}")
hparam_dict = {"lr": lr, "batch_size": batch_size}
writer.add_hparams(hparam_dict, {}, run_name=f"./")
tbWriter = None
def AddScalar(tag, value, epoch):
global tbWriter
if not tbWriter:
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
tbWriter = SummaryWriter(f"log/{current_time}")
hparam_dict = {"lr": LR, "batch_size": BS}
tbWriter.add_hparams(hparam_dict, {}, run_name=f"./")
tbWriter.add_scalar(tag, value, epoch)
def train(epoch):
@ -265,19 +258,12 @@ def train(epoch):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
# output = output * 1.0
# output = F.softmax(output, dim=1)
# print(output.requires_grad)
# print(output.grad_fn)
output = model(data, target)
loss = criterion(output, target)
loss.backward()
optimizer.step()
writer.add_scalar("loss", loss, epoch)
if batch_idx % 512 == 0:
AddScalar("loss", loss, epoch)
if batch_idx % 1024 == 0 and batch_idx > 0:
print(
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
@ -291,18 +277,19 @@ def test(epoch):
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
output = model(data, target)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100.0 * correct / len(test_loader.dataset)
writer.add_scalar("accuracy", accuracy, epoch)
AddScalar("accuracy", accuracy, epoch)
print(
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
f"({accuracy:.0f}%)\n"
)
model.printWeight()
for epoch in range(1, 300):
@ -311,4 +298,6 @@ for epoch in range(1, 300):
# torch.save(model.state_dict(), "mnist_cnn.pth")
print("Model saved to mnist_cnn.pth")
writer.close()
if tbWriter:
tbWriter.close()