Add llamafactory, Refine binary LUT add SimpleLNN.
This commit is contained in:
parent
9a8434df61
commit
392f507945
|
@ -9,6 +9,5 @@ checkpoints
|
||||||
build
|
build
|
||||||
log
|
log
|
||||||
logs
|
logs
|
||||||
data
|
|
||||||
|
|
||||||
mlruns
|
mlruns
|
115
binary/mnist.py
115
binary/mnist.py
|
@ -22,7 +22,7 @@ np.random.seed(1234)
|
||||||
torch.cuda.manual_seed_all(1234)
|
torch.cuda.manual_seed_all(1234)
|
||||||
|
|
||||||
BS = 16
|
BS = 16
|
||||||
LR = 0.001
|
LR = 0.01
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
@ -46,13 +46,14 @@ class Lut(torch.autograd.Function):
|
||||||
def forward(ctx, input, weight, index):
|
def forward(ctx, input, weight, index):
|
||||||
ind = ((input > 0).long() * index).sum(dim=-1)
|
ind = ((input > 0).long() * index).sum(dim=-1)
|
||||||
output = torch.gather(weight, 0, ind)
|
output = torch.gather(weight, 0, ind)
|
||||||
ctx.save_for_backward(input, weight, ind)
|
|
||||||
output = (output > 0).float()
|
output = (output > 0).float()
|
||||||
|
output = (output - 0.5) * 2.0
|
||||||
|
ctx.save_for_backward(input, weight, ind, output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, weight, ind = ctx.saved_tensors
|
input, weight, ind, output = ctx.saved_tensors
|
||||||
grad_input = grad_weight = None
|
grad_input = grad_weight = None
|
||||||
bits = input.shape[2]
|
bits = input.shape[2]
|
||||||
|
|
||||||
|
@ -61,10 +62,33 @@ class Lut(torch.autograd.Function):
|
||||||
grad_weight.scatter_add_(0, ind, grad_output)
|
grad_weight.scatter_add_(0, ind, grad_output)
|
||||||
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
grad_input = grad_output * torch.gather(weight, 0, ind)
|
|
||||||
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
|
||||||
|
|
||||||
return grad_input, grad_weight, None
|
# grad_input = grad_output * torch.gather(weight, 0, ind)
|
||||||
|
grad_input = grad_output
|
||||||
|
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||||
|
output = output.unsqueeze(-1).repeat(1, 1, bits)
|
||||||
|
in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||||
|
grad_input = grad_input * in_sign
|
||||||
|
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||||
|
|
||||||
|
# grad_input = grad_output
|
||||||
|
# grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
||||||
|
# output = output.unsqueeze(-1).repeat(1, 1, bits)
|
||||||
|
# in_sign = ((input > 0).float() - 0.5) * 2.0
|
||||||
|
# out_sign = ((output > 0).float() - 0.5) * 2.0
|
||||||
|
# grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
|
||||||
|
# grad_input = grad_input * in_sign * (out_sign * grad_sign)
|
||||||
|
# grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
||||||
|
|
||||||
|
# 需要一个动态的调整系数
|
||||||
|
# 能稳定的收敛
|
||||||
|
|
||||||
|
# print(in_sign[0].detach().cpu().numpy())
|
||||||
|
# print(out_sign[0].detach().cpu().numpy())
|
||||||
|
# print(grad_sign[0].detach().cpu().numpy())
|
||||||
|
# print(grad_input[0].detach().cpu().numpy())
|
||||||
|
|
||||||
|
return grad_input, grad_weight, None, None
|
||||||
|
|
||||||
|
|
||||||
class SimpleCNN(nn.Module):
|
class SimpleCNN(nn.Module):
|
||||||
|
@ -97,7 +121,7 @@ class LutGroup(nn.Module):
|
||||||
def __init__(self, group, groupBits, groupRepeat=1):
|
def __init__(self, group, groupBits, groupRepeat=1):
|
||||||
assert groupBits > 1
|
assert groupBits > 1
|
||||||
super(LutGroup, self).__init__()
|
super(LutGroup, self).__init__()
|
||||||
self.weight = nn.Parameter(torch.randn(pow(2, groupBits), int(groupRepeat * group)))
|
self.weight = nn.Parameter(torch.ones(pow(2, groupBits), int(groupRepeat * group)))
|
||||||
self.group = group
|
self.group = group
|
||||||
self.groupBits = groupBits
|
self.groupBits = groupBits
|
||||||
self.groupRepeat = groupRepeat
|
self.groupRepeat = groupRepeat
|
||||||
|
@ -107,7 +131,7 @@ class LutGroup(nn.Module):
|
||||||
# input [ batch, group * groupBits ]
|
# input [ batch, group * groupBits ]
|
||||||
# output [ batch, group * groupRepeat ]
|
# output [ batch, group * groupRepeat ]
|
||||||
batch = x.shape[0]
|
batch = x.shape[0]
|
||||||
x = x.view(batch, -1, self.groupBits)
|
x = x.reshape(batch, -1, self.groupBits)
|
||||||
if self.groupRepeat > 1:
|
if self.groupRepeat > 1:
|
||||||
x = x.repeat(1, self.groupRepeat, 1)
|
x = x.repeat(1, self.groupRepeat, 1)
|
||||||
x = Lut.apply(x, self.weight, self.index)
|
x = Lut.apply(x, self.weight, self.index)
|
||||||
|
@ -115,11 +139,12 @@ class LutGroup(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class LutCnn(nn.Module):
|
class LutCnn(nn.Module):
|
||||||
def __init__(self, output_c, input_shape, kernel_size, stride, dilation):
|
def __init__(self, channel_repeat, input_shape, kernel_size, stride, dilation, fc=False):
|
||||||
super(LutCnn, self).__init__()
|
super(LutCnn, self).__init__()
|
||||||
B, C, H, W = input_shape
|
B, C, H, W = input_shape
|
||||||
self.input_shape = input_shape
|
self.input_shape = input_shape
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
|
self.channel_repeat = channel_repeat
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation)
|
batch_idx, channel_idx, h_idx, w_idx = generate_unfold_index(input_shape, kernel_size, stride, dilation)
|
||||||
|
@ -128,7 +153,11 @@ class LutCnn(nn.Module):
|
||||||
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
|
self.h_idx = nn.Parameter(h_idx, requires_grad=False)
|
||||||
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
|
self.w_idx = nn.Parameter(w_idx, requires_grad=False)
|
||||||
groupBits = kernel_size * kernel_size
|
groupBits = kernel_size * kernel_size
|
||||||
self.lut = LutGroup(len(self.batch_idx) / B / groupBits, groupBits, output_c)
|
group = int(len(self.batch_idx) / B / groupBits)
|
||||||
|
self.lut = LutGroup(group, groupBits, channel_repeat)
|
||||||
|
self.fc = fc
|
||||||
|
if fc:
|
||||||
|
self.lutc = LutGroup(group, channel_repeat * C, channel_repeat * C)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = self.input_shape
|
B, C, H, W = self.input_shape
|
||||||
|
@ -136,6 +165,10 @@ class LutCnn(nn.Module):
|
||||||
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
|
x = x[(self.batch_idx, self.channel_idx, self.h_idx, self.w_idx)]
|
||||||
x = x.view(B, -1, self.kernel_size * self.kernel_size)
|
x = x.view(B, -1, self.kernel_size * self.kernel_size)
|
||||||
x = self.lut(x)
|
x = self.lut(x)
|
||||||
|
if self.fc:
|
||||||
|
x = x.view(B, -1, self.channel_repeat)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.lutc(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -148,17 +181,13 @@ class SimpleBNN(nn.Module):
|
||||||
self.w = nn.Parameter(torch.randn(3, 784))
|
self.w = nn.Parameter(torch.randn(3, 784))
|
||||||
self.b = nn.Parameter(torch.zeros(3, 784))
|
self.b = nn.Parameter(torch.zeros(3, 784))
|
||||||
|
|
||||||
# output_c, input_shape, kernel_size, stride, dilation
|
# channel_repeat, input_shape, kernel_size, stride, dilation, fc
|
||||||
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1)
|
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1, False)
|
||||||
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1)
|
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1, False)
|
||||||
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1)
|
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1, False)
|
||||||
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1)
|
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1, False)
|
||||||
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
|
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
|
||||||
|
|
||||||
# self.lutg = LutGroup()
|
|
||||||
# class LutGroup(nn.Module):
|
|
||||||
# def __init__(self, group, groupBits, groupRepeat=1):
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||||
self.fc1 = nn.Linear(320, 50)
|
self.fc1 = nn.Linear(320, 50)
|
||||||
|
@ -166,7 +195,7 @@ class SimpleBNN(nn.Module):
|
||||||
self.pool = nn.MaxPool2d(2)
|
self.pool = nn.MaxPool2d(2)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, t):
|
||||||
batch = x.shape[0]
|
batch = x.shape[0]
|
||||||
# x = x.view(batch, -1)
|
# x = x.view(batch, -1)
|
||||||
|
|
||||||
|
@ -222,10 +251,47 @@ class SimpleBNN(nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def printWeight(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleLNN(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(SimpleLNN, self).__init__()
|
||||||
|
# group, groupBits, groupRepeat
|
||||||
|
self.lutg1 = LutGroup(1, 10, 4)
|
||||||
|
self.lutg2 = LutGroup(1, 4, 10)
|
||||||
|
|
||||||
|
def forward(self, x, t):
|
||||||
|
batch = x.shape[0]
|
||||||
|
|
||||||
|
x = torch.zeros_like(t).unsqueeze(-1).repeat(1, 10)
|
||||||
|
x[torch.arange(0, batch), t] = 1
|
||||||
|
|
||||||
|
x = self.lutg1(x)
|
||||||
|
x = self.lutg2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def printWeight(self):
|
||||||
|
print("self.lutg1")
|
||||||
|
print(self.lutg1.weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
|
||||||
|
print("=============================")
|
||||||
|
print("=============================")
|
||||||
|
print("self.lutg1.grad")
|
||||||
|
print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
|
||||||
|
print("=============================")
|
||||||
|
print("=============================")
|
||||||
|
# print("self.lutg2")
|
||||||
|
# print(self.lutg2.weight.detach().cpu().numpy())
|
||||||
|
# print("=============================")
|
||||||
|
# print("=============================")
|
||||||
|
|
||||||
|
|
||||||
torch.autograd.set_detect_anomaly(True)
|
torch.autograd.set_detect_anomaly(True)
|
||||||
# model = SimpleCNN().to(device)
|
# model = SimpleCNN().to(device)
|
||||||
model = SimpleBNN().to(device)
|
# model = SimpleBNN().to(device)
|
||||||
|
model = SimpleLNN().to(device)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
||||||
|
|
||||||
|
@ -248,12 +314,12 @@ def train(epoch):
|
||||||
for batch_idx, (data, target) in enumerate(train_loader):
|
for batch_idx, (data, target) in enumerate(train_loader):
|
||||||
data, target = data.to(device), target.to(device)
|
data, target = data.to(device), target.to(device)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
output = model(data)
|
output = model(data, target)
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
AddScalar("loss", loss, epoch)
|
AddScalar("loss", loss, epoch)
|
||||||
if batch_idx % 512 == 0 and batch_idx > 0:
|
if batch_idx % 1024 == 0 and batch_idx > 0:
|
||||||
print(
|
print(
|
||||||
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
|
||||||
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
|
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
|
||||||
|
@ -267,7 +333,7 @@ def test(epoch):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data, target in test_loader:
|
for data, target in test_loader:
|
||||||
data, target = data.to(device), target.to(device)
|
data, target = data.to(device), target.to(device)
|
||||||
output = model(data)
|
output = model(data, target)
|
||||||
test_loss += criterion(output, target).item()
|
test_loss += criterion(output, target).item()
|
||||||
pred = output.argmax(dim=1, keepdim=True)
|
pred = output.argmax(dim=1, keepdim=True)
|
||||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||||
|
@ -279,6 +345,7 @@ def test(epoch):
|
||||||
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
|
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
|
||||||
f"({accuracy:.0f}%)\n"
|
f"({accuracy:.0f}%)\n"
|
||||||
)
|
)
|
||||||
|
model.printWeight()
|
||||||
|
|
||||||
|
|
||||||
def profiler():
|
def profiler():
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
saves
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,5 @@
|
||||||
|
{
|
||||||
|
"alpaca_zh_demo": {
|
||||||
|
"file_name": "alpaca_zh_demo.json"
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,5 @@
|
||||||
|
model_name_or_path: Qwen/Qwen3-4B
|
||||||
|
adapter_name_or_path: saves/qwen3-4b/lora/sft
|
||||||
|
template: qwen3
|
||||||
|
infer_backend: huggingface # choices: [huggingface, vllm, sglang]
|
||||||
|
trust_remote_code: true
|
|
@ -0,0 +1,13 @@
|
||||||
|
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
|
||||||
|
|
||||||
|
### model
|
||||||
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||||
|
template: llama3
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### export
|
||||||
|
export_dir: output/llama3_lora_sft
|
||||||
|
export_size: 5
|
||||||
|
export_device: cpu # choices: [cpu, auto]
|
||||||
|
export_legacy_format: false
|
|
@ -0,0 +1,46 @@
|
||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen3-4B
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
|
lora_target: all
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: alpaca_zh_demo
|
||||||
|
template: qwen3
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen3-4b/lora/sft
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: tensorboard # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
|
### eval
|
||||||
|
# eval_dataset: alpaca_en_demo
|
||||||
|
# val_size: 0.1
|
||||||
|
# per_device_eval_batch_size: 1
|
||||||
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
|
@ -1,4 +1,5 @@
|
||||||
dump1
|
dump1
|
||||||
dump2
|
dump2
|
||||||
*.png
|
*.png
|
||||||
*.log
|
*.log
|
||||||
|
data
|
|
@ -0,0 +1 @@
|
||||||
|
data
|
Loading…
Reference in New Issue