Update mnist use gather to speed up.

This commit is contained in:
Colin 2025-05-28 16:05:34 +08:00
parent dd71a5dedd
commit e2c8668a1b
1 changed files with 26 additions and 39 deletions

View File

@ -39,44 +39,31 @@ test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False)
class Lut(torch.autograd.Function):
# input [batch,count,bits]
# weight [count,2**bits]
# input [batch, group, bits ]
# output [batch, group ]
# weight [2**bits, group ]
@staticmethod
def forward(ctx, input, weight):
batch = input.shape[0]
count = input.shape[1]
bits = input.shape[2]
assert int(math.log2(weight.shape[-1])) == bits
index = 2 ** torch.arange(bits - 1, -1, -1, device=input.device)
def forward(ctx, input, weight, index):
ind = ((input > 0).long() * index).sum(dim=-1)
row_indices = torch.arange(count).unsqueeze(0).expand(batch, -1).to(input.device)
output = weight[row_indices, ind]
ctx.save_for_backward(input, weight, row_indices, ind)
output = torch.gather(weight, 0, ind)
ctx.save_for_backward(input, weight, ind)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, row_indices, ind = ctx.saved_tensors
input, weight, ind = ctx.saved_tensors
grad_input = grad_weight = None
batch = input.shape[0]
count = input.shape[1]
bits = input.shape[2]
if ctx.needs_input_grad[1]:
grad_weight = torch.zeros_like(weight)
ind_p = ind.permute(1, 0)
grad_output_p = grad_output.permute(1, 0)
grad_weight.scatter_add_(1, ind_p, grad_output_p)
grad_weight.scatter_add_(0, ind, grad_output)
if ctx.needs_input_grad[0]:
grad_input = grad_output * weight[row_indices, ind]
grad_input = grad_output * torch.gather(weight, 0, ind)
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
return grad_input, grad_weight
return grad_input, grad_weight, None
class SimpleCNN(nn.Module):
@ -106,31 +93,30 @@ class SimpleCNN(nn.Module):
class LutGroup(nn.Module):
def __init__(self, totalBits, groupBits=0, repeat=1):
if not groupBits:
groupBits = totalBits
def __init__(self, group, groupBits, groupRepeat=1):
assert groupBits > 1
super(LutGroup, self).__init__()
assert (totalBits % groupBits) == 0
self.weight = nn.Parameter(torch.randn(repeat * totalBits, pow(2, groupBits)))
self.totalBits = totalBits
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
self.group = group
self.groupBits = groupBits
self.repeat = repeat
self.groupRepeat = groupRepeat
self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False)
def forward(self, x):
# input [ batch, totalBits ]
# output [ batch, totalBits / groupBits * repeat ]
# input [ batch, group * groupBits ]
# output [ batch, group * groupRepeat ]
batch = x.shape[0]
x = x.view(batch, -1, self.groupBits)
if self.repeat > 1:
x = x.repeat(1, self.repeat, 1)
x = Lut.apply(x, self.weight)
if self.groupRepeat > 1:
x = x.repeat(1, self.groupRepeat, 1)
x = Lut.apply(x, self.weight, self.index)
return x
class LutCnn(nn.Module):
def __init__(self, output_c, input_shape, kernel_size, stride, dilation):
super(LutCnn, self).__init__()
self.output_c = output_c
B, C, H, W = input_shape
self.input_shape = input_shape
self.kernel_size = kernel_size
self.stride = stride
@ -140,7 +126,8 @@ class LutCnn(nn.Module):
self.channel_idx = nn.Parameter(channel_idx, requires_grad=False)
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
self.lut = LutGroup(len(self.batch_idx), kernel_size * kernel_size, output_c)
groupBits = kernel_size * kernel_size
self.lut = LutGroup(len(self.batch_idx) / B / groupBits, groupBits, output_c)
def forward(self, x):
B, C, H, W = self.input_shape
@ -299,9 +286,9 @@ def profiler():
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
if batch_idx > 10:
prof.export_chrome_trace("local.json")
assert False