Add finetune/.
This commit is contained in:
		
							parent
							
								
									e2c8668a1b
								
							
						
					
					
						commit
						eed28ac06e
					
				|  | @ -34,8 +34,8 @@ transform = transforms.Compose( | ||||||
| train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) | train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) | ||||||
| test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) | test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) | ||||||
| 
 | 
 | ||||||
| train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True) | train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True, drop_last=True, num_workers=4) | ||||||
| test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False) | test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False, drop_last=True) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Lut(torch.autograd.Function): | class Lut(torch.autograd.Function): | ||||||
|  | @ -47,6 +47,7 @@ class Lut(torch.autograd.Function): | ||||||
|         ind = ((input > 0).long() * index).sum(dim=-1) |         ind = ((input > 0).long() * index).sum(dim=-1) | ||||||
|         output = torch.gather(weight, 0, ind) |         output = torch.gather(weight, 0, ind) | ||||||
|         ctx.save_for_backward(input, weight, ind) |         ctx.save_for_backward(input, weight, ind) | ||||||
|  |         output = (output > 0).float() | ||||||
|         return output |         return output | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|  | @ -154,6 +155,10 @@ class SimpleBNN(nn.Module): | ||||||
|         self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1) |         self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1) | ||||||
|         self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1) |         self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1) | ||||||
| 
 | 
 | ||||||
|  |         # self.lutg = LutGroup() | ||||||
|  |         # class LutGroup(nn.Module): | ||||||
|  |         #     def __init__(self, group, groupBits, groupRepeat=1): | ||||||
|  | 
 | ||||||
|         self.conv1 = nn.Conv2d(1, 10, kernel_size=5) |         self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | ||||||
|         self.conv2 = nn.Conv2d(10, 20, kernel_size=5) |         self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | ||||||
|         self.fc1 = nn.Linear(320, 50) |         self.fc1 = nn.Linear(320, 50) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,3 @@ | ||||||
|  | outputs | ||||||
|  | unsloth_compiled_cache | ||||||
|  | wandb | ||||||
|  | @ -0,0 +1,79 @@ | ||||||
|  | from unsloth import FastLanguageModel, FastModel | ||||||
|  | import torch | ||||||
|  | from trl import SFTTrainer, SFTConfig | ||||||
|  | from datasets import load_dataset | ||||||
|  | max_seq_length = 2048 # Supports RoPE Scaling internally, so choose any! | ||||||
|  | # Get LAION dataset | ||||||
|  | url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl" | ||||||
|  | dataset = load_dataset("json", data_files = {"train" : url}, split = "train") | ||||||
|  | 
 | ||||||
|  | # 4bit pre quantized models we support for 4x faster downloading + no OOMs. | ||||||
|  | fourbit_models = [ | ||||||
|  |     "unsloth/Meta-Llama-3.1-8B-bnb-4bit",      # Llama-3.1 2x faster | ||||||
|  |     "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", | ||||||
|  |     "unsloth/Meta-Llama-3.1-70B-bnb-4bit", | ||||||
|  |     "unsloth/Meta-Llama-3.1-405B-bnb-4bit",    # 4bit for 405b! | ||||||
|  |     "unsloth/Mistral-Small-Instruct-2409",     # Mistral 22b 2x faster! | ||||||
|  |     "unsloth/mistral-7b-instruct-v0.3-bnb-4bit", | ||||||
|  |     "unsloth/Phi-3.5-mini-instruct",           # Phi-3.5 2x faster! | ||||||
|  |     "unsloth/Phi-3-medium-4k-instruct", | ||||||
|  |     "unsloth/gemma-2-9b-bnb-4bit", | ||||||
|  |     "unsloth/gemma-2-27b-bnb-4bit",            # Gemma 2x faster! | ||||||
|  | 
 | ||||||
|  |     "unsloth/Llama-3.2-1B-bnb-4bit",           # NEW! Llama 3.2 models | ||||||
|  |     "unsloth/Llama-3.2-1B-Instruct-bnb-4bit", | ||||||
|  |     "unsloth/Llama-3.2-3B-bnb-4bit", | ||||||
|  |     "unsloth/Llama-3.2-3B-Instruct-bnb-4bit", | ||||||
|  | 
 | ||||||
|  |     "unsloth/Llama-3.3-70B-Instruct-bnb-4bit" # NEW! Llama 3.3 70B! | ||||||
|  | ] # More models at https://huggingface.co/unsloth | ||||||
|  | 
 | ||||||
|  | model, tokenizer = FastModel.from_pretrained( | ||||||
|  |     model_name = "unsloth/Qwen3-4B", | ||||||
|  |     max_seq_length = 2048, # Choose any for long context! | ||||||
|  |     load_in_4bit = False,  # 4 bit quantization to reduce memory | ||||||
|  |     load_in_8bit = True, # [NEW!] A bit more accurate, uses 2x memory | ||||||
|  |     full_finetuning = False, # [NEW!] We have full finetuning now! | ||||||
|  |     # token = "hf_...", # use one if using gated models | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | # Do model patching and add fast LoRA weights | ||||||
|  | model = FastLanguageModel.get_peft_model( | ||||||
|  |     model, | ||||||
|  |     r = 16, | ||||||
|  |     target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", | ||||||
|  |                       "gate_proj", "up_proj", "down_proj",], | ||||||
|  |     lora_alpha = 16, | ||||||
|  |     lora_dropout = 0, # Supports any, but = 0 is optimized | ||||||
|  |     bias = "none",    # Supports any, but = "none" is optimized | ||||||
|  |     # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! | ||||||
|  |     use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context | ||||||
|  |     random_state = 3407, | ||||||
|  |     max_seq_length = max_seq_length, | ||||||
|  |     use_rslora = False,  # We support rank stabilized LoRA | ||||||
|  |     loftq_config = None, # And LoftQ | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | trainer = SFTTrainer( | ||||||
|  |     model = model, | ||||||
|  |     train_dataset = dataset, | ||||||
|  |     tokenizer = tokenizer, | ||||||
|  |     args = SFTConfig( | ||||||
|  |         max_seq_length = max_seq_length, | ||||||
|  |         per_device_train_batch_size = 2, | ||||||
|  |         gradient_accumulation_steps = 4, | ||||||
|  |         warmup_steps = 10, | ||||||
|  |         max_steps = 60, | ||||||
|  |         logging_steps = 1, | ||||||
|  |         output_dir = "outputs", | ||||||
|  |         optim = "adamw_8bit", | ||||||
|  |         seed = 3407, | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | trainer.train() | ||||||
|  | 
 | ||||||
|  | # Go to https://github.com/unslothai/unsloth/wiki for advanced tips like | ||||||
|  | # (1) Saving to GGUF / merging to 16bit for vLLM | ||||||
|  | # (2) Continued training from a saved LoRA adapter | ||||||
|  | # (3) Adding an evaluation loop / OOMs | ||||||
|  | # (4) Customized chat templates | ||||||
		Loading…
	
		Reference in New Issue