Fix and update binary network.
This commit is contained in:
parent
710c901f5e
commit
924a0ca9b4
|
@ -0,0 +1 @@
|
|||
data
|
|
@ -22,7 +22,7 @@ np.random.seed(1234)
|
|||
torch.cuda.manual_seed_all(1234)
|
||||
|
||||
BS = 16
|
||||
LR = 0.01
|
||||
LR = 0.001
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
@ -63,8 +63,8 @@ class Lut(torch.autograd.Function):
|
|||
|
||||
if ctx.needs_input_grad[0]:
|
||||
|
||||
# grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||
grad_input = grad_output
|
||||
grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||
# grad_input = grad_output
|
||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||
output = output.unsqueeze(-1).repeat(1, 1, bits)
|
||||
in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||
|
@ -121,7 +121,7 @@ class LutGroup(nn.Module):
|
|||
def __init__(self, group, groupBits, groupRepeat=1):
|
||||
assert groupBits > 1
|
||||
super(LutGroup, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group)))
|
||||
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
|
||||
self.group = group
|
||||
self.groupBits = groupBits
|
||||
self.groupRepeat = groupRepeat
|
||||
|
@ -166,7 +166,7 @@ class LutCnn(nn.Module):
|
|||
x = x.view(B, -1, self.kernel_size * self.kernel_size)
|
||||
x = self.lut(x)
|
||||
if self.fc:
|
||||
x = x.view(B, -1, self.channel_repeat)
|
||||
x = x.view(B, self.channel_repeat * C, -1)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.lutc(x)
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue