From f1124bc3b140296950c3e3fa743705d7b5ea1edb Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 21 May 2025 11:29:15 +0800 Subject: [PATCH] Update binary mnist. --- binary/mnist.py | 162 ++++++++++++++++++++++++------------------------ 1 file changed, 80 insertions(+), 82 deletions(-) diff --git a/binary/mnist.py b/binary/mnist.py index db6361b..d098e58 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -19,6 +19,8 @@ torch.cuda.manual_seed_all(1234) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = "cpu" +# torch.set_num_threads(16) print(f"Using device: {device}") transform = transforms.Compose( @@ -39,7 +41,9 @@ def to_binary_tensor(input_tensor, bits): return binary_bits -class MyLut(torch.autograd.Function): +class Lut(torch.autograd.Function): + # input [batch,count,bits] + # weight [count,2**bits] @staticmethod def forward(ctx, input, weight): batch = input.shape[0] @@ -100,131 +104,125 @@ class SimpleCNN(nn.Module): x = self.bn(x) x = x.view(-1, 160, 8) - x = MyLut.apply(x, self.weight) + x = Lut.apply(x, self.weight) x = self.relu(self.fc1(x)) x = self.fc2(x) return x -class Lut(nn.Module): - def __init__(self, bits): - super(Lut, self).__init__() - self.weight = nn.Parameter(torch.randn(pow(2, bits))) - self.bias = nn.Parameter(torch.randn(pow(2, bits))) - self.index = torch.pow(2, torch.arange(bits)) - self.bits = bits - - def forward(self, x): - - x = MyLut.apply(x, self.weight, self.bias) - - # tmp = torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) - # x = tmp + x - x.detach() - - xx = (x > 0).float() - x = xx + (x - x.detach()) - - # print(xx.requires_grad) - # print(xx.grad_fn) - x = x * (self.index.to(x.device)) - x = torch.sum(x, dim=-1) - - w = torch.gather(self.weight, 0, x.long()) - b = torch.gather(self.weight, 0, x.long()) - x = w * x + b - - # tmp = torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) - # x = tmp + x - x.detach() - - xx = (x > 0).float() - x = xx + (x - x.detach()) - - x = x.view(-1, 1) - return x - - class LutGroup(nn.Module): def __init__(self, bits, subbits): super(LutGroup, self).__init__() assert (bits % subbits) == 0 - self.lutlist = nn.ModuleList([Lut(subbits) for _ in range(int(bits / subbits))]) + self.weight = nn.Parameter(torch.randn(bits, pow(2, subbits))) self.bits = bits self.subbits = subbits def forward(self, x): - ll = len(self.lutlist) - tmp = torch.empty((x.shape[0], 0), dtype=x.dtype, device=x.device) - start = 0 - end = self.subbits - for i in range(ll): - tx = self.lutlist[i](x[:, start:end]) - tmp = torch.cat((tmp, tx), dim=1) - start += self.subbits - end += self.subbits - return tmp + batch = x.shape[0] + x = x.view(batch, -1, self.subbits) + x = Lut.apply(x, self.weight) + return x class LutParallel(nn.Module): def __init__(self, bits, number): super(LutParallel, self).__init__() - self.lutlist = nn.ModuleList([Lut(bits) for _ in range(number)]) - self.bits = bits self.number = number + self.weight = nn.Parameter(torch.randn(number, pow(2, bits))) def forward(self, x): - tmp = torch.empty((x.shape[0], 0), dtype=x.dtype, device=x.device) - for i in range(self.number): - tx = self.lutlist[i](x) - tmp = torch.cat((tmp, tx), dim=1) - return tmp + x = x.unsqueeze(1).repeat(1, self.number, 1) + x = Lut.apply(x, self.weight) + return x class SimpleBNN(nn.Module): def __init__(self): - super(SimpleCNN, self).__init__() - self.w = nn.Parameter(torch.randn(3, 784 * 8)) - self.b = nn.Parameter(torch.zeros(3, 784 * 8)) + super(SimpleBNN, self).__init__() + # self.w = nn.Parameter(torch.randn(3, 784 * 8)) + # self.b = nn.Parameter(torch.zeros(3, 784 * 8)) + + self.w = nn.Parameter(torch.randn(3, 784)) + self.b = nn.Parameter(torch.zeros(3, 784)) + self.lut1 = LutGroup(784 * 8, 8) - self.lut2 = LutGroup(784, 8) - self.lut3 = LutGroup(98, 14) - self.lut4 = LutParallel(7, 10) + self.lut2 = LutGroup(784, 4) + self.lut3 = LutGroup(196, 4) + self.lut4 = LutGroup(49, 7) + self.lut5 = LutParallel(7, 10) + + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + self.pool = nn.MaxPool2d(2) + self.relu = nn.ReLU() def forward(self, x): batch = x.shape[0] x = x.view(batch, -1) # 变换x [-0.5:0.5] 到 0-255,然后按照二进制展开成8个值 - x = (x * 256 + 128).clamp(0, 255).to(torch.uint8) - xx = torch.arange(8).to(x.device) - bits = (x.unsqueeze(-1) >> xx) & 1 - bits = bits.view(batch, -1) + # x = (x * 256 + 128).clamp(0, 255).to(torch.uint8) + # xx = torch.arange(7, -1, -1).to(x.device) + # bits = (x.unsqueeze(-1) >> xx) & 1 + # x = bits.view(batch, -1) + # x = x.float() - 0.5 - x = bits - # q = x * self.w[0] + self.b[0] - # k = x * self.w[1] + self.b[1] - # v = x * self.w[2] + self.b[2] + x = (x > 0).float() - # q = q.view(batch, -1, 1) - # k = k.view(batch, 1, -1) - # v = v.view(batch, -1, 1) - # kq = q @ k - # kqv = kq @ v - # kqv = kqv.view(batch, -1) - kqv = x + q = x * self.w[0] + self.b[0] + k = x * self.w[1] + self.b[1] + v = x * self.w[2] + self.b[2] + q = q.view(batch, -1, 1) + k = k.view(batch, 1, -1) + v = v.view(batch, -1, 1) + kq = q @ k + kqv = kq @ v + kqv = kqv.view(batch, -1, 8) + x = kqv - x = self.lut1(kqv) + ######################### + + # x = (x > 0) << xx + # x = x.sum(2) + # x = x.view(batch, 1, 28, 28) + # x = (x - 128.0) / 256.0 + + # x = x.view(batch, 1, 28, 28) + # x = (x > 0).float() + + # x = self.relu(self.pool(self.conv1(x))) + # x = self.relu(self.pool((self.conv2(x)))) + # x = x.view(-1, 320) + # x = self.relu(self.fc1(x)) + # x = self.fc2(x) + + ######################### + + x = (x > 0).float() + x = x.view(batch, 196, 4) + + # x = self.lut1(x) x = self.lut2(x) + + x = x.view(-1, 28, 7) + x = x.permute(0, 2, 1) + x = x.reshape(-1, 28 * 7) + x = self.lut3(x) x = self.lut4(x) + x = self.lut5(x) return x torch.autograd.set_detect_anomaly(True) -model = SimpleCNN().to(device) -# model = SimpleBNN().to(device) +# model = SimpleCNN().to(device) +model = SimpleBNN().to(device) criterion = nn.CrossEntropyLoss() -optimizer = torch.optim.SGD(model.parameters(), lr=0.01) +optimizer = torch.optim.SGD(model.parameters(), lr=0.001) def train(epoch):