Compare commits
No commits in common. "a676025d20cb7286b1ef3a9a4ee977b72d97ee94" and "924a0ca9b438b888a49087dbaffbf8953288f34b" have entirely different histories.
a676025d20
...
924a0ca9b4
|
@ -46,8 +46,8 @@ class Lut(torch.autograd.Function):
|
|||
def forward(ctx, input, weight, index):
|
||||
ind = ((input > 0).long() * index).sum(dim=-1)
|
||||
output = torch.gather(weight, 0, ind)
|
||||
# output = (output > 0).float()
|
||||
# output = (output - 0.5) * 2.0
|
||||
output = (output > 0).float()
|
||||
output = (output - 0.5) * 2.0
|
||||
ctx.save_for_backward(input, weight, ind, output)
|
||||
return output
|
||||
|
||||
|
@ -62,8 +62,31 @@ class Lut(torch.autograd.Function):
|
|||
grad_weight.scatter_add_(0, ind, grad_output)
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
|
||||
grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||
# grad_input = grad_output
|
||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||
output = output.unsqueeze(-1).repeat(1, 1, bits)
|
||||
in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||
grad_input = grad_input * in_sign
|
||||
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||
|
||||
# grad_input = grad_output
|
||||
# grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||
# output = output.unsqueeze(-1).repeat(1, 1, bits)
|
||||
# in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||
# out_sign = ((output > 0).float() - 0.5) * 2.0
|
||||
# grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
|
||||
# grad_input = grad_input * in_sign * (out_sign * grad_sign)
|
||||
# grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||
|
||||
# 需要一个动态的调整系数
|
||||
# 能稳定的收敛
|
||||
|
||||
# print(in_sign[0].detach().cpu().numpy())
|
||||
# print(out_sign[0].detach().cpu().numpy())
|
||||
# print(grad_sign[0].detach().cpu().numpy())
|
||||
# print(grad_input[0].detach().cpu().numpy())
|
||||
|
||||
return grad_input, grad_weight, None, None
|
||||
|
||||
|
@ -159,11 +182,11 @@ class SimpleBNN(nn.Module):
|
|||
self.b = nn.Parameter(torch.zeros(3, 784))
|
||||
|
||||
# channel_repeat, input_shape, kernel_size, stride, dilation, fc
|
||||
self.lnn1 = LutCnn(80, (BS, 1, 28, 28), 2, 2, 1, False)
|
||||
self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False)
|
||||
self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False)
|
||||
self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False)
|
||||
self.lnn5 = LutCnn(10, (BS, 80, 3, 3), 3, 1, 1)
|
||||
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1, False)
|
||||
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1, False)
|
||||
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1, False)
|
||||
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1, False)
|
||||
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
|
@ -267,8 +290,8 @@ class SimpleLNN(nn.Module):
|
|||
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
# model = SimpleCNN().to(device)
|
||||
model = SimpleBNN().to(device)
|
||||
# model = SimpleLNN().to(device)
|
||||
# model = SimpleBNN().to(device)
|
||||
model = SimpleLNN().to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
||||
|
||||
|
@ -341,6 +364,8 @@ def profiler():
|
|||
assert False
|
||||
|
||||
|
||||
# profiler()
|
||||
|
||||
for epoch in range(1, 300):
|
||||
train(epoch)
|
||||
test(epoch)
|
||||
|
|
|
@ -6,7 +6,3 @@
|
|||
|
||||
v100_32G_PCIE batch=1 bf16=false 5.5s/it
|
||||
v100_32G_PCIE batch=1 bf16=true 8.7s/it
|
||||
v100_32G_PCIE batch=4 bf16=false 21.53s/it
|
||||
|
||||
|
||||
A100_40G_PCIE batch=4 bf16=false 3.58s/it
|
||||
|
|
Loading…
Reference in New Issue