Compare commits
1 Commits
master
...
query_matm
Author | SHA1 | Date |
---|---|---|
|
7cf19b15cf |
|
@ -1,13 +1,3 @@
|
|||
__pycache__
|
||||
.vscode
|
||||
|
||||
*.npy
|
||||
temp
|
||||
lightning_logs
|
||||
|
||||
checkpoints
|
||||
build
|
||||
log
|
||||
logs
|
||||
|
||||
mlruns
|
||||
*.txt
|
|
@ -1,44 +0,0 @@
|
|||
import torch
|
||||
|
||||
# 1. 定义模型(需满足静态形状和静态控制流)
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = torch.nn.Linear(3 * 224 * 224, 1000)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(x.size(0), -1) # 静态形状操作
|
||||
return torch.relu(self.fc(x)) # 避免动态控制流
|
||||
|
||||
model = SimpleModel().cuda()
|
||||
|
||||
# 2. 准备静态输入/输出占位张量
|
||||
static_input = torch.randn(16, 3, 224, 224, device='cuda')
|
||||
static_output = torch.zeros(16, 1000, device='cuda')
|
||||
|
||||
# 3. 预热阶段(必须在非默认流)
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3): # 预热3次
|
||||
static_output = model(static_input)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
# 4. 创建并捕获CUDA图
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
# 注意:此处操作不会实际执行,仅记录计算图
|
||||
static_output = model(static_input)
|
||||
|
||||
# 5. 使用图计算(更新数据+重放)
|
||||
def run_graph(new_input):
|
||||
# 将新数据复制到捕获的输入张量
|
||||
static_input.copy_(new_input)
|
||||
# 重放计算图
|
||||
g.replay()
|
||||
return static_output.clone() # 返回结果副本
|
||||
|
||||
# 测试
|
||||
new_data = torch.randn(16, 3, 224, 224, device='cuda')
|
||||
result = run_graph(new_data)
|
||||
print(result.shape) # torch.Size([16, 1000])
|
|
@ -1,65 +0,0 @@
|
|||
import torch
|
||||
|
||||
data_size = 20
|
||||
|
||||
|
||||
torch.manual_seed(1234)
|
||||
x_data = torch.abs(torch.randn((data_size)))
|
||||
y_data = 10 * x_data
|
||||
|
||||
|
||||
class MUL(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
ctx.save_for_backward(input, weight)
|
||||
return input * weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_weight = input * grad_output
|
||||
grad_input = weight * grad_output
|
||||
print(f"grad_output:{grad_output.item():.4f}")
|
||||
return grad_input, grad_weight
|
||||
|
||||
|
||||
class LinearModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.tensor([[1.0]]), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
return MUL.apply(x, self.weight)
|
||||
return x * self.weight
|
||||
|
||||
|
||||
model = LinearModel()
|
||||
criterion = torch.nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
|
||||
loss_history = []
|
||||
|
||||
|
||||
for step in range(data_size):
|
||||
y_pred = model(x_data[step])
|
||||
|
||||
# loss = criterion(y_pred, y_data[step])
|
||||
# loss = y_data[step] / y_pred - 1.0
|
||||
# loss = torch.abs(y_data[step] - y_pred)
|
||||
# loss = y_data[step] - y_pred
|
||||
loss = (y_data[step] - y_pred) * (y_data[step] - y_pred)
|
||||
|
||||
loss_history.append(loss.item())
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
if (step + 1) % 1 == 0:
|
||||
w = model.weight.item()
|
||||
print(
|
||||
f"Step {step+1}: w={w:.4f} loss={loss.item():.6f} input:{x_data[step]:.4f} output:{y_pred.item():.4f} label:{y_data[step].item():.4f} w_grad:{model.weight.grad.item():.4f}"
|
||||
)
|
||||
|
||||
optimizer.step() # w = w - lr * ∇w[5](@ref)
|
||||
|
||||
test_x = torch.tensor([[5.0]])
|
||||
print(f"\n预测结果: x=5 → y={model(test_x).item():.2f}")
|
377
binary/mnist.py
|
@ -1,377 +0,0 @@
|
|||
import os
|
||||
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.profiler import profile, ProfilerActivity, record_function
|
||||
from unfold import generate_unfold_index
|
||||
import datetime
|
||||
|
||||
torch.manual_seed(1234)
|
||||
np.random.seed(1234)
|
||||
torch.cuda.manual_seed_all(1234)
|
||||
|
||||
BS = 16
|
||||
LR = 0.01
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # MNIST数据集的均值和标准差
|
||||
)
|
||||
|
||||
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
|
||||
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True, drop_last=True, num_workers=4)
|
||||
test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=True)
|
||||
|
||||
|
||||
class Lut(torch.autograd.Function):
|
||||
# input [batch, group, bits ]
|
||||
# output [batch, group ]
|
||||
# weight [2**bits, group ]
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, index):
|
||||
ind = ((input > 0).long() * index).sum(dim=-1)
|
||||
output = torch.gather(weight, 0, ind)
|
||||
output = (output > 0).float()
|
||||
output = (output - 0.5) * 2.0
|
||||
ctx.save_for_backward(input, weight, ind, output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, ind, output = ctx.saved_tensors
|
||||
grad_input = grad_weight = None
|
||||
bits = input.shape[2]
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = torch.zeros_like(weight)
|
||||
grad_weight.scatter_add_(0, ind, grad_output)
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
|
||||
# grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||
grad_input = grad_output
|
||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||
output = output.unsqueeze(-1).repeat(1, 1, bits)
|
||||
in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||
grad_input = grad_input * in_sign
|
||||
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||
|
||||
# grad_input = grad_output
|
||||
# grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||
# output = output.unsqueeze(-1).repeat(1, 1, bits)
|
||||
# in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||
# out_sign = ((output > 0).float() - 0.5) * 2.0
|
||||
# grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
|
||||
# grad_input = grad_input * in_sign * (out_sign * grad_sign)
|
||||
# grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||
|
||||
# 需要一个动态的调整系数
|
||||
# 能稳定的收敛
|
||||
|
||||
# print(in_sign[0].detach().cpu().numpy())
|
||||
# print(out_sign[0].detach().cpu().numpy())
|
||||
# print(grad_sign[0].detach().cpu().numpy())
|
||||
# print(grad_input[0].detach().cpu().numpy())
|
||||
|
||||
return grad_input, grad_weight, None, None
|
||||
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleCNN, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.bn = nn.BatchNorm1d(320 * 4)
|
||||
self.fc1 = nn.Linear(160, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.relu = nn.ReLU()
|
||||
self.weight = nn.Parameter(torch.randn(160, pow(2, 8)))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.pool(self.conv1(x)))
|
||||
x = self.relu((self.conv2(x)))
|
||||
x = x.view(-1, 320 * 4)
|
||||
x = self.bn(x)
|
||||
x = x.view(-1, 160, 8)
|
||||
|
||||
x = Lut.apply(x, self.weight)
|
||||
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class LutGroup(nn.Module):
|
||||
def __init__(self, group, groupBits, groupRepeat=1):
|
||||
assert groupBits > 1
|
||||
super(LutGroup, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group)))
|
||||
self.group = group
|
||||
self.groupBits = groupBits
|
||||
self.groupRepeat = groupRepeat
|
||||
self.index = nn.Parameter(2 ** torch.arange(groupBits - 1, -1, -1), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
# input [ batch, group * groupBits ]
|
||||
# output [ batch, group * groupRepeat ]
|
||||
batch = x.shape[0]
|
||||
x = x.reshape(batch, -1, self.groupBits)
|
||||
if self.groupRepeat > 1:
|
||||
x = x.repeat(1, self.groupRepeat, 1)
|
||||
x = Lut.apply(x, self.weight, self.index)
|
||||
return x
|
||||
|
||||
|
||||
class LutCnn(nn.Module):
|
||||
def __init__(self, channel_repeat, input_shape, kernel_size, stride, dilation, fc=False):
|
||||
super(LutCnn, self).__init__()
|
||||
B, C, H, W = input_shape
|
||||
self.input_shape = input_shape
|
||||
self.kernel_size = kernel_size
|
||||
self.channel_repeat = channel_repeat
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation)
|
||||
self.batch_idx = nn.Parameter(batch_idx, requires_grad=False)
|
||||
self.channel_idx = nn.Parameter(channel_idx, requires_grad=False)
|
||||
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
|
||||
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
|
||||
groupBits = kernel_size * kernel_size
|
||||
group = int(len(self.batch_idx) / B / groupBits)
|
||||
self.lut = LutGroup(group, groupBits, channel_repeat)
|
||||
self.fc = fc
|
||||
if fc:
|
||||
self.lutc = LutGroup(group, channel_repeat * C, channel_repeat * C)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = self.input_shape
|
||||
x = x.view(self.input_shape)
|
||||
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
|
||||
x = x.view(B, -1, self.kernel_size * self.kernel_size)
|
||||
x = self.lut(x)
|
||||
if self.fc:
|
||||
x = x.view(B, -1, self.channel_repeat)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.lutc(x)
|
||||
return x
|
||||
|
||||
|
||||
class SimpleBNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleBNN, self).__init__()
|
||||
# self.w = nn.Parameter(torch.randn(3, 784 * 8))
|
||||
# self.b = nn.Parameter(torch.zeros(3, 784 * 8))
|
||||
|
||||
self.w = nn.Parameter(torch.randn(3, 784))
|
||||
self.b = nn.Parameter(torch.zeros(3, 784))
|
||||
|
||||
# channel_repeat, input_shape, kernel_size, stride, dilation, fc
|
||||
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1, False)
|
||||
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1, False)
|
||||
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1, False)
|
||||
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1, False)
|
||||
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x, t):
|
||||
batch = x.shape[0]
|
||||
# x = x.view(batch, -1)
|
||||
|
||||
# 变换x [-0.5:0.5] 到 0-255,然后按照二进制展开成8个值
|
||||
# x = (x * 256 + 128).clamp(0, 255).to(torch.uint8)
|
||||
# xx = torch.arange(7, -1, -1).to(x.device)
|
||||
# bits = (x.unsqueeze(-1) >> xx) & 1
|
||||
# x = bits.view(batch, -1)
|
||||
# x = x.float() - 0.5
|
||||
|
||||
# x = (x > 0).float()
|
||||
|
||||
# q = x * self.w[0] + self.b[0]
|
||||
# k = x * self.w[1] + self.b[1]
|
||||
# v = x * self.w[2] + self.b[2]
|
||||
# q = q.view(batch, -1, 1)
|
||||
# k = k.view(batch, 1, -1)
|
||||
# v = v.view(batch, -1, 1)
|
||||
# kq = q @ k
|
||||
# kqv = kq @ v
|
||||
# kqv = kqv.view(batch, -1, 8)
|
||||
# x = kqv
|
||||
|
||||
#########################
|
||||
|
||||
# # x = (x > 0) << xx
|
||||
# # x = x.sum(2)
|
||||
# # x = x.view(batch, 1, 28, 28)
|
||||
# # x = (x - 128.0) / 256.0
|
||||
|
||||
# x = (x > 0).float()
|
||||
# x = x.view(batch, 1, 28, 28)
|
||||
|
||||
# x = self.relu(self.pool(self.conv1(x)))
|
||||
# x = self.relu(self.pool((self.conv2(x))))
|
||||
# x = x.view(-1, 320)
|
||||
# x = self.relu(self.fc1(x))
|
||||
# x = self.fc2(x)
|
||||
|
||||
#########################
|
||||
|
||||
x = self.lnn1(x)
|
||||
x = self.lnn2(x)
|
||||
x = self.lnn3(x)
|
||||
x = self.lnn4(x)
|
||||
x = self.lnn5(x)
|
||||
|
||||
# xx = 2 ** torch.arange(7, -1, -1).to(x.device)
|
||||
x = x.view(batch, -1, 8)
|
||||
# x = x * xx
|
||||
# x = (x - 128.0) / 256.0
|
||||
x = x.sum(2)
|
||||
|
||||
return x
|
||||
|
||||
def printWeight(self):
|
||||
pass
|
||||
|
||||
|
||||
class SimpleLNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleLNN, self).__init__()
|
||||
# group, groupBits, groupRepeat
|
||||
self.lutg1 = LutGroup(1, 10, 4)
|
||||
self.lutg2 = LutGroup(1, 4, 10)
|
||||
|
||||
def forward(self, x, t):
|
||||
batch = x.shape[0]
|
||||
|
||||
x = torch.zeros_like(t).unsqueeze(-1).repeat(1, 10)
|
||||
x[torch.arange(0, batch), t] = 1
|
||||
|
||||
x = self.lutg1(x)
|
||||
x = self.lutg2(x)
|
||||
|
||||
return x
|
||||
|
||||
def printWeight(self):
|
||||
print("self.lutg1")
|
||||
print(self.lutg1.weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
|
||||
print("=============================")
|
||||
print("=============================")
|
||||
print("self.lutg1.grad")
|
||||
print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
|
||||
print("=============================")
|
||||
print("=============================")
|
||||
# print("self.lutg2")
|
||||
# print(self.lutg2.weight.detach().cpu().numpy())
|
||||
# print("=============================")
|
||||
# print("=============================")
|
||||
|
||||
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
# model = SimpleCNN().to(device)
|
||||
# model = SimpleBNN().to(device)
|
||||
model = SimpleLNN().to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
||||
|
||||
tbWriter = None
|
||||
|
||||
|
||||
def AddScalar(tag, value, epoch):
|
||||
global tbWriter
|
||||
if not tbWriter:
|
||||
current_time = datetime.datetime.now().strftime("%m%d-%H%M%S")
|
||||
tbWriter = SummaryWriter(f"log/{current_time}")
|
||||
hparam_dict = {"lr": LR, "batch_size": BS}
|
||||
tbWriter.add_hparams(hparam_dict, {}, run_name=f"./")
|
||||
|
||||
tbWriter.add_scalar(tag, value, epoch)
|
||||
|
||||
|
||||
def train(epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data, target)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
AddScalar("loss", loss, epoch)
|
||||
if batch_idx % 1024 == 0 and batch_idx > 0:
|
||||
print(
|
||||
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
||||
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
|
||||
)
|
||||
|
||||
|
||||
def test(epoch):
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data, target)
|
||||
test_loss += criterion(output, target).item()
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
accuracy = 100.0 * correct / len(test_loader.dataset)
|
||||
AddScalar("accuracy", accuracy, epoch)
|
||||
print(
|
||||
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
|
||||
f"({accuracy:.0f}%)\n"
|
||||
)
|
||||
model.printWeight()
|
||||
|
||||
|
||||
def profiler():
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
|
||||
with record_function("model_inference"):
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
if batch_idx > 10:
|
||||
prof.export_chrome_trace("local.json")
|
||||
assert False
|
||||
|
||||
|
||||
# profiler()
|
||||
|
||||
for epoch in range(1, 300):
|
||||
train(epoch)
|
||||
test(epoch)
|
||||
|
||||
# torch.save(model.state_dict(), "mnist_cnn.pth")
|
||||
print("Model saved to mnist_cnn.pth")
|
||||
|
||||
if tbWriter:
|
||||
tbWriter.close()
|
|
@ -1,39 +0,0 @@
|
|||
|
||||
## 定义
|
||||
|
||||
### 梯度
|
||||
1. 预测变化对整体损失 L 的影响程度, 参数θ在当前点的变化对损失值的影响方向和幅度
|
||||
2. grad物理含义:loss L = 0 的时候,需要的变化量
|
||||
3. w = w - grad * lr 若梯度为正,权重应该减小
|
||||
4. w_grad = output_grad * input, input越大,grad越大,w调整的量越大
|
||||
1. input越大->对weight的放大倍数越大->才能达到loss=0的调整量
|
||||
2. 所以,weight的调整比例应该越大,才能弥补小input的loss=0
|
||||
5. 梯度的大小反应了影响损失的“快慢”
|
||||
1. 梯度大 → 损失曲面陡峭 → 微小变化导致损失剧烈波动
|
||||
2. 梯度大,微小变化就可以使得loss变化一个单位
|
||||
2. 梯度大,和loss的关系越相关
|
||||
|
||||
## 问题
|
||||
|
||||
* 在一串的binary lut网络中
|
||||
1. 如果每个卷积的channel相互之间没有关系
|
||||
2. 中间插入一层,交换各个channel之间的数据,生成新的相同数量的channel
|
||||
3. 方法2的效果很差
|
||||
1. 好像是破坏了训练,可能是训练的方法不对,梯度下降不适合这种模型
|
||||
2. 最终分类是10,10个输出之间有关系就会很差?
|
||||
* LUT层梯度计算的问题
|
||||
1. 发现LUT的反向计算grad_weight没有考虑weight本来的正负符号,grad表示的是>0的置信度
|
||||
1. 考虑梯度符号之后,由于整个选择的梯度是一个,没有机会变换到别的
|
||||
2. weight_grad:后面一级计算的grad_input,对于当前weight的grad是一样的,没有机会变换到别的
|
||||
3. 当前的选择不可信后的grad会导致直接0/1整体取反,而不会改变分布
|
||||
2. 输出级别用于criterion的LUT的梯度计算和基于Binary的输出1概率的梯度的计算方式不一样
|
||||
1. LUT的是输出1的概率,不能直接和criterion的梯度进行下降
|
||||
3. grad input的目标是,要不要更换别的index
|
||||
1. 梯度的大小表示更换别的index的程度
|
||||
2. 梯度正负无所谓,需要随机?
|
||||
|
||||
* unfold输出的维度不对
|
||||
1. LUT不是对卷积核进行计算,更容易收敛,但是精度没有更高
|
||||
2. LUT不是对卷积核进行计算,不容易收敛,精度差不多
|
||||
* 好像只有AdamW优化器可以优化参数,明显收敛
|
||||
* LUT的输出进行二值化对精度有影响,大概94->81
|
100
binary/unfold.py
|
@ -1,100 +0,0 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def to_binary_tensor(input_tensor, bits):
|
||||
int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64)
|
||||
shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device)
|
||||
binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1
|
||||
return binary_bits
|
||||
|
||||
|
||||
def generate_unfold_index(input_shape, kernel_size, stride, dilation=1):
|
||||
padding = 0
|
||||
B, C, H, W = input_shape
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding)
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation, dilation)
|
||||
|
||||
kH, kW = kernel_size
|
||||
dH, dW = dilation
|
||||
pH, pW = padding
|
||||
sH, sW = stride
|
||||
|
||||
# 计算输出窗口数量
|
||||
out_h = (H + 2 * pH - dH * (kH - 1) - 1) // sH + 1
|
||||
out_w = (W + 2 * pW - dW * (kW - 1) - 1) // sW + 1
|
||||
|
||||
# 构造索引
|
||||
batch_idx = []
|
||||
channel_idx = []
|
||||
h_idx = []
|
||||
w_idx = []
|
||||
|
||||
for b in range(B):
|
||||
for c in range(C):
|
||||
for i in range(out_h):
|
||||
for j in range(out_w):
|
||||
h_start = i * sH
|
||||
w_start = j * sW
|
||||
for kh in range(kH):
|
||||
for kw in range(kW):
|
||||
batch_idx.append(b)
|
||||
channel_idx.append(c)
|
||||
h_idx.append(h_start + kh * dH)
|
||||
w_idx.append(w_start + kw * dW)
|
||||
|
||||
# 转换为 tensor
|
||||
batch_idx = torch.tensor(batch_idx, dtype=torch.long)
|
||||
channel_idx = torch.tensor(channel_idx, dtype=torch.long)
|
||||
h_idx = torch.tensor(h_idx, dtype=torch.long)
|
||||
w_idx = torch.tensor(w_idx, dtype=torch.long)
|
||||
|
||||
return (batch_idx, channel_idx, h_idx, w_idx)
|
||||
|
||||
|
||||
def test(batch_size=2, channels=2, height=2, width=2, kernel_size=2, stride=1, dilation=1):
|
||||
|
||||
x = torch.randn(batch_size, channels, height, width)
|
||||
index = generate_unfold_index(input_shape=x.shape, kernel_size=kernel_size, stride=stride, dilation=dilation)
|
||||
unfolded_by_index = x[index]
|
||||
|
||||
unfolded_by_f = F.unfold(x, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation)
|
||||
|
||||
unfolded_by_index = unfolded_by_index.view(batch_size, channels, -1, kernel_size * kernel_size)
|
||||
unfolded_by_index = unfolded_by_index.permute(0, 1, 3, 2)
|
||||
unfolded_by_index = unfolded_by_index.reshape(unfolded_by_f.shape)
|
||||
|
||||
print("Shape of unfolded_by_index:", unfolded_by_index.shape)
|
||||
print("Shape of unfolded_by_f:", unfolded_by_f.shape)
|
||||
|
||||
# 检查是否一致
|
||||
return torch.allclose(unfolded_by_index, unfolded_by_f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
batch_size = 2
|
||||
channels = 2
|
||||
height = 2
|
||||
width = 2
|
||||
kernel_size = 2
|
||||
stride = 1
|
||||
dilation = 1
|
||||
result = test(batch_size, channels, height, width, kernel_size, stride, dilation)
|
||||
print("Are the results equal?", result)
|
||||
|
||||
batch_size = 5
|
||||
channels = 3
|
||||
height = 4
|
||||
width = 5
|
||||
kernel_size = 2
|
||||
stride = 2
|
||||
dilation = 2
|
||||
result = test(batch_size, channels, height, width, kernel_size, stride, dilation)
|
||||
print("Are the results equal?", result)
|
|
@ -0,0 +1,88 @@
|
|||
## data flow
|
||||
|
||||
```
|
||||
query -> "你好"
|
||||
┃
|
||||
tokenizer -> input_ids [6]
|
||||
┃
|
||||
rotary_pos_emb embedding -> [1, 6, 4096]
|
||||
╲ ╱
|
||||
GLMBlock x 28 -> [6, 1, 4096] <━━━┓
|
||||
RMSNorm -> [6, 1, 4096] ┃ final_layernorm
|
||||
[-1:] -> [1, 1, 4096] ┃
|
||||
Linear -> [1, 1, 65024] ┃ output_layer 4096->65024
|
||||
softmax -> [1, 65024] ┃
|
||||
multinomial -> [1] ┃
|
||||
cat([input_ids, next_tokens]) ━━━┛
|
||||
↓
|
||||
tokenizer.decode( )
|
||||
|
||||
# GLMBlock
|
||||
|
||||
input
|
||||
╱ ╲
|
||||
╱ RMSNorm hidden_states -> [6, 1, 4096]
|
||||
┃ ┋ ╱ ╲
|
||||
┃ ┋ ┃ pow(2) -> [6, 1, 4096]
|
||||
┃ ┋ ┃ ┃
|
||||
┃ ┋ ┃ mean -> [6, 1, 1]
|
||||
┃ ┋ ┃ ↓
|
||||
┃ ┋ ┃ rsqrt( + eps) -> [6, 1, 1]
|
||||
┃ ┋ ╲ ╱
|
||||
┃ ┋ mul -> [6, 1, 4096]
|
||||
┃ ┋ ╲ weight -> [4096]
|
||||
┃ ┋ ╲ ╱
|
||||
┃ RMSNorm mul -> [6, 1, 4096]
|
||||
┃ ╲
|
||||
┃ SelfAttention x -> [6, 1, 4096]
|
||||
┃ ┋ ┃
|
||||
┃ ┋ Linear -> [6, 1, 4608] 4096->4608
|
||||
┃ ┋ ╱ ┃ ╲
|
||||
┃ ┋ q k v [6, 1, 32, 128] [6, 1, 2, 128] [6, 1, 2, 128]
|
||||
┃ ┋ ╱ ┃ ╲
|
||||
┃ ┋ pos_emb pos_emb ╲ -> cat( x0*y0-x1*y1, x1*y0-x0*y1, x, y)
|
||||
┃ ┋ ┃ ┃ ┃
|
||||
┃ ┋ ┃ expand expand -> [6, 1, 32, 128] [6, 1, 32, 128]
|
||||
┃ ┋ permute permute permute -> [1, 32, 6, 128] [1, 32, 6, 128] [1, 32, 6, 128]
|
||||
┃ ┋ ╲ ╱ ┃
|
||||
┃ ┋ ┏---- matmul ┃ -> [1, 32, 6, 128] [1, 32, 128, 6] -> [1, 32, 6, 6]
|
||||
┃ ┋ ┃ add(mask) ╱ -> [1, 32, 6, 6]
|
||||
┃ ┋ attention┃ softmax ╱ -> [1, 32, 6, 6] dim:-1
|
||||
┃ ┋ ┃ ╲ ╱
|
||||
┃ ┋ ┗---- matmul -> [1, 32, 6, 6] [1, 32, 6, 128] -> [1, 32, 6, 128] -> [6, 1, 4096]
|
||||
┃ SelfAttention Linear -> [6, 1, 4096] 4096->4096
|
||||
┃ ╱
|
||||
┃ dropout
|
||||
╲ ╱
|
||||
Add
|
||||
╱ ╲
|
||||
┃ RMSNorm hidden_states -> [6, 1, 4096]
|
||||
┃ ┋ ╱ ╲
|
||||
┃ ┋ ┃ pow(2) -> [6, 1, 4096]
|
||||
┃ ┋ ┃ ┃
|
||||
┃ ┋ ┃ mean -> [6, 1, 1]
|
||||
┃ ┋ ┃ ↓
|
||||
┃ ┋ ┃ rsqrt( + eps) -> [6, 1, 1]
|
||||
┃ ┋ ╲ ╱
|
||||
┃ ┋ mul -> [6, 1, 4096]
|
||||
┃ ┋ ╲ weight -> [4096]
|
||||
┃ ┋ ╲ ╱
|
||||
┃ RMSNorm mul -> [6, 1, 4096]
|
||||
┃ ╱
|
||||
┃ mlp ╱
|
||||
┃ ┋ Linear -> [6, 1, 27392] 4096->27392
|
||||
┃ ┋ ╱ ╲
|
||||
┃ ┋ chunk1 chunk0 -> [6, 1, 13696]
|
||||
┃ ┋ ┃ ┃ ╲
|
||||
┃ ┋ ┃ ┃ sigmoid
|
||||
┃ ┋ ┃ ┃ ╱
|
||||
┃ ┋ ┃ mul
|
||||
┃ ┋ ╲ ╱
|
||||
┃ ┋ mul -> [6, 1, 13696]
|
||||
┃ mlp Linear -> [6, 1, 4096] 13696->4096
|
||||
┃ ╱
|
||||
┃ dropout
|
||||
┃ ╱
|
||||
Add
|
||||
|
||||
```
|
104
dataset/MNBVC.py
|
@ -1,104 +0,0 @@
|
|||
import argparse
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import datasets
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||
|
||||
from transformers import (
|
||||
BatchEncoding,
|
||||
DefaultDataCollator,
|
||||
PreTrainedTokenizer,
|
||||
set_seed,
|
||||
)
|
||||
from tokenization_qwen import QWenTokenizer
|
||||
|
||||
dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki"]
|
||||
dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki/20230198/58.jsonl.gz"]
|
||||
num_proc = 8
|
||||
seed = 42
|
||||
|
||||
|
||||
def split_raw_dataset(
|
||||
raw_dataset: datasets.DatasetDict,
|
||||
) -> Tuple[datasets.Dataset, datasets.Dataset]:
|
||||
if "validation" in raw_dataset:
|
||||
train_dataset, val_dataset = raw_dataset["train"], raw_dataset["validation"]
|
||||
else:
|
||||
raw_dataset = raw_dataset["train"].train_test_split(test_size=0.05, seed=seed)
|
||||
train_dataset, val_dataset = raw_dataset["train"], raw_dataset["test"]
|
||||
return train_dataset, val_dataset
|
||||
|
||||
|
||||
def process_dataset(dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer) -> datasets.Dataset:
|
||||
def group_texts(examples: Dict[str, list], block_size: int = 512) -> BatchEncoding:
|
||||
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
||||
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
||||
total_length = (total_length // block_size) * block_size
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
result = BatchEncoding(result)
|
||||
return result
|
||||
|
||||
def format_inputs(examples):
|
||||
p = examples["段落"]
|
||||
mergeLine = ""
|
||||
for line in p:
|
||||
mergeLine += line["内容"] + "\n"
|
||||
return {"text": mergeLine}
|
||||
|
||||
def tokenize_inputs(
|
||||
examples: Dict[str, list],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
column_name: str = "text",
|
||||
) -> BatchEncoding:
|
||||
logits = tokenizer(examples[column_name], return_attention_mask=False)
|
||||
return logits
|
||||
|
||||
dataset_column_names = list(dataset.features)
|
||||
dataset = dataset.map(
|
||||
partial(format_inputs),
|
||||
batched=False,
|
||||
num_proc=num_proc,
|
||||
remove_columns=dataset_column_names,
|
||||
)
|
||||
dataset_column_names = list(dataset.features)
|
||||
dataset = dataset.map(
|
||||
partial(tokenize_inputs, tokenizer=tokenizer),
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=dataset_column_names,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
partial(group_texts, block_size=tokenizer.model_max_length),
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
set_seed(seed)
|
||||
tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken")
|
||||
train_dataset_list = []
|
||||
val_dataset_list = []
|
||||
for dn in dataset_name:
|
||||
datanames = dn.split(".")
|
||||
if datanames[-1] == "gz" and datanames[-2] == "jsonl":
|
||||
raw_dataset = datasets.load_dataset("json", data_files=dn)
|
||||
elif datanames[-1] == "json":
|
||||
raw_dataset = datasets.load_dataset("json", data_files=dn)
|
||||
else:
|
||||
raw_dataset = datasets.load_dataset(dn)
|
||||
train_dataset, val_dataset = split_raw_dataset(raw_dataset)
|
||||
train_dataset = process_dataset(train_dataset, tokenizer)
|
||||
val_dataset = process_dataset(val_dataset, tokenizer)
|
||||
train_dataset_list.append(train_dataset)
|
||||
val_dataset_list.append(val_dataset)
|
|
@ -0,0 +1,42 @@
|
|||
# q matmul k
|
||||
|
||||
## model
|
||||
|
||||
qwen/Qwen-1_8B-Chat
|
||||
|
||||
## input
|
||||
|
||||
```python
|
||||
<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
你好<|im_end|>
|
||||
<|im_start|>assistant
|
||||
莎是现代汉语的男性的名字,出自《诗经》中的“采采卷耳
|
||||
```
|
||||
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
|
@ -0,0 +1,67 @@
|
|||
当前状态:
|
||||
1.情绪: 狼狈不堪。
|
||||
2.注意力: 检查自己的左手臂和膝盖是否有红肿或擦痕。
|
||||
3.当前在思考什么: 应该及时处理伤口,免得后患无穷。也许回家再处理会更明智一点。
|
||||
4.生理状态: 左手臂和膝盖疼痛不已,可能摔下有伤。
|
||||
发现新情况: 发现膝盖流血了
|
||||
请按照下面()里面回答我的问题,尽量得详细描述每一项,不要给其他的建议
|
||||
1.情绪:()
|
||||
2.注意力()
|
||||
3.当前在思考什么()
|
||||
4.生理状态()
|
||||
|
||||
|
||||
当前状态:
|
||||
情绪: 狼狈不堪。你感到困惑、不安和不舒服,可能还有一些惊恐或焦虑的情绪。这是因为你的身体受伤并且发现膝盖流血,这种情况会引起担忧和紧张情绪。
|
||||
注意力: 你的注意力仍然集中在检查自己的左手臂和膝盖是否有红肿或擦痕上。你关注的重点是身体的伤势和状况,以及现在又出现了膝盖流血的情况,这进一步引起了你对伤势的关注和紧张。
|
||||
当前在思考什么: 你正在思考应该如何及时处理伤口,以免留下后患。你之前考虑过回家再处理伤口,但现在面对膝盖流血的情况,你可能会重新评估并考虑寻求更及时的医疗帮助。你可能在思考是否需要立即就医或寻求其他形式的紧急救助。
|
||||
生理状态: 你的左手臂和膝盖疼痛不已,这表明可能发生了摔倒或其他伤害。这种持续的疼痛会对你的日常活动和舒适感造成影响,并进一步增加了你的焦虑和不适感。此外,你现在还发现膝盖流血,这需要特别关注和处理。
|
||||
发现新情况: 有一辆救护车过来了
|
||||
请按照下面()里面回答我的问题,尽量得详细描述每一项,不要给其他的建议
|
||||
1.情绪:()
|
||||
2.注意力()
|
||||
3.当前在思考什么()
|
||||
4.生理状态()
|
||||
|
||||
|
||||
当前状态:
|
||||
情绪: 狼狈不堪。你的情绪可能有些紧张和不安,但同时也有一丝希望和宽慰,因为有一辆救护车过来了,这意味着有专业的医护人员将会提供帮助。
|
||||
注意力: 你的注意力现在可能转移到了救护车的到来上。你可能在观察救护车的行动并期待医护人员的到来,以便能够得到适当的医疗和救助。
|
||||
当前在思考什么: 你正在思考救护车的到来是否能够及时帮助你处理伤口。你可能在考虑接下来的行动,如如何与医护人员沟通、告知他们你的伤势情况以及对伤口的处理需求等。
|
||||
生理状态: 你的左手臂和膝盖仍然疼痛,但现在有救护车的到来,你可以期待专业医护人员对你的伤势进行评估和处理。膝盖的出血也需要专业处理,救护车的到来为你提供了及时的医疗资源。
|
||||
发现新情况: 救护车开走了
|
||||
请按照下面()里面回答我的问题,尽量得详细描述每一项,不要给其他的建议
|
||||
1.情绪:()
|
||||
2.注意力()
|
||||
3.当前在思考什么()
|
||||
4.生理状态()
|
||||
|
||||
当前状态:
|
||||
情绪: 你可能感到失望、焦虑和困惑,因为救护车开走了,没有提供及时的医疗帮助。你可能感到担心自己的伤势没有得到妥善处理,这可能增加了你的不安情绪。
|
||||
注意力: 你的注意力可能再次转回到自己的伤势上,因为救护车离开了,你需要重新考虑接下来的行动。你可能在思考是否需要寻求其他的医疗资源或者联系其他人寻求帮助。
|
||||
当前在思考什么: 你正在思考应该采取什么措施来处理伤口和缓解疼痛。你可能会考虑联系其他的医疗机构或者寻求朋友、家人的建议和支持。你也可能在思考如何前往医疗机构或医院寻求进一步的治疗。
|
||||
生理状态: 你的左手臂和膝盖仍然疼痛,而且现在没有专业的医护人员提供帮助。你可能需要寻找其他方法来处理伤口和控制疼痛。你可以使用清洁的纱布或绷带暂时包扎伤口,并尽量避免过度活动,以免加重伤势。
|
||||
发现新情况: 发现休息一会儿后,膝盖不疼了,而且可以走路了
|
||||
请按照下面()里面回答我的问题,尽量得详细描述每一项,只要回答下面几项,不要给其他的建议
|
||||
1.情绪:()
|
||||
2.注意力()
|
||||
3.当前在思考什么()
|
||||
4.生理状态()
|
||||
|
||||
当前状态:
|
||||
情绪: 你可能感到一些松了一口气,因为休息一会儿后,膝盖的疼痛减轻了,你能够走路了。你可能还感到一些困惑,因为之前的疼痛情况可能与你的预期不符。你可能会感到一些轻微的希望和安慰,因为你的伤势表现出一些好转的迹象。
|
||||
注意力: 你的注意力可能会再次转移到你的身体感受上,因为你意识到膝盖的疼痛有所减轻,你能够走路了。你可能会开始思考接下来的行动,如是否需要进一步的医疗评估,或者是否可以自行处理伤口。
|
||||
当前在思考什么: 你正在思考如何进一步处理伤口和控制疼痛。考虑到你能够走路了,你可能会思考是否仍然需要寻求医疗资源,或者是否可以采取自我护理措施来处理伤口。你可能会考虑使用适当的药物或采取其他方法来缓解疼痛和促进伤口的康复。
|
||||
生理状态: 你的左手臂仍然疼痛,但你的膝盖疼痛减轻,并且你能够走路了。尽管如此,你的伤口仍然需要得到妥善处理和护理。你可以继续使用清洁的纱布或绷带来包扎伤口,以保持伤口的清洁和干燥。此外,继续避免过度活动是很重要的,以免加重伤势或引发其他并发症。
|
||||
发现新情况: 肚子很饿
|
||||
请按照下面()里面回答我的问题,尽量得详细描述每一项,只要回答下面几项,不要给其他的建议
|
||||
1.情绪:()
|
||||
2.注意力()
|
||||
3.当前在思考什么()
|
||||
4.生理状态()
|
||||
|
||||
|
||||
情绪: 你可能感到一些松了一口气,因为伤口的疼痛有所减轻,你能够走路了。然而,由于你肚子很饿,你可能会感到一些不舒服和分散注意力的情绪。你可能会希望尽快解决饥饿问题,以恢复体力和精神状态。
|
||||
注意力: 你的注意力可能会转移到你的饥饿感上,因为它变得非常强烈。你可能会开始思考如何解决肚子饿的问题,可能会考虑吃些东西来满足身体的需求。这可能会分散你对伤口处理和疼痛管理的注意力。
|
||||
当前在思考什么: 你正在思考如何满足饥饿感和解决肚子饿的问题。你可能会思考是否有食物可用,以及如何获取食物。同时,你可能也会思考如何平衡满足饥饿感和继续处理伤口的需要。
|
||||
生理状态: 你的左手臂仍然疼痛,但膝盖疼痛减轻,并且你能够走路了。然而,你现在感到非常饥饿,这可能对你的体力和精神状态产生影响。为了满足饥饿感,你可以尝试找到食物来满足身体的能量需求,并确保继续给伤口提供适当的护理。
|
|
@ -1 +0,0 @@
|
|||
saves
|
|
@ -1,5 +0,0 @@
|
|||
{
|
||||
"alpaca_zh_demo": {
|
||||
"file_name": "alpaca_zh_demo.json"
|
||||
}
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
model_name_or_path: Qwen/Qwen3-4B
|
||||
adapter_name_or_path: saves/qwen3-4b/lora/sft
|
||||
template: qwen3
|
||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang]
|
||||
trust_remote_code: true
|
|
@ -1,13 +0,0 @@
|
|||
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
|
||||
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||
template: llama3
|
||||
trust_remote_code: true
|
||||
|
||||
### export
|
||||
export_dir: output/llama3_lora_sft
|
||||
export_size: 5
|
||||
export_device: cpu # choices: [cpu, auto]
|
||||
export_legacy_format: false
|
|
@ -1,46 +0,0 @@
|
|||
### model
|
||||
model_name_or_path: Qwen/Qwen3-4B
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_rank: 8
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: alpaca_zh_demo
|
||||
template: qwen3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen3-4b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: tensorboard # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 4
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 5.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### eval
|
||||
# eval_dataset: alpaca_en_demo
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
|
@ -1,3 +0,0 @@
|
|||
outputs
|
||||
unsloth_compiled_cache
|
||||
wandb
|
|
@ -1,79 +0,0 @@
|
|||
from unsloth import FastLanguageModel, FastModel
|
||||
import torch
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from datasets import load_dataset
|
||||
max_seq_length = 2048 # Supports RoPE Scaling internally, so choose any!
|
||||
# Get LAION dataset
|
||||
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
|
||||
dataset = load_dataset("json", data_files = {"train" : url}, split = "train")
|
||||
|
||||
# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
|
||||
fourbit_models = [
|
||||
"unsloth/Meta-Llama-3.1-8B-bnb-4bit", # Llama-3.1 2x faster
|
||||
"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
|
||||
"unsloth/Meta-Llama-3.1-70B-bnb-4bit",
|
||||
"unsloth/Meta-Llama-3.1-405B-bnb-4bit", # 4bit for 405b!
|
||||
"unsloth/Mistral-Small-Instruct-2409", # Mistral 22b 2x faster!
|
||||
"unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
|
||||
"unsloth/Phi-3.5-mini-instruct", # Phi-3.5 2x faster!
|
||||
"unsloth/Phi-3-medium-4k-instruct",
|
||||
"unsloth/gemma-2-9b-bnb-4bit",
|
||||
"unsloth/gemma-2-27b-bnb-4bit", # Gemma 2x faster!
|
||||
|
||||
"unsloth/Llama-3.2-1B-bnb-4bit", # NEW! Llama 3.2 models
|
||||
"unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
|
||||
"unsloth/Llama-3.2-3B-bnb-4bit",
|
||||
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
|
||||
|
||||
"unsloth/Llama-3.3-70B-Instruct-bnb-4bit" # NEW! Llama 3.3 70B!
|
||||
] # More models at https://huggingface.co/unsloth
|
||||
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name = "unsloth/Qwen3-4B",
|
||||
max_seq_length = 2048, # Choose any for long context!
|
||||
load_in_4bit = False, # 4 bit quantization to reduce memory
|
||||
load_in_8bit = True, # [NEW!] A bit more accurate, uses 2x memory
|
||||
full_finetuning = False, # [NEW!] We have full finetuning now!
|
||||
# token = "hf_...", # use one if using gated models
|
||||
)
|
||||
|
||||
# Do model patching and add fast LoRA weights
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r = 16,
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",],
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0, # Supports any, but = 0 is optimized
|
||||
bias = "none", # Supports any, but = "none" is optimized
|
||||
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
||||
random_state = 3407,
|
||||
max_seq_length = max_seq_length,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model = model,
|
||||
train_dataset = dataset,
|
||||
tokenizer = tokenizer,
|
||||
args = SFTConfig(
|
||||
max_seq_length = max_seq_length,
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
warmup_steps = 10,
|
||||
max_steps = 60,
|
||||
logging_steps = 1,
|
||||
output_dir = "outputs",
|
||||
optim = "adamw_8bit",
|
||||
seed = 3407,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Go to https://github.com/unslothai/unsloth/wiki for advanced tips like
|
||||
# (1) Saving to GGUF / merging to 16bit for vLLM
|
||||
# (2) Continued training from a saved LoRA adapter
|
||||
# (3) Adding an evaluation loop / OOMs
|
||||
# (4) Customized chat templates
|
After Width: | Height: | Size: 7.9 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 9.4 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 9.9 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 9.9 KiB |
After Width: | Height: | Size: 9.6 KiB |
After Width: | Height: | Size: 9.5 KiB |
After Width: | Height: | Size: 9.3 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 9.1 KiB |
After Width: | Height: | Size: 9.7 KiB |
After Width: | Height: | Size: 9.7 KiB |
After Width: | Height: | Size: 8.4 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 9.8 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 9.0 KiB |
After Width: | Height: | Size: 13 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 13 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 13 KiB |
After Width: | Height: | Size: 13 KiB |
After Width: | Height: | Size: 13 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 9.6 KiB |
After Width: | Height: | Size: 14 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 12 KiB |