Refine LUT repeat from group to LUT.
This commit is contained in:
parent
c322ee8228
commit
6cb969ac3b
|
@ -39,35 +39,43 @@ test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=T
|
|||
|
||||
|
||||
class Lut(torch.autograd.Function):
|
||||
# input [batch, group, bits ]
|
||||
# output [batch, group ]
|
||||
# input [ batch, group * groupBits ]
|
||||
# output [ batch, group * groupRepeat ]
|
||||
# weight [2**bits, group ]
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, index):
|
||||
ind = ((input > 0).long() * index).sum(dim=-1)
|
||||
def forward(ctx, input, weight, index, groupBits, groupRepeat):
|
||||
batch = input.shape[0]
|
||||
x = input.reshape(batch, -1, groupBits)
|
||||
if groupRepeat > 1:
|
||||
x = x.repeat(1, groupRepeat, 1)
|
||||
|
||||
ind = ((x > 0).long() * index).sum(dim=-1)
|
||||
output = torch.gather(weight, 0, ind)
|
||||
output = (output > 0).float()
|
||||
output = (output - 0.5) * 2.0
|
||||
ctx.groupBits = groupBits
|
||||
ctx.groupRepeat = groupRepeat
|
||||
ctx.batch = batch
|
||||
ctx.save_for_backward(input, weight, ind)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, ind = ctx.saved_tensors
|
||||
groupBits = ctx.groupBits
|
||||
groupRepeat = ctx.groupRepeat
|
||||
batch = ctx.batch
|
||||
grad_input = grad_weight = None
|
||||
bits = input.shape[2]
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = torch.zeros_like(weight)
|
||||
grad_weight.scatter_add_(0, ind, grad_output)
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
|
||||
grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||
# grad_input = grad_output
|
||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||
# output = output.unsqueeze(-1).repeat(1, 1, bits)
|
||||
in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||
grad_input = grad_input.view(batch, -1, groupRepeat).sum(2)
|
||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits)
|
||||
grad_input = grad_input * in_sign
|
||||
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||
|
||||
|
@ -85,10 +93,7 @@ class Lut(torch.autograd.Function):
|
|||
|
||||
# print(in_sign[0].detach().cpu().numpy())
|
||||
# print(out_sign[0].detach().cpu().numpy())
|
||||
# print(grad_sign[0].detach().cpu().numpy())
|
||||
# print(grad_input[0].detach().cpu().numpy())
|
||||
|
||||
return grad_input, grad_weight, None
|
||||
return grad_input, grad_weight, None, None, None
|
||||
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
|
@ -121,7 +126,7 @@ class LutGroup(nn.Module):
|
|||
def __init__(self, group, groupBits, groupRepeat=1):
|
||||
assert groupBits > 1
|
||||
super(LutGroup, self).__init__()
|
||||
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
|
||||
self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group)))
|
||||
self.group = group
|
||||
self.groupBits = groupBits
|
||||
self.groupRepeat = groupRepeat
|
||||
|
@ -130,11 +135,7 @@ class LutGroup(nn.Module):
|
|||
def forward(self, x):
|
||||
# input [ batch, group * groupBits ]
|
||||
# output [ batch, group * groupRepeat ]
|
||||
batch = x.shape[0]
|
||||
x = x.reshape(batch, -1, self.groupBits)
|
||||
if self.groupRepeat > 1:
|
||||
x = x.repeat(1, self.groupRepeat, 1)
|
||||
x = Lut.apply(x, self.weight, self.index)
|
||||
x = Lut.apply(x, self.weight, self.index, self.groupBits, self.groupRepeat)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -186,7 +187,7 @@ class SimpleBNN(nn.Module):
|
|||
self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False)
|
||||
self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False)
|
||||
self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False)
|
||||
self.lnn5 = LutCnn(10, (BS, 80, 3, 3), 3, 1, 1)
|
||||
self.lnn5 = LutCnn(1, (BS, 80, 3, 3), 3, 1, 1)
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
|
@ -282,10 +283,10 @@ class SimpleLNN(nn.Module):
|
|||
print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
|
||||
print("=============================")
|
||||
print("=============================")
|
||||
# print("self.lutg2")
|
||||
# print(self.lutg2.weight.detach().cpu().numpy())
|
||||
# print("=============================")
|
||||
# print("=============================")
|
||||
print("self.lutg2")
|
||||
print(self.lutg2.weight.detach().cpu().numpy())
|
||||
print("=============================")
|
||||
print("=============================")
|
||||
|
||||
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
|
|
Loading…
Reference in New Issue