Refine mnist LUT by add lutCNN.

This commit is contained in:
Colin 2025-05-27 18:51:07 +08:00
parent 30dd319d8c
commit dd71a5dedd
3 changed files with 189 additions and 72 deletions

44
binary/cudagraph.py Normal file
View File

@ -0,0 +1,44 @@
import torch
# 1. 定义模型(需满足静态形状和静态控制流)
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(3 * 224 * 224, 1000)
def forward(self, x):
x = x.view(x.size(0), -1) # 静态形状操作
return torch.relu(self.fc(x)) # 避免动态控制流
model = SimpleModel().cuda()
# 2. 准备静态输入/输出占位张量
static_input = torch.randn(16, 3, 224, 224, device='cuda')
static_output = torch.zeros(16, 1000, device='cuda')
# 3. 预热阶段(必须在非默认流)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3): # 预热3次
static_output = model(static_input)
torch.cuda.current_stream().wait_stream(s)
# 4. 创建并捕获CUDA图
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
# 注意:此处操作不会实际执行,仅记录计算图
static_output = model(static_input)
# 5. 使用图计算(更新数据+重放)
def run_graph(new_input):
# 将新数据复制到捕获的输入张量
static_input.copy_(new_input)
# 重放计算图
g.replay()
return static_output.clone() # 返回结果副本
# 测试
new_data = torch.randn(16, 3, 224, 224, device='cuda')
result = run_graph(new_data)
print(result.shape) # torch.Size([16, 1000])

View File

@ -14,14 +14,15 @@ import torch.nn.functional as F
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch.profiler import profile, ProfilerActivity, record_function
from unfold import generate_unfold_index
import datetime
torch.manual_seed(1234)
np.random.seed(1234)
torch.cuda.manual_seed_all(1234)
batch_size = 16
lr = 0.001
BS = 16
LR = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
@ -33,15 +34,8 @@ transform = transforms.Compose(
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
def to_binary_tensor(input_tensor, bits):
int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64)
shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device)
binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1
return binary_bits
train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False)
class Lut(torch.autograd.Function):
@ -133,6 +127,30 @@ class LutGroup(nn.Module):
return x
class LutCnn(nn.Module):
def __init__(self, output_c, input_shape, kernel_size, stride, dilation):
super(LutCnn, self).__init__()
self.output_c = output_c
self.input_shape = input_shape
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation)
self.batch_idx = nn.Parameter(batch_idx, requires_grad=False)
self.channel_idx = nn.Parameter(channel_idx, requires_grad=False)
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
self.lut = LutGroup(len(self.batch_idx), kernel_size * kernel_size, output_c)
def forward(self, x):
B, C, H, W = self.input_shape
x = x.view(self.input_shape)
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
x = x.view(B, -1, self.kernel_size * self.kernel_size)
x = self.lut(x)
return x
class SimpleBNN(nn.Module):
def __init__(self):
super(SimpleBNN, self).__init__()
@ -142,14 +160,12 @@ class SimpleBNN(nn.Module):
self.w = nn.Parameter(torch.randn(3, 784))
self.b = nn.Parameter(torch.zeros(3, 784))
self.lut2_p = LutGroup(784, 4, 8) # pool2 => 14*14*4
self.lut3_p = LutGroup(8 * 14 * 14, 4, 1) # pool2 => 7*7*4
self.lut4_p = LutGroup(8 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8
# self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8
self.lut5_p = LutGroup(8 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16
# self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16
self.lut6_p = LutGroup(8 * 3 * 3, 9, 10) # conv 3 => 128
# self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 8
# output_c, input_shape, kernel_size, stride, dilation
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1)
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1)
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1)
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1)
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
@ -160,7 +176,7 @@ class SimpleBNN(nn.Module):
def forward(self, x):
batch = x.shape[0]
x = x.view(batch, -1)
# x = x.view(batch, -1)
# 变换x [-0.5:0.5] 到 0-255然后按照二进制展开成8个值
# x = (x * 256 + 128).clamp(0, 255).to(torch.uint8)
@ -200,54 +216,11 @@ class SimpleBNN(nn.Module):
#########################
# unfold
# 原始输出 shape: [batch, channels * kH * kW, L]
# 其中 L 是滑动窗口的数量
x = x.view(batch, 1, 28, 28)
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
# x = x.view(batch, 1, 4, -1)
# x = x.permute(0, 1, 3, 2)
x = x.reshape(batch, -1)
x = self.lut2_p(x)
x = x.view(batch, 8, 14, 14)
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
# x = x.view(batch, 8, 4, -1)
# x = x.permute(0, 1, 3, 2)
x = x.reshape(batch, -1)
x = self.lut3_p(x)
x = x.view(batch, 8, 7, 7)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
# x = x.view(batch, 8, 9, -1)
# x = x.permute(0, 1, 3, 2)
x = x.reshape(batch, -1)
x = self.lut4_p(x)
# x = x.view(batch, 8, 25)
# x = x.permute(0, 2, 1) # batch 25 8
# x = x.reshape(batch, -1) # batch 25*8
# x = self.lut4(x) # batch 8*25
x = x.reshape(batch, 8, 5, 5)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
# x = x.view(batch, 8, 9, -1)
# x = x.permute(0, 1, 3, 2)
x = x.reshape(batch, -1)
x = self.lut5_p(x)
# x = x.view(batch, 16, 3, 3)
# x = x.permute(0, 2, 3, 1)
# x = x.reshape(batch, -1)
# x = self.lut5(x)
x = x.view(batch, 8, 3, 3)
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
# x = x.view(batch, 8, 9, -1)
# x = x.permute(0, 1, 3, 2)
x = x.reshape(batch, -1)
x = self.lut6_p(x)
x = self.lnn1(x)
x = self.lnn2(x)
x = self.lnn3(x)
x = self.lnn4(x)
x = self.lnn5(x)
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
x = x.view(batch, -1, 8)
@ -262,7 +235,7 @@ torch.autograd.set_detect_anomaly(True)
# model = SimpleCNN().to(device)
model = SimpleBNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
tbWriter = None
@ -272,7 +245,7 @@ def AddScalar(tag, value, epoch):
if not tbWriter:
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
tbWriter = SummaryWriter(f"log/{current_time}")
hparam_dict = {"lr": lr, "batch_size": batch_size}
hparam_dict = {"lr": LR, "batch_size": BS}
tbWriter.add_hparams(hparam_dict, {}, run_name=f"./")
tbWriter.add_scalar(tag, value, epoch)
@ -328,8 +301,8 @@ def profiler():
optimizer.step()
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
if batch_idx > 5:
break
if batch_idx > 10:
assert False
# profiler()

100
binary/unfold.py Normal file
View File

@ -0,0 +1,100 @@
import torch
import torch.nn.functional as F
def to_binary_tensor(input_tensor, bits):
int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64)
shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device)
binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1
return binary_bits
def generate_unfold_index(input_shape, kernel_size, stride, dilation=1):
padding = 0
B, C, H, W = input_shape
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
if isinstance(padding, int):
padding = (padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation)
kH, kW = kernel_size
dH, dW = dilation
pH, pW = padding
sH, sW = stride
# 计算输出窗口数量
out_h = (H + 2 * pH - dH * (kH - 1) - 1) // sH + 1
out_w = (W + 2 * pW - dW * (kW - 1) - 1) // sW + 1
# 构造索引
batch_idx = []
channel_idx = []
h_idx = []
w_idx = []
for b in range(B):
for c in range(C):
for i in range(out_h):
for j in range(out_w):
h_start = i * sH
w_start = j * sW
for kh in range(kH):
for kw in range(kW):
batch_idx.append(b)
channel_idx.append(c)
h_idx.append(h_start + kh * dH)
w_idx.append(w_start + kw * dW)
# 转换为 tensor
batch_idx = torch.tensor(batch_idx, dtype=torch.long)
channel_idx = torch.tensor(channel_idx, dtype=torch.long)
h_idx = torch.tensor(h_idx, dtype=torch.long)
w_idx = torch.tensor(w_idx, dtype=torch.long)
return (batch_idx, channel_idx, h_idx, w_idx)
def test(batch_size=2, channels=2, height=2, width=2, kernel_size=2, stride=1, dilation=1):
x = torch.randn(batch_size, channels, height, width)
index = generate_unfold_index(input_shape=x.shape, kernel_size=kernel_size, stride=stride, dilation=dilation)
unfolded_by_index = x[index]
unfolded_by_f = F.unfold(x, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation)
unfolded_by_index = unfolded_by_index.view(batch_size, channels, -1, kernel_size * kernel_size)
unfolded_by_index = unfolded_by_index.permute(0, 1, 3, 2)
unfolded_by_index = unfolded_by_index.reshape(unfolded_by_f.shape)
print("Shape of unfolded_by_index:", unfolded_by_index.shape)
print("Shape of unfolded_by_f:", unfolded_by_f.shape)
# 检查是否一致
return torch.allclose(unfolded_by_index, unfolded_by_f)
if __name__ == "__main__":
batch_size = 2
channels = 2
height = 2
width = 2
kernel_size = 2
stride = 1
dilation = 1
result = test(batch_size, channels, height, width, kernel_size, stride, dilation)
print("Are the results equal?", result)
batch_size = 5
channels = 3
height = 4
width = 5
kernel_size = 2
stride = 2
dilation = 2
result = test(batch_size, channels, height, width, kernel_size, stride, dilation)
print("Are the results equal?", result)