Update mnist use gather to speed up.
This commit is contained in:
parent
dd71a5dedd
commit
e2c8668a1b
|
@ -39,44 +39,31 @@ test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False)
|
|||
|
||||
|
||||
class Lut(torch.autograd.Function):
|
||||
# input [batch,count,bits]
|
||||
# weight [count,2**bits]
|
||||
# input [batch, group, bits ]
|
||||
# output [batch, group ]
|
||||
# weight [2**bits, group ]
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
batch = input.shape[0]
|
||||
count = input.shape[1]
|
||||
bits = input.shape[2]
|
||||
assert int(math.log2(weight.shape[-1])) == bits
|
||||
|
||||
index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device)
|
||||
def forward(ctx, input, weight, index):
|
||||
ind = ((input > 0).long() * index).sum(dim=-1)
|
||||
|
||||
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1).to(input.device)
|
||||
output = weight[row_indices, ind]
|
||||
|
||||
ctx.save_for_backward(input, weight, row_indices, ind)
|
||||
output = torch.gather(weight, 0, ind)
|
||||
ctx.save_for_backward(input, weight, ind)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, row_indices, ind = ctx.saved_tensors
|
||||
input, weight, ind = ctx.saved_tensors
|
||||
grad_input = grad_weight = None
|
||||
|
||||
batch = input.shape[0]
|
||||
count = input.shape[1]
|
||||
bits = input.shape[2]
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = torch.zeros_like(weight)
|
||||
ind_p = ind.permute(1, 0)
|
||||
grad_output_p = grad_output.permute(1, 0)
|
||||
grad_weight.scatter_add_(1, ind_p, grad_output_p)
|
||||
grad_weight.scatter_add_(0, ind, grad_output)
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = grad_output * weight[row_indices, ind]
|
||||
grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||
|
||||
return grad_input, grad_weight
|
||||
return grad_input, grad_weight, None
|
||||
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
|
@ -106,31 +93,30 @@ class SimpleCNN(nn.Module):
|
|||
|
||||
|
||||
class LutGroup(nn.Module):
|
||||
def __init__(self, totalBits, groupBits=0, repeat=1):
|
||||
if not groupBits:
|
||||
groupBits = totalBits
|
||||
def __init__(self, group, groupBits, groupRepeat=1):
|
||||
assert groupBits > 1
|
||||
super(LutGroup, self).__init__()
|
||||
assert (totalBits % groupBits) == 0
|
||||
self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits)))
|
||||
self.totalBits = totalBits
|
||||
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
|
||||
self.group = group
|
||||
self.groupBits = groupBits
|
||||
self.repeat = repeat
|
||||
self.groupRepeat = groupRepeat
|
||||
self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
# input [ batch, totalBits ]
|
||||
# output [ batch, totalBits / groupBits * repeat ]
|
||||
# input [ batch, group * groupBits ]
|
||||
# output [ batch, group * groupRepeat ]
|
||||
batch = x.shape[0]
|
||||
x = x.view(batch, -1, self.groupBits)
|
||||
if self.repeat > 1:
|
||||
x = x.repeat(1, self.repeat, 1)
|
||||
x = Lut.apply(x, self.weight)
|
||||
if self.groupRepeat > 1:
|
||||
x = x.repeat(1, self.groupRepeat, 1)
|
||||
x = Lut.apply(x, self.weight, self.index)
|
||||
return x
|
||||
|
||||
|
||||
class LutCnn(nn.Module):
|
||||
def __init__(self, output_c, input_shape, kernel_size, stride, dilation):
|
||||
super(LutCnn, self).__init__()
|
||||
self.output_c = output_c
|
||||
B, C, H, W = input_shape
|
||||
self.input_shape = input_shape
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
|
@ -140,7 +126,8 @@ class LutCnn(nn.Module):
|
|||
self.channel_idx = nn.Parameter(channel_idx, requires_grad=False)
|
||||
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
|
||||
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
|
||||
self.lut = LutGroup(len(self.batch_idx), kernel_size * kernel_size, output_c)
|
||||
groupBits = kernel_size * kernel_size
|
||||
self.lut = LutGroup(len(self.batch_idx) / B / groupBits, groupBits, output_c)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = self.input_shape
|
||||
|
@ -299,9 +286,9 @@ def profiler():
|
|||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
|
||||
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
if batch_idx > 10:
|
||||
prof.export_chrome_trace("local.json")
|
||||
assert False
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue