Update binary mnist.

This commit is contained in:
Colin 2025-05-21 11:29:15 +08:00
parent 9194595716
commit f1124bc3b1
1 changed files with 80 additions and 82 deletions

View File

@ -19,6 +19,8 @@ torch.cuda.manual_seed_all(1234)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
# torch.set_num_threads(16)
print(f"Using device: {device}") print(f"Using device: {device}")
transform = transforms.Compose( transform = transforms.Compose(
@ -39,7 +41,9 @@ def to_binary_tensor(input_tensor, bits):
return binary_bits return binary_bits
class MyLut(torch.autograd.Function): class Lut(torch.autograd.Function):
# input [batch,count,bits]
# weight [count,2**bits]
@staticmethod @staticmethod
def forward(ctx, input, weight): def forward(ctx, input, weight):
batch = input.shape[0] batch = input.shape[0]
@ -100,131 +104,125 @@ class SimpleCNN(nn.Module):
x = self.bn(x) x = self.bn(x)
x = x.view(-1, 160, 8) x = x.view(-1, 160, 8)
x = MyLut.apply(x, self.weight) x = Lut.apply(x, self.weight)
x = self.relu(self.fc1(x)) x = self.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
return x return x
class Lut(nn.Module):
def __init__(self, bits):
super(Lut, self).__init__()
self.weight = nn.Parameter(torch.randn(pow(2, bits)))
self.bias = nn.Parameter(torch.randn(pow(2, bits)))
self.index = torch.pow(2, torch.arange(bits))
self.bits = bits
def forward(self, x):
x = MyLut.apply(x, self.weight, self.bias)
# tmp = torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
# x = tmp + x - x.detach()
xx = (x > 0).float()
x = xx + (x - x.detach())
# print(xx.requires_grad)
# print(xx.grad_fn)
x = x * (self.index.to(x.device))
x = torch.sum(x, dim=-1)
w = torch.gather(self.weight, 0, x.long())
b = torch.gather(self.weight, 0, x.long())
x = w * x + b
# tmp = torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
# x = tmp + x - x.detach()
xx = (x > 0).float()
x = xx + (x - x.detach())
x = x.view(-1, 1)
return x
class LutGroup(nn.Module): class LutGroup(nn.Module):
def __init__(self, bits, subbits): def __init__(self, bits, subbits):
super(LutGroup, self).__init__() super(LutGroup, self).__init__()
assert (bits % subbits) == 0 assert (bits % subbits) == 0
self.lutlist = nn.ModuleList([Lut(subbits) for _ in range(int(bits / subbits))]) self.weight = nn.Parameter(torch.randn(bits, pow(2, subbits)))
self.bits = bits self.bits = bits
self.subbits = subbits self.subbits = subbits
def forward(self, x): def forward(self, x):
ll = len(self.lutlist) batch = x.shape[0]
tmp = torch.empty((x.shape[0], 0), dtype=x.dtype, device=x.device) x = x.view(batch, -1, self.subbits)
start = 0 x = Lut.apply(x, self.weight)
end = self.subbits return x
for i in range(ll):
tx = self.lutlist[i](x[:, start:end])
tmp = torch.cat((tmp, tx), dim=1)
start += self.subbits
end += self.subbits
return tmp
class LutParallel(nn.Module): class LutParallel(nn.Module):
def __init__(self, bits, number): def __init__(self, bits, number):
super(LutParallel, self).__init__() super(LutParallel, self).__init__()
self.lutlist = nn.ModuleList([Lut(bits) for _ in range(number)])
self.bits = bits
self.number = number self.number = number
self.weight = nn.Parameter(torch.randn(number, pow(2, bits)))
def forward(self, x): def forward(self, x):
tmp = torch.empty((x.shape[0], 0), dtype=x.dtype, device=x.device) x = x.unsqueeze(1).repeat(1, self.number, 1)
for i in range(self.number): x = Lut.apply(x, self.weight)
tx = self.lutlist[i](x) return x
tmp = torch.cat((tmp, tx), dim=1)
return tmp
class SimpleBNN(nn.Module): class SimpleBNN(nn.Module):
def __init__(self): def __init__(self):
super(SimpleCNN, self).__init__() super(SimpleBNN, self).__init__()
self.w = nn.Parameter(torch.randn(3, 784 * 8)) # self.w = nn.Parameter(torch.randn(3, 784 * 8))
self.b = nn.Parameter(torch.zeros(3, 784 * 8)) # self.b = nn.Parameter(torch.zeros(3, 784 * 8))
self.w = nn.Parameter(torch.randn(3, 784))
self.b = nn.Parameter(torch.zeros(3, 784))
self.lut1 = LutGroup(784 * 8, 8) self.lut1 = LutGroup(784 * 8, 8)
self.lut2 = LutGroup(784, 8) self.lut2 = LutGroup(784, 4)
self.lut3 = LutGroup(98, 14) self.lut3 = LutGroup(196, 4)
self.lut4 = LutParallel(7, 10) self.lut4 = LutGroup(49, 7)
self.lut5 = LutParallel(7, 10)
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.pool = nn.MaxPool2d(2)
self.relu = nn.ReLU()
def forward(self, x): def forward(self, x):
batch = x.shape[0] batch = x.shape[0]
x = x.view(batch, -1) x = x.view(batch, -1)
# 变换x [-0.5:0.5] 到 0-255然后按照二进制展开成8个值 # 变换x [-0.5:0.5] 到 0-255然后按照二进制展开成8个值
x = (x * 256 + 128).clamp(0, 255).to(torch.uint8) # x = (x * 256 + 128).clamp(0, 255).to(torch.uint8)
xx = torch.arange(8).to(x.device) # xx = torch.arange(7, -1, -1).to(x.device)
bits = (x.unsqueeze(-1) >> xx) & 1 # bits = (x.unsqueeze(-1) >> xx) & 1
bits = bits.view(batch, -1) # x = bits.view(batch, -1)
# x = x.float() - 0.5
x = bits x = (x > 0).float()
# q = x * self.w[0] + self.b[0]
# k = x * self.w[1] + self.b[1]
# v = x * self.w[2] + self.b[2]
# q = q.view(batch, -1, 1) q = x * self.w[0] + self.b[0]
# k = k.view(batch, 1, -1) k = x * self.w[1] + self.b[1]
# v = v.view(batch, -1, 1) v = x * self.w[2] + self.b[2]
# kq = q @ k q = q.view(batch, -1, 1)
# kqv = kq @ v k = k.view(batch, 1, -1)
# kqv = kqv.view(batch, -1) v = v.view(batch, -1, 1)
kqv = x kq = q @ k
kqv = kq @ v
kqv = kqv.view(batch, -1, 8)
x = kqv
x = self.lut1(kqv) #########################
# x = (x > 0) << xx
# x = x.sum(2)
# x = x.view(batch, 1, 28, 28)
# x = (x - 128.0) / 256.0
# x = x.view(batch, 1, 28, 28)
# x = (x > 0).float()
# x = self.relu(self.pool(self.conv1(x)))
# x = self.relu(self.pool((self.conv2(x))))
# x = x.view(-1, 320)
# x = self.relu(self.fc1(x))
# x = self.fc2(x)
#########################
x = (x > 0).float()
x = x.view(batch, 196, 4)
# x = self.lut1(x)
x = self.lut2(x) x = self.lut2(x)
x = x.view(-1, 28, 7)
x = x.permute(0, 2, 1)
x = x.reshape(-1, 28 * 7)
x = self.lut3(x) x = self.lut3(x)
x = self.lut4(x) x = self.lut4(x)
x = self.lut5(x)
return x return x
torch.autograd.set_detect_anomaly(True) torch.autograd.set_detect_anomaly(True)
model = SimpleCNN().to(device) # model = SimpleCNN().to(device)
# model = SimpleBNN().to(device) model = SimpleBNN().to(device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
def train(epoch): def train(epoch):