Fix and update binary network.
This commit is contained in:
parent
710c901f5e
commit
924a0ca9b4
|
@ -0,0 +1 @@
|
||||||
|
data
|
|
@ -22,7 +22,7 @@ np.random.seed(1234)
|
||||||
torch.cuda.manual_seed_all(1234)
|
torch.cuda.manual_seed_all(1234)
|
||||||
|
|
||||||
BS = 16
|
BS = 16
|
||||||
LR = 0.01
|
LR = 0.001
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
@ -63,8 +63,8 @@ class Lut(torch.autograd.Function):
|
||||||
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
|
|
||||||
# grad_input = grad_output * torch.gather(weight, 0, ind)
|
grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||||
grad_input = grad_output
|
# grad_input = grad_output
|
||||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||||
output = output.unsqueeze(-1).repeat(1, 1, bits)
|
output = output.unsqueeze(-1).repeat(1, 1, bits)
|
||||||
in_sign = ((input > 0).float() - 0.5) * 2.0
|
in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||||
|
@ -121,7 +121,7 @@ class LutGroup(nn.Module):
|
||||||
def __init__(self, group, groupBits, groupRepeat=1):
|
def __init__(self, group, groupBits, groupRepeat=1):
|
||||||
assert groupBits > 1
|
assert groupBits > 1
|
||||||
super(LutGroup, self).__init__()
|
super(LutGroup, self).__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group)))
|
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
|
||||||
self.group = group
|
self.group = group
|
||||||
self.groupBits = groupBits
|
self.groupBits = groupBits
|
||||||
self.groupRepeat = groupRepeat
|
self.groupRepeat = groupRepeat
|
||||||
|
@ -166,7 +166,7 @@ class LutCnn(nn.Module):
|
||||||
x = x.view(B, -1, self.kernel_size * self.kernel_size)
|
x = x.view(B, -1, self.kernel_size * self.kernel_size)
|
||||||
x = self.lut(x)
|
x = self.lut(x)
|
||||||
if self.fc:
|
if self.fc:
|
||||||
x = x.view(B, -1, self.channel_repeat)
|
x = x.view(B, self.channel_repeat * C, -1)
|
||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
x = self.lutc(x)
|
x = self.lutc(x)
|
||||||
return x
|
return x
|
||||||
|
|
Loading…
Reference in New Issue