Update mnist use gather to speed up.
This commit is contained in:
parent
dd71a5dedd
commit
e2c8668a1b
|
@ -39,44 +39,31 @@ test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False)
|
||||||
|
|
||||||
|
|
||||||
class Lut(torch.autograd.Function):
|
class Lut(torch.autograd.Function):
|
||||||
# input [batch,count,bits]
|
# input [batch, group, bits ]
|
||||||
# weight [count,2**bits]
|
# output [batch, group ]
|
||||||
|
# weight [2**bits, group ]
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, weight):
|
def forward(ctx, input, weight, index):
|
||||||
batch = input.shape[0]
|
|
||||||
count = input.shape[1]
|
|
||||||
bits = input.shape[2]
|
|
||||||
assert int(math.log2(weight.shape[-1])) == bits
|
|
||||||
|
|
||||||
index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device)
|
|
||||||
ind = ((input > 0).long() * index).sum(dim=-1)
|
ind = ((input > 0).long() * index).sum(dim=-1)
|
||||||
|
output = torch.gather(weight, 0, ind)
|
||||||
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1).to(input.device)
|
ctx.save_for_backward(input, weight, ind)
|
||||||
output = weight[row_indices, ind]
|
|
||||||
|
|
||||||
ctx.save_for_backward(input, weight, row_indices, ind)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, weight, row_indices, ind = ctx.saved_tensors
|
input, weight, ind = ctx.saved_tensors
|
||||||
grad_input = grad_weight = None
|
grad_input = grad_weight = None
|
||||||
|
|
||||||
batch = input.shape[0]
|
|
||||||
count = input.shape[1]
|
|
||||||
bits = input.shape[2]
|
bits = input.shape[2]
|
||||||
|
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
grad_weight = torch.zeros_like(weight)
|
grad_weight = torch.zeros_like(weight)
|
||||||
ind_p = ind.permute(1, 0)
|
grad_weight.scatter_add_(0, ind, grad_output)
|
||||||
grad_output_p = grad_output.permute(1, 0)
|
|
||||||
grad_weight.scatter_add_(1, ind_p, grad_output_p)
|
|
||||||
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
grad_input = grad_output * weight[row_indices, ind]
|
grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||||
|
|
||||||
return grad_input, grad_weight
|
return grad_input, grad_weight, None
|
||||||
|
|
||||||
|
|
||||||
class SimpleCNN(nn.Module):
|
class SimpleCNN(nn.Module):
|
||||||
|
@ -106,31 +93,30 @@ class SimpleCNN(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class LutGroup(nn.Module):
|
class LutGroup(nn.Module):
|
||||||
def __init__(self, totalBits, groupBits=0, repeat=1):
|
def __init__(self, group, groupBits, groupRepeat=1):
|
||||||
if not groupBits:
|
assert groupBits > 1
|
||||||
groupBits = totalBits
|
|
||||||
super(LutGroup, self).__init__()
|
super(LutGroup, self).__init__()
|
||||||
assert (totalBits % groupBits) == 0
|
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
|
||||||
self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits)))
|
self.group = group
|
||||||
self.totalBits = totalBits
|
|
||||||
self.groupBits = groupBits
|
self.groupBits = groupBits
|
||||||
self.repeat = repeat
|
self.groupRepeat = groupRepeat
|
||||||
|
self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# input [ batch, totalBits ]
|
# input [ batch, group * groupBits ]
|
||||||
# output [ batch, totalBits / groupBits * repeat ]
|
# output [ batch, group * groupRepeat ]
|
||||||
batch = x.shape[0]
|
batch = x.shape[0]
|
||||||
x = x.view(batch, -1, self.groupBits)
|
x = x.view(batch, -1, self.groupBits)
|
||||||
if self.repeat > 1:
|
if self.groupRepeat > 1:
|
||||||
x = x.repeat(1, self.repeat, 1)
|
x = x.repeat(1, self.groupRepeat, 1)
|
||||||
x = Lut.apply(x, self.weight)
|
x = Lut.apply(x, self.weight, self.index)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LutCnn(nn.Module):
|
class LutCnn(nn.Module):
|
||||||
def __init__(self, output_c, input_shape, kernel_size, stride, dilation):
|
def __init__(self, output_c, input_shape, kernel_size, stride, dilation):
|
||||||
super(LutCnn, self).__init__()
|
super(LutCnn, self).__init__()
|
||||||
self.output_c = output_c
|
B, C, H, W = input_shape
|
||||||
self.input_shape = input_shape
|
self.input_shape = input_shape
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
@ -140,7 +126,8 @@ class LutCnn(nn.Module):
|
||||||
self.channel_idx = nn.Parameter(channel_idx, requires_grad=False)
|
self.channel_idx = nn.Parameter(channel_idx, requires_grad=False)
|
||||||
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
|
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
|
||||||
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
|
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
|
||||||
self.lut = LutGroup(len(self.batch_idx), kernel_size * kernel_size, output_c)
|
groupBits = kernel_size * kernel_size
|
||||||
|
self.lut = LutGroup(len(self.batch_idx) / B / groupBits, groupBits, output_c)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = self.input_shape
|
B, C, H, W = self.input_shape
|
||||||
|
@ -299,9 +286,9 @@ def profiler():
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||||
|
|
||||||
if batch_idx > 10:
|
if batch_idx > 10:
|
||||||
|
prof.export_chrome_trace("local.json")
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue