Refine mnist LUT by add lutCNN.
This commit is contained in:
parent
30dd319d8c
commit
dd71a5dedd
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
|
||||
# 1. 定义模型(需满足静态形状和静态控制流)
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = torch.nn.Linear(3 * 224 * 224, 1000)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(x.size(0), -1) # 静态形状操作
|
||||
return torch.relu(self.fc(x)) # 避免动态控制流
|
||||
|
||||
model = SimpleModel().cuda()
|
||||
|
||||
# 2. 准备静态输入/输出占位张量
|
||||
static_input = torch.randn(16, 3, 224, 224, device='cuda')
|
||||
static_output = torch.zeros(16, 1000, device='cuda')
|
||||
|
||||
# 3. 预热阶段(必须在非默认流)
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3): # 预热3次
|
||||
static_output = model(static_input)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
# 4. 创建并捕获CUDA图
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
# 注意:此处操作不会实际执行,仅记录计算图
|
||||
static_output = model(static_input)
|
||||
|
||||
# 5. 使用图计算(更新数据+重放)
|
||||
def run_graph(new_input):
|
||||
# 将新数据复制到捕获的输入张量
|
||||
static_input.copy_(new_input)
|
||||
# 重放计算图
|
||||
g.replay()
|
||||
return static_output.clone() # 返回结果副本
|
||||
|
||||
# 测试
|
||||
new_data = torch.randn(16, 3, 224, 224, device='cuda')
|
||||
result = run_graph(new_data)
|
||||
print(result.shape) # torch.Size([16, 1000])
|
117
binary/mnist.py
117
binary/mnist.py
|
@ -14,14 +14,15 @@ import torch.nn.functional as F
|
|||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.profiler import profile, ProfilerActivity, record_function
|
||||
from unfold import generate_unfold_index
|
||||
import datetime
|
||||
|
||||
torch.manual_seed(1234)
|
||||
np.random.seed(1234)
|
||||
torch.cuda.manual_seed_all(1234)
|
||||
|
||||
batch_size = 16
|
||||
lr = 0.001
|
||||
BS = 16
|
||||
LR = 0.001
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
@ -33,15 +34,8 @@ transform = transforms.Compose(
|
|||
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
|
||||
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
|
||||
|
||||
|
||||
def to_binary_tensor(input_tensor, bits):
|
||||
int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64)
|
||||
shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device)
|
||||
binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1
|
||||
return binary_bits
|
||||
train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False)
|
||||
|
||||
|
||||
class Lut(torch.autograd.Function):
|
||||
|
@ -133,6 +127,30 @@ class LutGroup(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class LutCnn(nn.Module):
|
||||
def __init__(self, output_c, input_shape, kernel_size, stride, dilation):
|
||||
super(LutCnn, self).__init__()
|
||||
self.output_c = output_c
|
||||
self.input_shape = input_shape
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation)
|
||||
self.batch_idx = nn.Parameter(batch_idx, requires_grad=False)
|
||||
self.channel_idx = nn.Parameter(channel_idx, requires_grad=False)
|
||||
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
|
||||
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
|
||||
self.lut = LutGroup(len(self.batch_idx), kernel_size * kernel_size, output_c)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = self.input_shape
|
||||
x = x.view(self.input_shape)
|
||||
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
|
||||
x = x.view(B, -1, self.kernel_size * self.kernel_size)
|
||||
x = self.lut(x)
|
||||
return x
|
||||
|
||||
|
||||
class SimpleBNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleBNN, self).__init__()
|
||||
|
@ -142,14 +160,12 @@ class SimpleBNN(nn.Module):
|
|||
self.w = nn.Parameter(torch.randn(3, 784))
|
||||
self.b = nn.Parameter(torch.zeros(3, 784))
|
||||
|
||||
self.lut2_p = LutGroup(784, 4, 8) # pool2 => 14*14*4
|
||||
self.lut3_p = LutGroup(8 * 14 * 14, 4, 1) # pool2 => 7*7*4
|
||||
self.lut4_p = LutGroup(8 * 5 * 5 * 3 * 3, 9, 1) # conv 3 => 5*5*8
|
||||
# self.lut4 = LutGroup(8 * 5 * 5, 8, 8) # conv 3 => 5*5*8
|
||||
self.lut5_p = LutGroup(8 * 3 * 3 * 3 * 3, 9, 1) # conv 3 => 3*3*16
|
||||
# self.lut5 = LutGroup(16 * 3 * 3, 16, 10) # conv 3 => 3*3*16
|
||||
self.lut6_p = LutGroup(8 * 3 * 3, 9, 10) # conv 3 => 128
|
||||
# self.lut7_p = LutGroup(8 * 16, 16, 10) # fc 128 => 8
|
||||
# output_c, input_shape, kernel_size, stride, dilation
|
||||
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1)
|
||||
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1)
|
||||
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1)
|
||||
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1)
|
||||
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
|
@ -160,7 +176,7 @@ class SimpleBNN(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
batch = x.shape[0]
|
||||
x = x.view(batch, -1)
|
||||
# x = x.view(batch, -1)
|
||||
|
||||
# 变换x [-0.5:0.5] 到 0-255,然后按照二进制展开成8个值
|
||||
# x = (x * 256 + 128).clamp(0, 255).to(torch.uint8)
|
||||
|
@ -200,54 +216,11 @@ class SimpleBNN(nn.Module):
|
|||
|
||||
#########################
|
||||
|
||||
# unfold
|
||||
# 原始输出 shape: [batch, channels * kH * kW, L]
|
||||
# 其中 L 是滑动窗口的数量
|
||||
|
||||
x = x.view(batch, 1, 28, 28)
|
||||
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
||||
# x = x.view(batch, 1, 4, -1)
|
||||
# x = x.permute(0, 1, 3, 2)
|
||||
x = x.reshape(batch, -1)
|
||||
x = self.lut2_p(x)
|
||||
|
||||
x = x.view(batch, 8, 14, 14)
|
||||
x = F.unfold(x, kernel_size=2, dilation=1, padding=0, stride=2)
|
||||
# x = x.view(batch, 8, 4, -1)
|
||||
# x = x.permute(0, 1, 3, 2)
|
||||
x = x.reshape(batch, -1)
|
||||
x = self.lut3_p(x)
|
||||
|
||||
x = x.view(batch, 8, 7, 7)
|
||||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||
# x = x.view(batch, 8, 9, -1)
|
||||
# x = x.permute(0, 1, 3, 2)
|
||||
x = x.reshape(batch, -1)
|
||||
x = self.lut4_p(x)
|
||||
|
||||
# x = x.view(batch, 8, 25)
|
||||
# x = x.permute(0, 2, 1) # batch 25 8
|
||||
# x = x.reshape(batch, -1) # batch 25*8
|
||||
# x = self.lut4(x) # batch 8*25
|
||||
|
||||
x = x.reshape(batch, 8, 5, 5)
|
||||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||
# x = x.view(batch, 8, 9, -1)
|
||||
# x = x.permute(0, 1, 3, 2)
|
||||
x = x.reshape(batch, -1)
|
||||
x = self.lut5_p(x)
|
||||
|
||||
# x = x.view(batch, 16, 3, 3)
|
||||
# x = x.permute(0, 2, 3, 1)
|
||||
# x = x.reshape(batch, -1)
|
||||
# x = self.lut5(x)
|
||||
|
||||
x = x.view(batch, 8, 3, 3)
|
||||
x = F.unfold(x, kernel_size=3, dilation=1, padding=0, stride=1)
|
||||
# x = x.view(batch, 8, 9, -1)
|
||||
# x = x.permute(0, 1, 3, 2)
|
||||
x = x.reshape(batch, -1)
|
||||
x = self.lut6_p(x)
|
||||
x = self.lnn1(x)
|
||||
x = self.lnn2(x)
|
||||
x = self.lnn3(x)
|
||||
x = self.lnn4(x)
|
||||
x = self.lnn5(x)
|
||||
|
||||
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
|
||||
x = x.view(batch, -1, 8)
|
||||
|
@ -262,7 +235,7 @@ torch.autograd.set_detect_anomaly(True)
|
|||
# model = SimpleCNN().to(device)
|
||||
model = SimpleBNN().to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
||||
|
||||
tbWriter = None
|
||||
|
||||
|
@ -272,7 +245,7 @@ def AddScalar(tag, value, epoch):
|
|||
if not tbWriter:
|
||||
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
|
||||
tbWriter = SummaryWriter(f"log/{current_time}")
|
||||
hparam_dict = {"lr": lr, "batch_size": batch_size}
|
||||
hparam_dict = {"lr": LR, "batch_size": BS}
|
||||
tbWriter.add_hparams(hparam_dict, {}, run_name=f"./")
|
||||
|
||||
tbWriter.add_scalar(tag, value, epoch)
|
||||
|
@ -328,8 +301,8 @@ def profiler():
|
|||
optimizer.step()
|
||||
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
|
||||
|
||||
if batch_idx > 5:
|
||||
break
|
||||
if batch_idx > 10:
|
||||
assert False
|
||||
|
||||
|
||||
# profiler()
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def to_binary_tensor(input_tensor, bits):
|
||||
int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64)
|
||||
shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device)
|
||||
binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1
|
||||
return binary_bits
|
||||
|
||||
|
||||
def generate_unfold_index(input_shape, kernel_size, stride, dilation=1):
|
||||
padding = 0
|
||||
B, C, H, W = input_shape
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding)
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation, dilation)
|
||||
|
||||
kH, kW = kernel_size
|
||||
dH, dW = dilation
|
||||
pH, pW = padding
|
||||
sH, sW = stride
|
||||
|
||||
# 计算输出窗口数量
|
||||
out_h = (H + 2 * pH - dH * (kH - 1) - 1) // sH + 1
|
||||
out_w = (W + 2 * pW - dW * (kW - 1) - 1) // sW + 1
|
||||
|
||||
# 构造索引
|
||||
batch_idx = []
|
||||
channel_idx = []
|
||||
h_idx = []
|
||||
w_idx = []
|
||||
|
||||
for b in range(B):
|
||||
for c in range(C):
|
||||
for i in range(out_h):
|
||||
for j in range(out_w):
|
||||
h_start = i * sH
|
||||
w_start = j * sW
|
||||
for kh in range(kH):
|
||||
for kw in range(kW):
|
||||
batch_idx.append(b)
|
||||
channel_idx.append(c)
|
||||
h_idx.append(h_start + kh * dH)
|
||||
w_idx.append(w_start + kw * dW)
|
||||
|
||||
# 转换为 tensor
|
||||
batch_idx = torch.tensor(batch_idx, dtype=torch.long)
|
||||
channel_idx = torch.tensor(channel_idx, dtype=torch.long)
|
||||
h_idx = torch.tensor(h_idx, dtype=torch.long)
|
||||
w_idx = torch.tensor(w_idx, dtype=torch.long)
|
||||
|
||||
return (batch_idx, channel_idx, h_idx, w_idx)
|
||||
|
||||
|
||||
def test(batch_size=2, channels=2, height=2, width=2, kernel_size=2, stride=1, dilation=1):
|
||||
|
||||
x = torch.randn(batch_size, channels, height, width)
|
||||
index = generate_unfold_index(input_shape=x.shape, kernel_size=kernel_size, stride=stride, dilation=dilation)
|
||||
unfolded_by_index = x[index]
|
||||
|
||||
unfolded_by_f = F.unfold(x, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation)
|
||||
|
||||
unfolded_by_index = unfolded_by_index.view(batch_size, channels, -1, kernel_size * kernel_size)
|
||||
unfolded_by_index = unfolded_by_index.permute(0, 1, 3, 2)
|
||||
unfolded_by_index = unfolded_by_index.reshape(unfolded_by_f.shape)
|
||||
|
||||
print("Shape of unfolded_by_index:", unfolded_by_index.shape)
|
||||
print("Shape of unfolded_by_f:", unfolded_by_f.shape)
|
||||
|
||||
# 检查是否一致
|
||||
return torch.allclose(unfolded_by_index, unfolded_by_f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
batch_size = 2
|
||||
channels = 2
|
||||
height = 2
|
||||
width = 2
|
||||
kernel_size = 2
|
||||
stride = 1
|
||||
dilation = 1
|
||||
result = test(batch_size, channels, height, width, kernel_size, stride, dilation)
|
||||
print("Are the results equal?", result)
|
||||
|
||||
batch_size = 5
|
||||
channels = 3
|
||||
height = 4
|
||||
width = 5
|
||||
kernel_size = 2
|
||||
stride = 2
|
||||
dilation = 2
|
||||
result = test(batch_size, channels, height, width, kernel_size, stride, dilation)
|
||||
print("Are the results equal?", result)
|
Loading…
Reference in New Issue