Update Lut forward and backward.
This commit is contained in:
parent
a676025d20
commit
c322ee8228
|
@ -46,14 +46,14 @@ class Lut(torch.autograd.Function):
|
|||
def forward(ctx, input, weight, index):
|
||||
ind = ((input > 0).long() * index).sum(dim=-1)
|
||||
output = torch.gather(weight, 0, ind)
|
||||
# output = (output > 0).float()
|
||||
# output = (output - 0.5) * 2.0
|
||||
ctx.save_for_backward(input, weight, ind, output)
|
||||
output = (output > 0).float()
|
||||
output = (output - 0.5) * 2.0
|
||||
ctx.save_for_backward(input, weight, ind)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, ind, output = ctx.saved_tensors
|
||||
input, weight, ind = ctx.saved_tensors
|
||||
grad_input = grad_weight = None
|
||||
bits = input.shape[2]
|
||||
|
||||
|
@ -62,10 +62,33 @@ class Lut(torch.autograd.Function):
|
|||
grad_weight.scatter_add_(0, ind, grad_output)
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||
|
||||
return grad_input, grad_weight, None, None
|
||||
grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||
# grad_input = grad_output
|
||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||
# output = output.unsqueeze(-1).repeat(1, 1, bits)
|
||||
in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||
grad_input = grad_input * in_sign
|
||||
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||
|
||||
# grad_input = grad_output
|
||||
# grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||
# output = output.unsqueeze(-1).repeat(1, 1, bits)
|
||||
# in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||
# out_sign = ((output > 0).float() - 0.5) * 2.0
|
||||
# grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
|
||||
# grad_input = grad_input * in_sign * (out_sign * grad_sign)
|
||||
# grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||
|
||||
# 需要一个动态的调整系数
|
||||
# 能稳定的收敛
|
||||
|
||||
# print(in_sign[0].detach().cpu().numpy())
|
||||
# print(out_sign[0].detach().cpu().numpy())
|
||||
# print(grad_sign[0].detach().cpu().numpy())
|
||||
# print(grad_input[0].detach().cpu().numpy())
|
||||
|
||||
return grad_input, grad_weight, None
|
||||
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue