Update Lut forward and backward.
This commit is contained in:
		
							parent
							
								
									a676025d20
								
							
						
					
					
						commit
						c322ee8228
					
				| 
						 | 
					@ -46,14 +46,14 @@ class Lut(torch.autograd.Function):
 | 
				
			||||||
    def forward(ctx, input, weight, index):
 | 
					    def forward(ctx, input, weight, index):
 | 
				
			||||||
        ind = ((input > 0).long() * index).sum(dim=-1)
 | 
					        ind = ((input > 0).long() * index).sum(dim=-1)
 | 
				
			||||||
        output = torch.gather(weight, 0, ind)
 | 
					        output = torch.gather(weight, 0, ind)
 | 
				
			||||||
        # output = (output > 0).float()
 | 
					        output = (output > 0).float()
 | 
				
			||||||
        # output = (output - 0.5) * 2.0
 | 
					        output = (output - 0.5) * 2.0
 | 
				
			||||||
        ctx.save_for_backward(input, weight, ind, output)
 | 
					        ctx.save_for_backward(input, weight, ind)
 | 
				
			||||||
        return output
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def backward(ctx, grad_output):
 | 
					    def backward(ctx, grad_output):
 | 
				
			||||||
        input, weight, ind, output = ctx.saved_tensors
 | 
					        input, weight, ind = ctx.saved_tensors
 | 
				
			||||||
        grad_input = grad_weight = None
 | 
					        grad_input = grad_weight = None
 | 
				
			||||||
        bits = input.shape[2]
 | 
					        bits = input.shape[2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -62,10 +62,33 @@ class Lut(torch.autograd.Function):
 | 
				
			||||||
            grad_weight.scatter_add_(0, ind, grad_output)
 | 
					            grad_weight.scatter_add_(0, ind, grad_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if ctx.needs_input_grad[0]:
 | 
					        if ctx.needs_input_grad[0]:
 | 
				
			||||||
            grad_input = grad_output * torch.gather(weight, 0, ind)
 | 
					 | 
				
			||||||
            grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return grad_input, grad_weight, None, None
 | 
					            grad_input = grad_output * torch.gather(weight, 0, ind)
 | 
				
			||||||
 | 
					            # grad_input = grad_output
 | 
				
			||||||
 | 
					            grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
 | 
				
			||||||
 | 
					            # output = output.unsqueeze(-1).repeat(1, 1, bits)
 | 
				
			||||||
 | 
					            in_sign = ((input > 0).float() - 0.5) * 2.0
 | 
				
			||||||
 | 
					            grad_input = grad_input * in_sign
 | 
				
			||||||
 | 
					            grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # grad_input = grad_output
 | 
				
			||||||
 | 
					            # grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
 | 
				
			||||||
 | 
					            # output = output.unsqueeze(-1).repeat(1, 1, bits)
 | 
				
			||||||
 | 
					            # in_sign = ((input > 0).float() - 0.5) * 2.0
 | 
				
			||||||
 | 
					            # out_sign = ((output > 0).float() - 0.5) * 2.0
 | 
				
			||||||
 | 
					            # grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
 | 
				
			||||||
 | 
					            # grad_input = grad_input * in_sign * (out_sign * grad_sign)
 | 
				
			||||||
 | 
					            # grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # 需要一个动态的调整系数
 | 
				
			||||||
 | 
					            # 能稳定的收敛
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # print(in_sign[0].detach().cpu().numpy())
 | 
				
			||||||
 | 
					            # print(out_sign[0].detach().cpu().numpy())
 | 
				
			||||||
 | 
					            # print(grad_sign[0].detach().cpu().numpy())
 | 
				
			||||||
 | 
					            # print(grad_input[0].detach().cpu().numpy())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return grad_input, grad_weight, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SimpleCNN(nn.Module):
 | 
					class SimpleCNN(nn.Module):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue