diff --git a/binary/mnist.py b/binary/mnist.py index 60518f8..57127bb 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -39,44 +39,31 @@ test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False) class Lut(torch.autograd.Function): - # input [batch,count,bits] - # weight [count,2**bits] + # input [batch, group, bits ] + # output [batch, group ] + # weight [2**bits, group ] @staticmethod - def forward(ctx, input, weight): - batch = input.shape[0] - count = input.shape[1] - bits = input.shape[2] - assert int(math.log2(weight.shape[-1])) == bits - - index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device) + def forward(ctx, input, weight, index): ind = ((input > 0).long() * index).sum(dim=-1) - - row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1).to(input.device) - output = weight[row_indices, ind] - - ctx.save_for_backward(input, weight, row_indices, ind) + output = torch.gather(weight, 0, ind) + ctx.save_for_backward(input, weight, ind) return output @staticmethod def backward(ctx, grad_output): - input, weight, row_indices, ind = ctx.saved_tensors + input, weight, ind = ctx.saved_tensors grad_input = grad_weight = None - - batch = input.shape[0] - count = input.shape[1] bits = input.shape[2] if ctx.needs_input_grad[1]: grad_weight = torch.zeros_like(weight) - ind_p = ind.permute(1, 0) - grad_output_p = grad_output.permute(1, 0) - grad_weight.scatter_add_(1, ind_p, grad_output_p) + grad_weight.scatter_add_(0, ind, grad_output) if ctx.needs_input_grad[0]: - grad_input = grad_output * weight[row_indices, ind] + grad_input = grad_output * torch.gather(weight, 0, ind) grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits) - return grad_input, grad_weight + return grad_input, grad_weight, None class SimpleCNN(nn.Module): @@ -106,31 +93,30 @@ class SimpleCNN(nn.Module): class LutGroup(nn.Module): - def __init__(self, totalBits, groupBits=0, repeat=1): - if not groupBits: - groupBits = totalBits + def __init__(self, group, groupBits, groupRepeat=1): + assert groupBits > 1 super(LutGroup, self).__init__() - assert (totalBits % groupBits) == 0 - self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits))) - self.totalBits = totalBits + self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group))) + self.group = group self.groupBits = groupBits - self.repeat = repeat + self.groupRepeat = groupRepeat + self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False) def forward(self, x): - # input [ batch, totalBits ] - # output [ batch, totalBits / groupBits * repeat ] + # input [ batch, group * groupBits ] + # output [ batch, group * groupRepeat ] batch = x.shape[0] x = x.view(batch, -1, self.groupBits) - if self.repeat > 1: - x = x.repeat(1, self.repeat, 1) - x = Lut.apply(x, self.weight) + if self.groupRepeat > 1: + x = x.repeat(1, self.groupRepeat, 1) + x = Lut.apply(x, self.weight, self.index) return x class LutCnn(nn.Module): def __init__(self, output_c, input_shape, kernel_size, stride, dilation): super(LutCnn, self).__init__() - self.output_c = output_c + B, C, H, W = input_shape self.input_shape = input_shape self.kernel_size = kernel_size self.stride = stride @@ -140,7 +126,8 @@ class LutCnn(nn.Module): self.channel_idx = nn.Parameter(channel_idx, requires_grad=False) self.h_idx = nn.Parameter(h_idx, requires_grad=False) self.w_idx = nn.Parameter(w_idx, requires_grad=False) - self.lut = LutGroup(len(self.batch_idx), kernel_size * kernel_size, output_c) + groupBits = kernel_size * kernel_size + self.lut = LutGroup(len(self.batch_idx) / B / groupBits, groupBits, output_c) def forward(self, x): B, C, H, W = self.input_shape @@ -299,9 +286,9 @@ def profiler(): loss = criterion(output, target) loss.backward() optimizer.step() - print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) - + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) if batch_idx > 10: + prof.export_chrome_trace("local.json") assert False