Refine LUT repeat from group to LUT.

This commit is contained in:
Colin 2025-06-10 20:03:23 +08:00
parent c322ee8228
commit 6cb969ac3b
1 changed files with 25 additions and 24 deletions

View File

@ -39,35 +39,43 @@ test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=T
class Lut(torch.autograd.Function):
# input [batch, group, bits ]
# output [batch, group ]
# input [ batch, group * groupBits ]
# output [ batch, group * groupRepeat ]
# weight [2**bits, group ]
@staticmethod
def forward(ctx, input, weight, index):
ind = ((input > 0).long() * index).sum(dim=-1)
def forward(ctx, input, weight, index, groupBits, groupRepeat):
batch = input.shape[0]
x = input.reshape(batch, -1, groupBits)
if groupRepeat > 1:
x = x.repeat(1, groupRepeat, 1)
ind = ((x > 0).long() * index).sum(dim=-1)
output = torch.gather(weight, 0, ind)
output = (output > 0).float()
output = (output - 0.5) * 2.0
ctx.groupBits = groupBits
ctx.groupRepeat = groupRepeat
ctx.batch = batch
ctx.save_for_backward(input, weight, ind)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, ind = ctx.saved_tensors
groupBits = ctx.groupBits
groupRepeat = ctx.groupRepeat
batch = ctx.batch
grad_input = grad_weight = None
bits = input.shape[2]
if ctx.needs_input_grad[1]:
grad_weight = torch.zeros_like(weight)
grad_weight.scatter_add_(0, ind, grad_output)
if ctx.needs_input_grad[0]:
grad_input = grad_output * torch.gather(weight, 0, ind)
# grad_input = grad_output
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
# output = output.unsqueeze(-1).repeat(1, 1, bits)
in_sign = ((input > 0).float() - 0.5) * 2.0
grad_input = grad_input.view(batch, -1, groupRepeat).sum(2)
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits)
grad_input = grad_input * in_sign
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
@ -85,10 +93,7 @@ class Lut(torch.autograd.Function):
# print(in_sign[0].detach().cpu().numpy())
# print(out_sign[0].detach().cpu().numpy())
# print(grad_sign[0].detach().cpu().numpy())
# print(grad_input[0].detach().cpu().numpy())
return grad_input, grad_weight, None
return grad_input, grad_weight, None, None, None
class SimpleCNN(nn.Module):
@ -121,7 +126,7 @@ class LutGroup(nn.Module):
def __init__(self, group, groupBits, groupRepeat=1):
assert groupBits > 1
super(LutGroup, self).__init__()
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group)))
self.group = group
self.groupBits = groupBits
self.groupRepeat = groupRepeat
@ -130,11 +135,7 @@ class LutGroup(nn.Module):
def forward(self, x):
# input [ batch, group * groupBits ]
# output [ batch, group * groupRepeat ]
batch = x.shape[0]
x = x.reshape(batch, -1, self.groupBits)
if self.groupRepeat > 1:
x = x.repeat(1, self.groupRepeat, 1)
x = Lut.apply(x, self.weight, self.index)
x = Lut.apply(x, self.weight, self.index, self.groupBits, self.groupRepeat)
return x
@ -186,7 +187,7 @@ class SimpleBNN(nn.Module):
self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False)
self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False)
self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False)
self.lnn5 = LutCnn(10, (BS, 80, 3, 3), 3, 1, 1)
self.lnn5 = LutCnn(1, (BS, 80, 3, 3), 3, 1, 1)
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
@ -282,10 +283,10 @@ class SimpleLNN(nn.Module):
print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
print("=============================")
print("=============================")
# print("self.lutg2")
# print(self.lutg2.weight.detach().cpu().numpy())
# print("=============================")
# print("=============================")
print("self.lutg2")
print(self.lutg2.weight.detach().cpu().numpy())
print("=============================")
print("=============================")
torch.autograd.set_detect_anomaly(True)