Refine LUT repeat from group to LUT.
This commit is contained in:
parent
c322ee8228
commit
6cb969ac3b
|
@ -39,35 +39,43 @@ test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=T
|
||||||
|
|
||||||
|
|
||||||
class Lut(torch.autograd.Function):
|
class Lut(torch.autograd.Function):
|
||||||
# input [batch, group, bits ]
|
# input [ batch, group * groupBits ]
|
||||||
# output [batch, group ]
|
# output [ batch, group * groupRepeat ]
|
||||||
# weight [2**bits, group ]
|
# weight [2**bits, group ]
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, weight, index):
|
def forward(ctx, input, weight, index, groupBits, groupRepeat):
|
||||||
ind = ((input > 0).long() * index).sum(dim=-1)
|
batch = input.shape[0]
|
||||||
|
x = input.reshape(batch, -1, groupBits)
|
||||||
|
if groupRepeat > 1:
|
||||||
|
x = x.repeat(1, groupRepeat, 1)
|
||||||
|
|
||||||
|
ind = ((x > 0).long() * index).sum(dim=-1)
|
||||||
output = torch.gather(weight, 0, ind)
|
output = torch.gather(weight, 0, ind)
|
||||||
output = (output > 0).float()
|
output = (output > 0).float()
|
||||||
output = (output - 0.5) * 2.0
|
output = (output - 0.5) * 2.0
|
||||||
|
ctx.groupBits = groupBits
|
||||||
|
ctx.groupRepeat = groupRepeat
|
||||||
|
ctx.batch = batch
|
||||||
ctx.save_for_backward(input, weight, ind)
|
ctx.save_for_backward(input, weight, ind)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, weight, ind = ctx.saved_tensors
|
input, weight, ind = ctx.saved_tensors
|
||||||
|
groupBits = ctx.groupBits
|
||||||
|
groupRepeat = ctx.groupRepeat
|
||||||
|
batch = ctx.batch
|
||||||
grad_input = grad_weight = None
|
grad_input = grad_weight = None
|
||||||
bits = input.shape[2]
|
|
||||||
|
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
grad_weight = torch.zeros_like(weight)
|
grad_weight = torch.zeros_like(weight)
|
||||||
grad_weight.scatter_add_(0, ind, grad_output)
|
grad_weight.scatter_add_(0, ind, grad_output)
|
||||||
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
|
|
||||||
grad_input = grad_output * torch.gather(weight, 0, ind)
|
grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||||
# grad_input = grad_output
|
|
||||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
|
||||||
# output = output.unsqueeze(-1).repeat(1, 1, bits)
|
|
||||||
in_sign = ((input > 0).float() - 0.5) * 2.0
|
in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||||
|
grad_input = grad_input.view(batch, -1, groupRepeat).sum(2)
|
||||||
|
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits)
|
||||||
grad_input = grad_input * in_sign
|
grad_input = grad_input * in_sign
|
||||||
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||||
|
|
||||||
|
@ -85,10 +93,7 @@ class Lut(torch.autograd.Function):
|
||||||
|
|
||||||
# print(in_sign[0].detach().cpu().numpy())
|
# print(in_sign[0].detach().cpu().numpy())
|
||||||
# print(out_sign[0].detach().cpu().numpy())
|
# print(out_sign[0].detach().cpu().numpy())
|
||||||
# print(grad_sign[0].detach().cpu().numpy())
|
return grad_input, grad_weight, None, None, None
|
||||||
# print(grad_input[0].detach().cpu().numpy())
|
|
||||||
|
|
||||||
return grad_input, grad_weight, None
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleCNN(nn.Module):
|
class SimpleCNN(nn.Module):
|
||||||
|
@ -121,7 +126,7 @@ class LutGroup(nn.Module):
|
||||||
def __init__(self, group, groupBits, groupRepeat=1):
|
def __init__(self, group, groupBits, groupRepeat=1):
|
||||||
assert groupBits > 1
|
assert groupBits > 1
|
||||||
super(LutGroup, self).__init__()
|
super(LutGroup, self).__init__()
|
||||||
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
|
self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group)))
|
||||||
self.group = group
|
self.group = group
|
||||||
self.groupBits = groupBits
|
self.groupBits = groupBits
|
||||||
self.groupRepeat = groupRepeat
|
self.groupRepeat = groupRepeat
|
||||||
|
@ -130,11 +135,7 @@ class LutGroup(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# input [ batch, group * groupBits ]
|
# input [ batch, group * groupBits ]
|
||||||
# output [ batch, group * groupRepeat ]
|
# output [ batch, group * groupRepeat ]
|
||||||
batch = x.shape[0]
|
x = Lut.apply(x, self.weight, self.index, self.groupBits, self.groupRepeat)
|
||||||
x = x.reshape(batch, -1, self.groupBits)
|
|
||||||
if self.groupRepeat > 1:
|
|
||||||
x = x.repeat(1, self.groupRepeat, 1)
|
|
||||||
x = Lut.apply(x, self.weight, self.index)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -186,7 +187,7 @@ class SimpleBNN(nn.Module):
|
||||||
self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False)
|
self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False)
|
||||||
self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False)
|
self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False)
|
||||||
self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False)
|
self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False)
|
||||||
self.lnn5 = LutCnn(10, (BS, 80, 3, 3), 3, 1, 1)
|
self.lnn5 = LutCnn(1, (BS, 80, 3, 3), 3, 1, 1)
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||||
|
@ -282,10 +283,10 @@ class SimpleLNN(nn.Module):
|
||||||
print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
|
print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
|
||||||
print("=============================")
|
print("=============================")
|
||||||
print("=============================")
|
print("=============================")
|
||||||
# print("self.lutg2")
|
print("self.lutg2")
|
||||||
# print(self.lutg2.weight.detach().cpu().numpy())
|
print(self.lutg2.weight.detach().cpu().numpy())
|
||||||
# print("=============================")
|
print("=============================")
|
||||||
# print("=============================")
|
print("=============================")
|
||||||
|
|
||||||
|
|
||||||
torch.autograd.set_detect_anomaly(True)
|
torch.autograd.set_detect_anomaly(True)
|
||||||
|
|
Loading…
Reference in New Issue