303 lines
9.4 KiB
Python
303 lines
9.4 KiB
Python
import os
|
||
|
||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
import torchvision
|
||
from torchvision import transforms
|
||
from torch.utils.data import DataLoader
|
||
import math
|
||
import torch.nn.functional as F
|
||
import numpy as np
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
import datetime
|
||
|
||
torch.manual_seed(1234)
|
||
np.random.seed(1234)
|
||
torch.cuda.manual_seed_all(1234)
|
||
|
||
batch_size = 16
|
||
lr = 0.001
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"Using device: {device}")
|
||
|
||
transform = transforms.Compose(
|
||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # MNIST数据集的均值和标准差
|
||
)
|
||
|
||
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
|
||
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
|
||
|
||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
|
||
|
||
|
||
def to_binary_tensor(input_tensor, bits):
|
||
int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64)
|
||
shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device)
|
||
binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1
|
||
return binary_bits
|
||
|
||
|
||
class Lut(torch.autograd.Function):
|
||
# input [batch,count,bits]
|
||
# weight [count,2**bits]
|
||
@staticmethod
|
||
def forward(ctx, input, weight):
|
||
batch = input.shape[0]
|
||
count = input.shape[1]
|
||
bits = input.shape[2]
|
||
assert int(math.log2(weight.shape[-1])) == bits
|
||
|
||
index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device)
|
||
x = (input > 0).long()
|
||
x = x * index
|
||
ind = x.sum(dim=-1)
|
||
|
||
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
|
||
output = weight[row_indices, ind]
|
||
|
||
ctx.save_for_backward(input, weight, ind)
|
||
return output
|
||
|
||
@staticmethod
|
||
def backward(ctx, grad_output):
|
||
input, weight, ind = ctx.saved_tensors
|
||
grad_input = grad_weight = None
|
||
|
||
batch = input.shape[0]
|
||
count = input.shape[1]
|
||
bits = input.shape[2]
|
||
|
||
if ctx.needs_input_grad[1]:
|
||
grad_weight = torch.zeros_like(weight)
|
||
ind_p = ind.permute(1, 0)
|
||
grad_output_p = grad_output.permute(1, 0)
|
||
grad_weight.scatter_add_(1, ind_p, grad_output_p)
|
||
|
||
if ctx.needs_input_grad[0]:
|
||
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
|
||
grad_input = grad_output * weight[row_indices, ind]
|
||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||
|
||
return grad_input, grad_weight
|
||
|
||
|
||
class SimpleCNN(nn.Module):
|
||
def __init__(self):
|
||
super(SimpleCNN, self).__init__()
|
||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||
self.bn = nn.BatchNorm1d(320 * 4)
|
||
self.fc1 = nn.Linear(160, 50)
|
||
self.fc2 = nn.Linear(50, 10)
|
||
self.pool = nn.MaxPool2d(2)
|
||
self.relu = nn.ReLU()
|
||
self.weight = nn.Parameter(torch.randn(160, pow(2, 8)))
|
||
|
||
def forward(self, x):
|
||
x = self.relu(self.pool(self.conv1(x)))
|
||
x = self.relu((self.conv2(x)))
|
||
x = x.view(-1, 320 * 4)
|
||
x = self.bn(x)
|
||
x = x.view(-1, 160, 8)
|
||
|
||
x = Lut.apply(x, self.weight)
|
||
|
||
x = self.relu(self.fc1(x))
|
||
x = self.fc2(x)
|
||
return x
|
||
|
||
|
||
class LutGroup(nn.Module):
|
||
def __init__(self, totalBits, groupBits=0, repeat=1):
|
||
if not groupBits:
|
||
groupBits = totalBits
|
||
super(LutGroup, self).__init__()
|
||
assert (totalBits % groupBits) == 0
|
||
self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits)))
|
||
self.totalBits = totalBits
|
||
self.groupBits = groupBits
|
||
self.repeat = repeat
|
||
|
||
def forward(self, x):
|
||
# input [ batch, totalBits ]
|
||
# output [ batch, totalBits / groupBits * repeat ]
|
||
batch = x.shape[0]
|
||
x = x.view(batch, -1, self.groupBits)
|
||
x = x.repeat(1, self.repeat, 1)
|
||
x = Lut.apply(x, self.weight)
|
||
return x
|
||
|
||
|
||
class SimpleBNN(nn.Module):
|
||
def __init__(self):
|
||
super(SimpleBNN, self).__init__()
|
||
# self.w = nn.Parameter(torch.randn(3, 784 * 8))
|
||
# self.b = nn.Parameter(torch.zeros(3, 784 * 8))
|
||
|
||
self.w = nn.Parameter(torch.randn(3, 784))
|
||
self.b = nn.Parameter(torch.zeros(3, 784))
|
||
|
||
self.lut2_p = LutGroup(784, 4, 2) # pool2 => 14*14*4
|
||
self.lut3_p = LutGroup(14 * 14 * 2, 4, 2) # pool2 => 7*7*4
|
||
self.lut4_p = LutGroup(3 * 3 * 5 * 5 * 4, 9, 2) # conv 3 => 5*5*8
|
||
# self.lut4 = LutGroup(5 * 5 * 8, 8, 8) # conv 3 => 5*5*8
|
||
self.lut5_p = LutGroup(3 * 3 * 3 * 3 * 8, 9, 2) # conv 3 => 3*3*8*2
|
||
# self.lut5 = LutGroup(3 * 3 * 8 * 2, 16, 16) # conv 3 => 3*3*16
|
||
self.lut6_p = LutGroup(3 * 3 * 16, 9, 5) # conv 3 => 80
|
||
|
||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||
self.fc1 = nn.Linear(320, 50)
|
||
self.fc2 = nn.Linear(50, 10)
|
||
self.pool = nn.MaxPool2d(2)
|
||
self.relu = nn.ReLU()
|
||
|
||
def forward(self, x):
|
||
batch = x.shape[0]
|
||
x = x.view(batch, -1)
|
||
|
||
# 变换x [-0.5:0.5] 到 0-255,然后按照二进制展开成8个值
|
||
# x = (x * 256 + 128).clamp(0, 255).to(torch.uint8)
|
||
# xx = torch.arange(7, -1, -1).to(x.device)
|
||
# bits = (x.unsqueeze(-1) >> xx) & 1
|
||
# x = bits.view(batch, -1)
|
||
# x = x.float() - 0.5
|
||
|
||
# x = (x > 0).float()
|
||
|
||
# q = x * self.w[0] + self.b[0]
|
||
# k = x * self.w[1] + self.b[1]
|
||
# v = x * self.w[2] + self.b[2]
|
||
# q = q.view(batch, -1, 1)
|
||
# k = k.view(batch, 1, -1)
|
||
# v = v.view(batch, -1, 1)
|
||
# kq = q @ k
|
||
# kqv = kq @ v
|
||
# kqv = kqv.view(batch, -1, 8)
|
||
# x = kqv
|
||
|
||
#########################
|
||
|
||
# # x = (x > 0) << xx
|
||
# # x = x.sum(2)
|
||
# # x = x.view(batch, 1, 28, 28)
|
||
# # x = (x - 128.0) / 256.0
|
||
|
||
# x = (x > 0).float()
|
||
# x = x.view(batch, 1, 28, 28)
|
||
|
||
# x = self.relu(self.pool(self.conv1(x)))
|
||
# x = self.relu(self.pool((self.conv2(x))))
|
||
# x = x.view(-1, 320)
|
||
# x = self.relu(self.fc1(x))
|
||
# x = self.fc2(x)
|
||
|
||
#########################
|
||
|
||
x = x.view(batch, 1, 28, 28)
|
||
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
||
x = x.view(batch, -1)
|
||
x = self.lut2_p(x)
|
||
|
||
x = x.view(batch, 2, 14, 14)
|
||
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
||
x = x.view(batch, -1)
|
||
x = self.lut3_p(x)
|
||
|
||
x = x.view(batch, 4, 7, 7)
|
||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||
x = x.view(batch, -1)
|
||
x = self.lut4_p(x)
|
||
# x = self.lut4(x)
|
||
|
||
x = x.view(batch, 8, 5, 5)
|
||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||
x = x.view(batch, -1)
|
||
x = self.lut5_p(x)
|
||
# x = self.lut5(x)
|
||
|
||
x = x.view(batch, 16, 3, 3)
|
||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||
x = x.view(batch, -1)
|
||
x = self.lut6_p(x)
|
||
|
||
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
|
||
x = x.view(batch, -1, 8)
|
||
# x = x * xx
|
||
# x = (x - 128.0) / 256.0
|
||
x = x.sum(2)
|
||
|
||
return x
|
||
|
||
|
||
torch.autograd.set_detect_anomaly(True)
|
||
# model = SimpleCNN().to(device)
|
||
model = SimpleBNN().to(device)
|
||
criterion = nn.CrossEntropyLoss()
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
||
|
||
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
|
||
writer = SummaryWriter(f"log/{current_time}")
|
||
hparam_dict = {"lr": lr, "batch_size": batch_size}
|
||
writer.add_hparams(hparam_dict, {}, run_name=f"./")
|
||
|
||
|
||
def train(epoch):
|
||
model.train()
|
||
for batch_idx, (data, target) in enumerate(train_loader):
|
||
data, target = data.to(device), target.to(device)
|
||
optimizer.zero_grad()
|
||
output = model(data)
|
||
# output = output * 1.0
|
||
# output = F.softmax(output, dim=1)
|
||
|
||
# print(output.requires_grad)
|
||
# print(output.grad_fn)
|
||
|
||
loss = criterion(output, target)
|
||
loss.backward()
|
||
optimizer.step()
|
||
writer.add_scalar("loss", loss, epoch)
|
||
|
||
if batch_idx % 512 == 0:
|
||
print(
|
||
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
||
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
|
||
)
|
||
|
||
|
||
def test(epoch):
|
||
model.eval()
|
||
test_loss = 0
|
||
correct = 0
|
||
with torch.no_grad():
|
||
for data, target in test_loader:
|
||
data, target = data.to(device), target.to(device)
|
||
output = model(data)
|
||
test_loss += criterion(output, target).item()
|
||
pred = output.argmax(dim=1, keepdim=True)
|
||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||
|
||
test_loss /= len(test_loader.dataset)
|
||
accuracy = 100.0 * correct / len(test_loader.dataset)
|
||
writer.add_scalar("accuracy", accuracy, epoch)
|
||
print(
|
||
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
|
||
f"({accuracy:.0f}%)\n"
|
||
)
|
||
|
||
|
||
for epoch in range(1, 300):
|
||
train(epoch)
|
||
test(epoch)
|
||
|
||
# torch.save(model.state_dict(), "mnist_cnn.pth")
|
||
print("Model saved to mnist_cnn.pth")
|
||
writer.close()
|