2025-05-20 14:07:10 +08:00
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.optim as optim
|
|
|
|
|
import torchvision
|
|
|
|
|
from torchvision import transforms
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
import math
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import numpy as np
|
2025-05-22 15:23:41 +08:00
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
2025-05-26 16:11:11 +08:00
|
|
|
|
from torch.profiler import profile, ProfilerActivity, record_function
|
2025-05-27 18:51:07 +08:00
|
|
|
|
from unfold import generate_unfold_index
|
2025-05-22 15:23:41 +08:00
|
|
|
|
import datetime
|
2025-05-20 14:07:10 +08:00
|
|
|
|
|
|
|
|
|
torch.manual_seed(1234)
|
|
|
|
|
np.random.seed(1234)
|
|
|
|
|
torch.cuda.manual_seed_all(1234)
|
|
|
|
|
|
2025-05-27 18:51:07 +08:00
|
|
|
|
BS = 16
|
|
|
|
|
LR = 0.001
|
2025-05-20 14:07:10 +08:00
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
print(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
|
transform = transforms.Compose(
|
|
|
|
|
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # MNIST数据集的均值和标准差
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
|
|
|
|
|
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
|
|
|
|
|
|
2025-05-27 18:51:07 +08:00
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True)
|
|
|
|
|
test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False)
|
2025-05-20 14:07:10 +08:00
|
|
|
|
|
|
|
|
|
|
2025-05-21 11:29:15 +08:00
|
|
|
|
class Lut(torch.autograd.Function):
|
2025-05-28 16:05:34 +08:00
|
|
|
|
# input [batch, group, bits ]
|
|
|
|
|
# output [batch, group ]
|
|
|
|
|
# weight [2**bits, group ]
|
2025-05-20 14:07:10 +08:00
|
|
|
|
@staticmethod
|
2025-05-28 16:05:34 +08:00
|
|
|
|
def forward(ctx, input, weight, index):
|
2025-05-26 16:11:11 +08:00
|
|
|
|
ind = ((input > 0).long() * index).sum(dim=-1)
|
2025-05-28 16:05:34 +08:00
|
|
|
|
output = torch.gather(weight, 0, ind)
|
|
|
|
|
ctx.save_for_backward(input, weight, ind)
|
2025-05-20 14:07:10 +08:00
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, grad_output):
|
2025-05-28 16:05:34 +08:00
|
|
|
|
input, weight, ind = ctx.saved_tensors
|
2025-05-20 14:07:10 +08:00
|
|
|
|
grad_input = grad_weight = None
|
|
|
|
|
bits = input.shape[2]
|
|
|
|
|
|
|
|
|
|
if ctx.needs_input_grad[1]:
|
|
|
|
|
grad_weight = torch.zeros_like(weight)
|
2025-05-28 16:05:34 +08:00
|
|
|
|
grad_weight.scatter_add_(0, ind, grad_output)
|
2025-05-20 14:07:10 +08:00
|
|
|
|
|
|
|
|
|
if ctx.needs_input_grad[0]:
|
2025-05-28 16:05:34 +08:00
|
|
|
|
grad_input = grad_output * torch.gather(weight, 0, ind)
|
2025-05-20 14:07:10 +08:00
|
|
|
|
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
|
|
|
|
|
2025-05-28 16:05:34 +08:00
|
|
|
|
return grad_input, grad_weight, None
|
2025-05-20 14:07:10 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleCNN(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(SimpleCNN, self).__init__()
|
|
|
|
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
|
|
|
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
|
|
|
|
self.bn = nn.BatchNorm1d(320 * 4)
|
|
|
|
|
self.fc1 = nn.Linear(160, 50)
|
|
|
|
|
self.fc2 = nn.Linear(50, 10)
|
|
|
|
|
self.pool = nn.MaxPool2d(2)
|
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
|
self.weight = nn.Parameter(torch.randn(160, pow(2, 8)))
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.relu(self.pool(self.conv1(x)))
|
|
|
|
|
x = self.relu((self.conv2(x)))
|
|
|
|
|
x = x.view(-1, 320 * 4)
|
|
|
|
|
x = self.bn(x)
|
|
|
|
|
x = x.view(-1, 160, 8)
|
|
|
|
|
|
2025-05-21 11:29:15 +08:00
|
|
|
|
x = Lut.apply(x, self.weight)
|
2025-05-20 14:07:10 +08:00
|
|
|
|
|
|
|
|
|
x = self.relu(self.fc1(x))
|
|
|
|
|
x = self.fc2(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LutGroup(nn.Module):
|
2025-05-28 16:05:34 +08:00
|
|
|
|
def __init__(self, group, groupBits, groupRepeat=1):
|
|
|
|
|
assert groupBits > 1
|
2025-05-20 14:07:10 +08:00
|
|
|
|
super(LutGroup, self).__init__()
|
2025-05-28 16:05:34 +08:00
|
|
|
|
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
|
|
|
|
|
self.group = group
|
2025-05-22 15:23:41 +08:00
|
|
|
|
self.groupBits = groupBits
|
2025-05-28 16:05:34 +08:00
|
|
|
|
self.groupRepeat = groupRepeat
|
|
|
|
|
self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False)
|
2025-05-20 14:07:10 +08:00
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
2025-05-28 16:05:34 +08:00
|
|
|
|
# input [ batch, group * groupBits ]
|
|
|
|
|
# output [ batch, group * groupRepeat ]
|
2025-05-21 11:29:15 +08:00
|
|
|
|
batch = x.shape[0]
|
2025-05-22 15:23:41 +08:00
|
|
|
|
x = x.view(batch, -1, self.groupBits)
|
2025-05-28 16:05:34 +08:00
|
|
|
|
if self.groupRepeat > 1:
|
|
|
|
|
x = x.repeat(1, self.groupRepeat, 1)
|
|
|
|
|
x = Lut.apply(x, self.weight, self.index)
|
2025-05-21 11:29:15 +08:00
|
|
|
|
return x
|
2025-05-20 14:07:10 +08:00
|
|
|
|
|
|
|
|
|
|
2025-05-27 18:51:07 +08:00
|
|
|
|
class LutCnn(nn.Module):
|
|
|
|
|
def __init__(self, output_c, input_shape, kernel_size, stride, dilation):
|
|
|
|
|
super(LutCnn, self).__init__()
|
2025-05-28 16:05:34 +08:00
|
|
|
|
B, C, H, W = input_shape
|
2025-05-27 18:51:07 +08:00
|
|
|
|
self.input_shape = input_shape
|
|
|
|
|
self.kernel_size = kernel_size
|
|
|
|
|
self.stride = stride
|
|
|
|
|
self.dilation = dilation
|
|
|
|
|
batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation)
|
|
|
|
|
self.batch_idx = nn.Parameter(batch_idx, requires_grad=False)
|
|
|
|
|
self.channel_idx = nn.Parameter(channel_idx, requires_grad=False)
|
|
|
|
|
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
|
|
|
|
|
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
|
2025-05-28 16:05:34 +08:00
|
|
|
|
groupBits = kernel_size * kernel_size
|
|
|
|
|
self.lut = LutGroup(len(self.batch_idx) / B / groupBits, groupBits, output_c)
|
2025-05-27 18:51:07 +08:00
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, C, H, W = self.input_shape
|
|
|
|
|
x = x.view(self.input_shape)
|
|
|
|
|
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
|
|
|
|
|
x = x.view(B, -1, self.kernel_size * self.kernel_size)
|
|
|
|
|
x = self.lut(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
2025-05-20 14:07:10 +08:00
|
|
|
|
class SimpleBNN(nn.Module):
|
|
|
|
|
def __init__(self):
|
2025-05-21 11:29:15 +08:00
|
|
|
|
super(SimpleBNN, self).__init__()
|
|
|
|
|
# self.w = nn.Parameter(torch.randn(3, 784 * 8))
|
|
|
|
|
# self.b = nn.Parameter(torch.zeros(3, 784 * 8))
|
|
|
|
|
|
|
|
|
|
self.w = nn.Parameter(torch.randn(3, 784))
|
|
|
|
|
self.b = nn.Parameter(torch.zeros(3, 784))
|
|
|
|
|
|
2025-05-27 18:51:07 +08:00
|
|
|
|
# output_c, input_shape, kernel_size, stride, dilation
|
|
|
|
|
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1)
|
|
|
|
|
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1)
|
|
|
|
|
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1)
|
|
|
|
|
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1)
|
|
|
|
|
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
|
2025-05-21 11:29:15 +08:00
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
|
|
|
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
|
|
|
|
self.fc1 = nn.Linear(320, 50)
|
|
|
|
|
self.fc2 = nn.Linear(50, 10)
|
|
|
|
|
self.pool = nn.MaxPool2d(2)
|
|
|
|
|
self.relu = nn.ReLU()
|
2025-05-20 14:07:10 +08:00
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
batch = x.shape[0]
|
2025-05-27 18:51:07 +08:00
|
|
|
|
# x = x.view(batch, -1)
|
2025-05-20 14:07:10 +08:00
|
|
|
|
|
|
|
|
|
# 变换x [-0.5:0.5] 到 0-255,然后按照二进制展开成8个值
|
2025-05-21 11:29:15 +08:00
|
|
|
|
# x = (x * 256 + 128).clamp(0, 255).to(torch.uint8)
|
|
|
|
|
# xx = torch.arange(7, -1, -1).to(x.device)
|
|
|
|
|
# bits = (x.unsqueeze(-1) >> xx) & 1
|
|
|
|
|
# x = bits.view(batch, -1)
|
|
|
|
|
# x = x.float() - 0.5
|
|
|
|
|
|
2025-05-22 15:23:41 +08:00
|
|
|
|
# x = (x > 0).float()
|
2025-05-21 11:29:15 +08:00
|
|
|
|
|
2025-05-22 15:23:41 +08:00
|
|
|
|
# q = x * self.w[0] + self.b[0]
|
|
|
|
|
# k = x * self.w[1] + self.b[1]
|
|
|
|
|
# v = x * self.w[2] + self.b[2]
|
|
|
|
|
# q = q.view(batch, -1, 1)
|
|
|
|
|
# k = k.view(batch, 1, -1)
|
|
|
|
|
# v = v.view(batch, -1, 1)
|
|
|
|
|
# kq = q @ k
|
|
|
|
|
# kqv = kq @ v
|
|
|
|
|
# kqv = kqv.view(batch, -1, 8)
|
|
|
|
|
# x = kqv
|
2025-05-21 11:29:15 +08:00
|
|
|
|
|
|
|
|
|
#########################
|
|
|
|
|
|
2025-05-22 15:23:41 +08:00
|
|
|
|
# # x = (x > 0) << xx
|
|
|
|
|
# # x = x.sum(2)
|
|
|
|
|
# # x = x.view(batch, 1, 28, 28)
|
|
|
|
|
# # x = (x - 128.0) / 256.0
|
2025-05-21 11:29:15 +08:00
|
|
|
|
|
|
|
|
|
# x = (x > 0).float()
|
2025-05-22 15:23:41 +08:00
|
|
|
|
# x = x.view(batch, 1, 28, 28)
|
2025-05-21 11:29:15 +08:00
|
|
|
|
|
|
|
|
|
# x = self.relu(self.pool(self.conv1(x)))
|
|
|
|
|
# x = self.relu(self.pool((self.conv2(x))))
|
|
|
|
|
# x = x.view(-1, 320)
|
|
|
|
|
# x = self.relu(self.fc1(x))
|
|
|
|
|
# x = self.fc2(x)
|
|
|
|
|
|
|
|
|
|
#########################
|
|
|
|
|
|
2025-05-27 18:51:07 +08:00
|
|
|
|
x = self.lnn1(x)
|
|
|
|
|
x = self.lnn2(x)
|
|
|
|
|
x = self.lnn3(x)
|
|
|
|
|
x = self.lnn4(x)
|
|
|
|
|
x = self.lnn5(x)
|
2025-05-22 15:23:41 +08:00
|
|
|
|
|
|
|
|
|
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
|
|
|
|
|
x = x.view(batch, -1, 8)
|
|
|
|
|
# x = x * xx
|
|
|
|
|
# x = (x - 128.0) / 256.0
|
|
|
|
|
x = x.sum(2)
|
2025-05-21 11:29:15 +08:00
|
|
|
|
|
2025-05-20 14:07:10 +08:00
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.autograd.set_detect_anomaly(True)
|
2025-05-21 11:29:15 +08:00
|
|
|
|
# model = SimpleCNN().to(device)
|
|
|
|
|
model = SimpleBNN().to(device)
|
2025-05-20 14:07:10 +08:00
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
2025-05-27 18:51:07 +08:00
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
2025-05-22 15:23:41 +08:00
|
|
|
|
|
2025-05-26 16:11:11 +08:00
|
|
|
|
tbWriter = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def AddScalar(tag, value, epoch):
|
|
|
|
|
global tbWriter
|
|
|
|
|
if not tbWriter:
|
|
|
|
|
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
|
|
|
|
|
tbWriter = SummaryWriter(f"log/{current_time}")
|
2025-05-27 18:51:07 +08:00
|
|
|
|
hparam_dict = {"lr": LR, "batch_size": BS}
|
2025-05-26 16:11:11 +08:00
|
|
|
|
tbWriter.add_hparams(hparam_dict, {}, run_name=f"./")
|
|
|
|
|
|
|
|
|
|
tbWriter.add_scalar(tag, value, epoch)
|
2025-05-20 14:07:10 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(epoch):
|
|
|
|
|
model.train()
|
|
|
|
|
for batch_idx, (data, target) in enumerate(train_loader):
|
|
|
|
|
data, target = data.to(device), target.to(device)
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
output = model(data)
|
|
|
|
|
loss = criterion(output, target)
|
|
|
|
|
loss.backward()
|
|
|
|
|
optimizer.step()
|
2025-05-26 16:11:11 +08:00
|
|
|
|
AddScalar("loss", loss, epoch)
|
|
|
|
|
if batch_idx % 512 == 0 and batch_idx > 0:
|
2025-05-20 14:07:10 +08:00
|
|
|
|
print(
|
|
|
|
|
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
|
|
|
|
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-05-22 15:23:41 +08:00
|
|
|
|
def test(epoch):
|
2025-05-20 14:07:10 +08:00
|
|
|
|
model.eval()
|
|
|
|
|
test_loss = 0
|
|
|
|
|
correct = 0
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
for data, target in test_loader:
|
|
|
|
|
data, target = data.to(device), target.to(device)
|
|
|
|
|
output = model(data)
|
|
|
|
|
test_loss += criterion(output, target).item()
|
|
|
|
|
pred = output.argmax(dim=1, keepdim=True)
|
|
|
|
|
correct += pred.eq(target.view_as(pred)).sum().item()
|
|
|
|
|
|
|
|
|
|
test_loss /= len(test_loader.dataset)
|
|
|
|
|
accuracy = 100.0 * correct / len(test_loader.dataset)
|
2025-05-26 16:11:11 +08:00
|
|
|
|
AddScalar("accuracy", accuracy, epoch)
|
2025-05-20 14:07:10 +08:00
|
|
|
|
print(
|
|
|
|
|
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
|
|
|
|
|
f"({accuracy:.0f}%)\n"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-05-26 16:11:11 +08:00
|
|
|
|
def profiler():
|
|
|
|
|
for batch_idx, (data, target) in enumerate(train_loader):
|
|
|
|
|
data, target = data.to(device), target.to(device)
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
|
|
|
|
|
with record_function("model_inference"):
|
|
|
|
|
output = model(data)
|
|
|
|
|
loss = criterion(output, target)
|
|
|
|
|
loss.backward()
|
|
|
|
|
optimizer.step()
|
2025-05-28 16:05:34 +08:00
|
|
|
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
2025-05-27 18:51:07 +08:00
|
|
|
|
if batch_idx > 10:
|
2025-05-28 16:05:34 +08:00
|
|
|
|
prof.export_chrome_trace("local.json")
|
2025-05-27 18:51:07 +08:00
|
|
|
|
assert False
|
2025-05-26 16:11:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# profiler()
|
|
|
|
|
|
2025-05-22 15:23:41 +08:00
|
|
|
|
for epoch in range(1, 300):
|
2025-05-20 14:07:10 +08:00
|
|
|
|
train(epoch)
|
2025-05-22 15:23:41 +08:00
|
|
|
|
test(epoch)
|
2025-05-20 14:07:10 +08:00
|
|
|
|
|
|
|
|
|
# torch.save(model.state_dict(), "mnist_cnn.pth")
|
|
|
|
|
print("Model saved to mnist_cnn.pth")
|
2025-05-26 16:11:11 +08:00
|
|
|
|
|
|
|
|
|
if tbWriter:
|
|
|
|
|
tbWriter.close()
|