Update Lut forward and backward.

This commit is contained in:
Colin 2025-06-10 10:52:26 +08:00
parent a676025d20
commit c322ee8228
1 changed files with 30 additions and 7 deletions

View File

@ -46,14 +46,14 @@ class Lut(torch.autograd.Function):
def forward(ctx, input, weight, index):
ind = ((input > 0).long() * index).sum(dim=-1)
output = torch.gather(weight, 0, ind)
# output = (output > 0).float()
# output = (output - 0.5) * 2.0
ctx.save_for_backward(input, weight, ind, output)
output = (output > 0).float()
output = (output - 0.5) * 2.0
ctx.save_for_backward(input, weight, ind)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, ind, output = ctx.saved_tensors
input, weight, ind = ctx.saved_tensors
grad_input = grad_weight = None
bits = input.shape[2]
@ -62,10 +62,33 @@ class Lut(torch.autograd.Function):
grad_weight.scatter_add_(0, ind, grad_output)
if ctx.needs_input_grad[0]:
grad_input = grad_output * torch.gather(weight, 0, ind)
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
return grad_input, grad_weight, None, None
grad_input = grad_output * torch.gather(weight, 0, ind)
# grad_input = grad_output
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
# output = output.unsqueeze(-1).repeat(1, 1, bits)
in_sign = ((input > 0).float() - 0.5) * 2.0
grad_input = grad_input * in_sign
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
# grad_input = grad_output
# grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
# output = output.unsqueeze(-1).repeat(1, 1, bits)
# in_sign = ((input > 0).float() - 0.5) * 2.0
# out_sign = ((output > 0).float() - 0.5) * 2.0
# grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
# grad_input = grad_input * in_sign * (out_sign * grad_sign)
# grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
# 需要一个动态的调整系数
# 能稳定的收敛
# print(in_sign[0].detach().cpu().numpy())
# print(out_sign[0].detach().cpu().numpy())
# print(grad_sign[0].detach().cpu().numpy())
# print(grad_input[0].detach().cpu().numpy())
return grad_input, grad_weight, None
class SimpleCNN(nn.Module):