Update Lut forward and backward.
This commit is contained in:
		
							parent
							
								
									a676025d20
								
							
						
					
					
						commit
						c322ee8228
					
				| 
						 | 
				
			
			@ -46,14 +46,14 @@ class Lut(torch.autograd.Function):
 | 
			
		|||
    def forward(ctx, input, weight, index):
 | 
			
		||||
        ind = ((input > 0).long() * index).sum(dim=-1)
 | 
			
		||||
        output = torch.gather(weight, 0, ind)
 | 
			
		||||
        # output = (output > 0).float()
 | 
			
		||||
        # output = (output - 0.5) * 2.0
 | 
			
		||||
        ctx.save_for_backward(input, weight, ind, output)
 | 
			
		||||
        output = (output > 0).float()
 | 
			
		||||
        output = (output - 0.5) * 2.0
 | 
			
		||||
        ctx.save_for_backward(input, weight, ind)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, grad_output):
 | 
			
		||||
        input, weight, ind, output = ctx.saved_tensors
 | 
			
		||||
        input, weight, ind = ctx.saved_tensors
 | 
			
		||||
        grad_input = grad_weight = None
 | 
			
		||||
        bits = input.shape[2]
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -62,10 +62,33 @@ class Lut(torch.autograd.Function):
 | 
			
		|||
            grad_weight.scatter_add_(0, ind, grad_output)
 | 
			
		||||
 | 
			
		||||
        if ctx.needs_input_grad[0]:
 | 
			
		||||
            grad_input = grad_output * torch.gather(weight, 0, ind)
 | 
			
		||||
            grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
 | 
			
		||||
 | 
			
		||||
        return grad_input, grad_weight, None, None
 | 
			
		||||
            grad_input = grad_output * torch.gather(weight, 0, ind)
 | 
			
		||||
            # grad_input = grad_output
 | 
			
		||||
            grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
 | 
			
		||||
            # output = output.unsqueeze(-1).repeat(1, 1, bits)
 | 
			
		||||
            in_sign = ((input > 0).float() - 0.5) * 2.0
 | 
			
		||||
            grad_input = grad_input * in_sign
 | 
			
		||||
            grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
 | 
			
		||||
 | 
			
		||||
            # grad_input = grad_output
 | 
			
		||||
            # grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
 | 
			
		||||
            # output = output.unsqueeze(-1).repeat(1, 1, bits)
 | 
			
		||||
            # in_sign = ((input > 0).float() - 0.5) * 2.0
 | 
			
		||||
            # out_sign = ((output > 0).float() - 0.5) * 2.0
 | 
			
		||||
            # grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
 | 
			
		||||
            # grad_input = grad_input * in_sign * (out_sign * grad_sign)
 | 
			
		||||
            # grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
 | 
			
		||||
 | 
			
		||||
            # 需要一个动态的调整系数
 | 
			
		||||
            # 能稳定的收敛
 | 
			
		||||
 | 
			
		||||
            # print(in_sign[0].detach().cpu().numpy())
 | 
			
		||||
            # print(out_sign[0].detach().cpu().numpy())
 | 
			
		||||
            # print(grad_sign[0].detach().cpu().numpy())
 | 
			
		||||
            # print(grad_input[0].detach().cpu().numpy())
 | 
			
		||||
 | 
			
		||||
        return grad_input, grad_weight, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SimpleCNN(nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue