Refine LUT repeat from group to LUT.
This commit is contained in:
		
							parent
							
								
									c322ee8228
								
							
						
					
					
						commit
						6cb969ac3b
					
				|  | @ -39,35 +39,43 @@ test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=T | |||
| 
 | ||||
| 
 | ||||
| class Lut(torch.autograd.Function): | ||||
|     # input  [batch, group, bits ] | ||||
|     # output [batch, group ] | ||||
|     # input   [ batch, group * groupBits ] | ||||
|     # output  [ batch, group * groupRepeat ] | ||||
|     # weight [2**bits, group ] | ||||
|     @staticmethod | ||||
|     def forward(ctx, input, weight, index): | ||||
|         ind = ((input > 0).long() * index).sum(dim=-1) | ||||
|     def forward(ctx, input, weight, index, groupBits, groupRepeat): | ||||
|         batch = input.shape[0] | ||||
|         x = input.reshape(batch, -1, groupBits) | ||||
|         if groupRepeat > 1: | ||||
|             x = x.repeat(1, groupRepeat, 1) | ||||
| 
 | ||||
|         ind = ((x > 0).long() * index).sum(dim=-1) | ||||
|         output = torch.gather(weight, 0, ind) | ||||
|         output = (output > 0).float() | ||||
|         output = (output - 0.5) * 2.0 | ||||
|         ctx.groupBits = groupBits | ||||
|         ctx.groupRepeat = groupRepeat | ||||
|         ctx.batch = batch | ||||
|         ctx.save_for_backward(input, weight, ind) | ||||
|         return output | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def backward(ctx, grad_output): | ||||
|         input, weight, ind = ctx.saved_tensors | ||||
|         groupBits = ctx.groupBits | ||||
|         groupRepeat = ctx.groupRepeat | ||||
|         batch = ctx.batch | ||||
|         grad_input = grad_weight = None | ||||
|         bits = input.shape[2] | ||||
| 
 | ||||
|         if ctx.needs_input_grad[1]: | ||||
|             grad_weight = torch.zeros_like(weight) | ||||
|             grad_weight.scatter_add_(0, ind, grad_output) | ||||
| 
 | ||||
|         if ctx.needs_input_grad[0]: | ||||
| 
 | ||||
|             grad_input = grad_output * torch.gather(weight, 0, ind) | ||||
|             # grad_input = grad_output | ||||
|             grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) | ||||
|             # output = output.unsqueeze(-1).repeat(1, 1, bits) | ||||
|             in_sign = ((input > 0).float() - 0.5) * 2.0 | ||||
|             grad_input = grad_input.view(batch, -1, groupRepeat).sum(2) | ||||
|             grad_input = grad_input.unsqueeze(-1).repeat(1, 1, groupBits) | ||||
|             grad_input = grad_input * in_sign | ||||
|             grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0) | ||||
| 
 | ||||
|  | @ -85,10 +93,7 @@ class Lut(torch.autograd.Function): | |||
| 
 | ||||
|             # print(in_sign[0].detach().cpu().numpy()) | ||||
|             # print(out_sign[0].detach().cpu().numpy()) | ||||
|             # print(grad_sign[0].detach().cpu().numpy()) | ||||
|             # print(grad_input[0].detach().cpu().numpy()) | ||||
| 
 | ||||
|         return grad_input, grad_weight, None | ||||
|         return grad_input, grad_weight, None, None, None | ||||
| 
 | ||||
| 
 | ||||
| class SimpleCNN(nn.Module): | ||||
|  | @ -121,7 +126,7 @@ class LutGroup(nn.Module): | |||
|     def __init__(self, group, groupBits, groupRepeat=1): | ||||
|         assert groupBits > 1 | ||||
|         super(LutGroup, self).__init__() | ||||
|         self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group))) | ||||
|         self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group))) | ||||
|         self.group = group | ||||
|         self.groupBits = groupBits | ||||
|         self.groupRepeat = groupRepeat | ||||
|  | @ -130,11 +135,7 @@ class LutGroup(nn.Module): | |||
|     def forward(self, x): | ||||
|         # input   [ batch, group * groupBits ] | ||||
|         # output  [ batch, group * groupRepeat ] | ||||
|         batch = x.shape[0] | ||||
|         x = x.reshape(batch, -1, self.groupBits) | ||||
|         if self.groupRepeat > 1: | ||||
|             x = x.repeat(1, self.groupRepeat, 1) | ||||
|         x = Lut.apply(x, self.weight, self.index) | ||||
|         x = Lut.apply(x, self.weight, self.index, self.groupBits, self.groupRepeat) | ||||
|         return x | ||||
| 
 | ||||
| 
 | ||||
|  | @ -186,7 +187,7 @@ class SimpleBNN(nn.Module): | |||
|         self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False) | ||||
|         self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False) | ||||
|         self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False) | ||||
|         self.lnn5 = LutCnn(10, (BS, 80, 3, 3), 3, 1, 1) | ||||
|         self.lnn5 = LutCnn(1, (BS, 80, 3, 3), 3, 1, 1) | ||||
| 
 | ||||
|         self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | ||||
|         self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | ||||
|  | @ -282,10 +283,10 @@ class SimpleLNN(nn.Module): | |||
|         print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy()) | ||||
|         print("=============================") | ||||
|         print("=============================") | ||||
|         # print("self.lutg2") | ||||
|         # print(self.lutg2.weight.detach().cpu().numpy()) | ||||
|         # print("=============================") | ||||
|         # print("=============================") | ||||
|         print("self.lutg2") | ||||
|         print(self.lutg2.weight.detach().cpu().numpy()) | ||||
|         print("=============================") | ||||
|         print("=============================") | ||||
| 
 | ||||
| 
 | ||||
| torch.autograd.set_detect_anomaly(True) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue