279 lines
8.2 KiB
Python
279 lines
8.2 KiB
Python
|
import os
|
|||
|
|
|||
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|||
|
|
|||
|
|
|||
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
import torch.optim as optim
|
|||
|
import torchvision
|
|||
|
from torchvision import transforms
|
|||
|
from torch.utils.data import DataLoader
|
|||
|
import math
|
|||
|
import torch.nn.functional as F
|
|||
|
import numpy as np
|
|||
|
|
|||
|
torch.manual_seed(1234)
|
|||
|
np.random.seed(1234)
|
|||
|
torch.cuda.manual_seed_all(1234)
|
|||
|
|
|||
|
|
|||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||
|
print(f"Using device: {device}")
|
|||
|
|
|||
|
transform = transforms.Compose(
|
|||
|
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # MNIST数据集的均值和标准差
|
|||
|
)
|
|||
|
|
|||
|
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
|
|||
|
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
|
|||
|
|
|||
|
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
|
|||
|
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
|
|||
|
|
|||
|
|
|||
|
def to_binary_tensor(input_tensor, bits):
|
|||
|
int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64)
|
|||
|
shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device)
|
|||
|
binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1
|
|||
|
return binary_bits
|
|||
|
|
|||
|
|
|||
|
class MyLut(torch.autograd.Function):
|
|||
|
@staticmethod
|
|||
|
def forward(ctx, input, weight):
|
|||
|
batch = input.shape[0]
|
|||
|
count = input.shape[1]
|
|||
|
bits = input.shape[2]
|
|||
|
assert int(math.log2(weight.shape[-1])) == bits
|
|||
|
|
|||
|
index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device)
|
|||
|
x = (input > 0).long()
|
|||
|
x = x * index
|
|||
|
ind = x.sum(dim=-1)
|
|||
|
|
|||
|
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
|
|||
|
output = weight[row_indices, ind]
|
|||
|
|
|||
|
ctx.save_for_backward(input, weight, ind)
|
|||
|
return output
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def backward(ctx, grad_output):
|
|||
|
input, weight, ind = ctx.saved_tensors
|
|||
|
grad_input = grad_weight = None
|
|||
|
|
|||
|
batch = input.shape[0]
|
|||
|
count = input.shape[1]
|
|||
|
bits = input.shape[2]
|
|||
|
|
|||
|
if ctx.needs_input_grad[1]:
|
|||
|
grad_weight = torch.zeros_like(weight)
|
|||
|
ind_p = ind.permute(1, 0)
|
|||
|
grad_output_p = grad_output.permute(1, 0)
|
|||
|
grad_weight.scatter_add_(1, ind_p, grad_output_p)
|
|||
|
|
|||
|
if ctx.needs_input_grad[0]:
|
|||
|
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1)
|
|||
|
grad_input = grad_output * weight[row_indices, ind]
|
|||
|
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
|||
|
|
|||
|
return grad_input, grad_weight
|
|||
|
|
|||
|
|
|||
|
class SimpleCNN(nn.Module):
|
|||
|
def __init__(self):
|
|||
|
super(SimpleCNN, self).__init__()
|
|||
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
|||
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
|||
|
self.bn = nn.BatchNorm1d(320 * 4)
|
|||
|
self.fc1 = nn.Linear(160, 50)
|
|||
|
self.fc2 = nn.Linear(50, 10)
|
|||
|
self.pool = nn.MaxPool2d(2)
|
|||
|
self.relu = nn.ReLU()
|
|||
|
self.weight = nn.Parameter(torch.randn(160, pow(2, 8)))
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
x = self.relu(self.pool(self.conv1(x)))
|
|||
|
x = self.relu((self.conv2(x)))
|
|||
|
x = x.view(-1, 320 * 4)
|
|||
|
x = self.bn(x)
|
|||
|
x = x.view(-1, 160, 8)
|
|||
|
|
|||
|
x = MyLut.apply(x, self.weight)
|
|||
|
|
|||
|
x = self.relu(self.fc1(x))
|
|||
|
x = self.fc2(x)
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
class Lut(nn.Module):
|
|||
|
def __init__(self, bits):
|
|||
|
super(Lut, self).__init__()
|
|||
|
self.weight = nn.Parameter(torch.randn(pow(2, bits)))
|
|||
|
self.bias = nn.Parameter(torch.randn(pow(2, bits)))
|
|||
|
self.index = torch.pow(2, torch.arange(bits))
|
|||
|
self.bits = bits
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
|
|||
|
x = MyLut.apply(x, self.weight, self.bias)
|
|||
|
|
|||
|
# tmp = torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
|
|||
|
# x = tmp + x - x.detach()
|
|||
|
|
|||
|
xx = (x > 0).float()
|
|||
|
x = xx + (x - x.detach())
|
|||
|
|
|||
|
# print(xx.requires_grad)
|
|||
|
# print(xx.grad_fn)
|
|||
|
x = x * (self.index.to(x.device))
|
|||
|
x = torch.sum(x, dim=-1)
|
|||
|
|
|||
|
w = torch.gather(self.weight, 0, x.long())
|
|||
|
b = torch.gather(self.weight, 0, x.long())
|
|||
|
x = w * x + b
|
|||
|
|
|||
|
# tmp = torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
|
|||
|
# x = tmp + x - x.detach()
|
|||
|
|
|||
|
xx = (x > 0).float()
|
|||
|
x = xx + (x - x.detach())
|
|||
|
|
|||
|
x = x.view(-1, 1)
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
class LutGroup(nn.Module):
|
|||
|
def __init__(self, bits, subbits):
|
|||
|
super(LutGroup, self).__init__()
|
|||
|
assert (bits % subbits) == 0
|
|||
|
self.lutlist = nn.ModuleList([Lut(subbits) for _ in range(int(bits / subbits))])
|
|||
|
self.bits = bits
|
|||
|
self.subbits = subbits
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
ll = len(self.lutlist)
|
|||
|
tmp = torch.empty((x.shape[0], 0), dtype=x.dtype, device=x.device)
|
|||
|
start = 0
|
|||
|
end = self.subbits
|
|||
|
for i in range(ll):
|
|||
|
tx = self.lutlist[i](x[:, start:end])
|
|||
|
tmp = torch.cat((tmp, tx), dim=1)
|
|||
|
start += self.subbits
|
|||
|
end += self.subbits
|
|||
|
return tmp
|
|||
|
|
|||
|
|
|||
|
class LutParallel(nn.Module):
|
|||
|
def __init__(self, bits, number):
|
|||
|
super(LutParallel, self).__init__()
|
|||
|
self.lutlist = nn.ModuleList([Lut(bits) for _ in range(number)])
|
|||
|
self.bits = bits
|
|||
|
self.number = number
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
tmp = torch.empty((x.shape[0], 0), dtype=x.dtype, device=x.device)
|
|||
|
for i in range(self.number):
|
|||
|
tx = self.lutlist[i](x)
|
|||
|
tmp = torch.cat((tmp, tx), dim=1)
|
|||
|
return tmp
|
|||
|
|
|||
|
|
|||
|
class SimpleBNN(nn.Module):
|
|||
|
def __init__(self):
|
|||
|
super(SimpleCNN, self).__init__()
|
|||
|
self.w = nn.Parameter(torch.randn(3, 784 * 8))
|
|||
|
self.b = nn.Parameter(torch.zeros(3, 784 * 8))
|
|||
|
self.lut1 = LutGroup(784 * 8, 8)
|
|||
|
self.lut2 = LutGroup(784, 8)
|
|||
|
self.lut3 = LutGroup(98, 14)
|
|||
|
self.lut4 = LutParallel(7, 10)
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
batch = x.shape[0]
|
|||
|
x = x.view(batch, -1)
|
|||
|
|
|||
|
# 变换x [-0.5:0.5] 到 0-255,然后按照二进制展开成8个值
|
|||
|
x = (x * 256 + 128).clamp(0, 255).to(torch.uint8)
|
|||
|
xx = torch.arange(8).to(x.device)
|
|||
|
bits = (x.unsqueeze(-1) >> xx) & 1
|
|||
|
bits = bits.view(batch, -1)
|
|||
|
|
|||
|
x = bits
|
|||
|
# q = x * self.w[0] + self.b[0]
|
|||
|
# k = x * self.w[1] + self.b[1]
|
|||
|
# v = x * self.w[2] + self.b[2]
|
|||
|
|
|||
|
# q = q.view(batch, -1, 1)
|
|||
|
# k = k.view(batch, 1, -1)
|
|||
|
# v = v.view(batch, -1, 1)
|
|||
|
# kq = q @ k
|
|||
|
# kqv = kq @ v
|
|||
|
# kqv = kqv.view(batch, -1)
|
|||
|
kqv = x
|
|||
|
|
|||
|
x = self.lut1(kqv)
|
|||
|
x = self.lut2(x)
|
|||
|
x = self.lut3(x)
|
|||
|
x = self.lut4(x)
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
torch.autograd.set_detect_anomaly(True)
|
|||
|
model = SimpleCNN().to(device)
|
|||
|
# model = SimpleBNN().to(device)
|
|||
|
criterion = nn.CrossEntropyLoss()
|
|||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|||
|
|
|||
|
|
|||
|
def train(epoch):
|
|||
|
model.train()
|
|||
|
for batch_idx, (data, target) in enumerate(train_loader):
|
|||
|
data, target = data.to(device), target.to(device)
|
|||
|
optimizer.zero_grad()
|
|||
|
output = model(data)
|
|||
|
# output = output * 1.0
|
|||
|
# output = F.softmax(output, dim=1)
|
|||
|
|
|||
|
# print(output.requires_grad)
|
|||
|
# print(output.grad_fn)
|
|||
|
|
|||
|
loss = criterion(output, target)
|
|||
|
loss.backward()
|
|||
|
optimizer.step()
|
|||
|
|
|||
|
if batch_idx % 100 == 0:
|
|||
|
print(
|
|||
|
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
|||
|
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
def test():
|
|||
|
model.eval()
|
|||
|
test_loss = 0
|
|||
|
correct = 0
|
|||
|
with torch.no_grad():
|
|||
|
for data, target in test_loader:
|
|||
|
data, target = data.to(device), target.to(device)
|
|||
|
output = model(data)
|
|||
|
test_loss += criterion(output, target).item()
|
|||
|
pred = output.argmax(dim=1, keepdim=True)
|
|||
|
correct += pred.eq(target.view_as(pred)).sum().item()
|
|||
|
|
|||
|
test_loss /= len(test_loader.dataset)
|
|||
|
accuracy = 100.0 * correct / len(test_loader.dataset)
|
|||
|
print(
|
|||
|
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
|
|||
|
f"({accuracy:.0f}%)\n"
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
for epoch in range(1, 30):
|
|||
|
train(epoch)
|
|||
|
test()
|
|||
|
|
|||
|
# torch.save(model.state_dict(), "mnist_cnn.pth")
|
|||
|
print("Model saved to mnist_cnn.pth")
|