Update finetune

This commit is contained in:
Colin 2024-01-04 19:12:28 +08:00
parent ec72ee1141
commit 9deb809a88
15 changed files with 74 additions and 34 deletions

0
__init__.py Normal file
View File

View File

@ -1,5 +1,4 @@
import sys import sys
sys.path.append("..") sys.path.append("..")
import json import json

View File

@ -1,5 +1,4 @@
import sys import sys
sys.path.append("..") sys.path.append("..")
import json import json

2
qwen/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from qwen.modeling_qwen import QWenLMHeadModel
from qwen.configuration_qwen import QWenConfig

View File

@ -18,13 +18,16 @@ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate.utils import DistributedType from accelerate.utils import DistributedType
from modelscope import snapshot_download from modelscope import snapshot_download
from modeling_qwen import QWenLMHeadModel
IGNORE_TOKEN_ID = LabelSmoother.ignore_index IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@dataclass @dataclass
class ModelArguments: class ModelArguments:
model_name_or_path: Optional[str] = field(default="qwen/Qwen-1_8B-Chat") model_name_or_path: Optional[str] = field(default="qwen/Qwen-1_8B-Chat")
@dataclass @dataclass
class DataArguments: class DataArguments:
@ -101,12 +104,15 @@ def get_peft_state_maybe_zero_3(named_params, bias):
local_rank = None local_rank = None
def rank0_print(*args): def rank0_print(*args):
if local_rank == 0: if local_rank == 0:
print(*args) print(*args)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"): def safe_save_model_for_hf_trainer(
trainer: transformers.Trainer, output_dir: str, bias="none"
):
"""Collects the state dict and dump to disk.""" """Collects the state dict and dump to disk."""
# check if zero3 mode enabled # check if zero3 mode enabled
if deepspeed.is_deepspeed_zero3_enabled(): if deepspeed.is_deepspeed_zero3_enabled():
@ -126,16 +132,16 @@ def preprocess(
sources, sources,
tokenizer: transformers.PreTrainedTokenizer, tokenizer: transformers.PreTrainedTokenizer,
max_len: int, max_len: int,
system_message: str = "You are a helpful assistant." system_message: str = "You are a helpful assistant.",
) -> Dict: ) -> Dict:
roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"} roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}
im_start = tokenizer.im_start_id im_start = tokenizer.im_start_id
im_end = tokenizer.im_end_id im_end = tokenizer.im_end_id
nl_tokens = tokenizer('\n').input_ids nl_tokens = tokenizer("\n").input_ids
_system = tokenizer('system').input_ids + nl_tokens _system = tokenizer("system").input_ids + nl_tokens
_user = tokenizer('user').input_ids + nl_tokens _user = tokenizer("user").input_ids + nl_tokens
_assistant = tokenizer('assistant').input_ids + nl_tokens _assistant = tokenizer("assistant").input_ids + nl_tokens
# Apply prompt templates # Apply prompt templates
input_ids, targets = [], [] input_ids, targets = [], []
@ -144,20 +150,43 @@ def preprocess(
source = source[1:] source = source[1:]
input_id, target = [], [] input_id, target = [], []
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens system = (
[im_start]
+ _system
+ tokenizer(system_message).input_ids
+ [im_end]
+ nl_tokens
)
input_id += system input_id += system
target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens target += (
[im_start] + [IGNORE_TOKEN_ID] * (len(system) - 3) + [im_end] + nl_tokens
)
assert len(input_id) == len(target) assert len(input_id) == len(target)
for j, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
_input_id = tokenizer(role).input_ids + nl_tokens + \ _input_id = (
tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens tokenizer(role).input_ids
+ nl_tokens
+ tokenizer(sentence["value"]).input_ids
+ [im_end]
+ nl_tokens
)
input_id += _input_id input_id += _input_id
if role == '<|im_start|>user': if role == "<|im_start|>user":
_target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens _target = (
elif role == '<|im_start|>assistant': [im_start]
_target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \ + [IGNORE_TOKEN_ID] * (len(_input_id) - 3)
_input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens + [im_end]
+ nl_tokens
)
elif role == "<|im_start|>assistant":
_target = (
[im_start]
+ [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids)
+ _input_id[len(tokenizer(role).input_ids) + 1 : -2]
+ [im_end]
+ nl_tokens
)
else: else:
raise NotImplementedError raise NotImplementedError
target += _target target += _target
@ -179,7 +208,9 @@ def preprocess(
class SupervisedDataset(Dataset): class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning.""" """Dataset for supervised fine-tuning."""
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int): def __init__(
self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int
):
super(SupervisedDataset, self).__init__() super(SupervisedDataset, self).__init__()
rank0_print("Formatting inputs...") rank0_print("Formatting inputs...")
@ -204,7 +235,9 @@ class SupervisedDataset(Dataset):
class LazySupervisedDataset(Dataset): class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning.""" """Dataset for supervised fine-tuning."""
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int): def __init__(
self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int
):
super(LazySupervisedDataset, self).__init__() super(LazySupervisedDataset, self).__init__()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_len = max_len self.max_len = max_len
@ -221,7 +254,9 @@ class LazySupervisedDataset(Dataset):
if i in self.cached_data_dict: if i in self.cached_data_dict:
return self.cached_data_dict[i] return self.cached_data_dict[i]
ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len) ret = preprocess(
[self.raw_data[i]["conversations"]], self.tokenizer, self.max_len
)
ret = dict( ret = dict(
input_ids=ret["input_ids"][0], input_ids=ret["input_ids"][0],
labels=ret["labels"][0], labels=ret["labels"][0],
@ -233,7 +268,9 @@ class LazySupervisedDataset(Dataset):
def make_supervised_data_module( def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer, data_args, max_len, tokenizer: transformers.PreTrainedTokenizer,
data_args,
max_len,
) -> Dict: ) -> Dict:
"""Make dataset and collator for supervised fine-tuning.""" """Make dataset and collator for supervised fine-tuning."""
dataset_cls = ( dataset_cls = (
@ -267,7 +304,10 @@ def train():
) = parser.parse_args_into_dataclasses() ) = parser.parse_args_into_dataclasses()
# This serves for single-gpu qlora. # This serves for single-gpu qlora.
if getattr(training_args, 'deepspeed', None) and int(os.environ.get("WORLD_SIZE", 1))==1: if (
getattr(training_args, "deepspeed", None)
and int(os.environ.get("WORLD_SIZE", 1)) == 1
):
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
local_rank = training_args.local_rank local_rank = training_args.local_rank
@ -278,9 +318,7 @@ def train():
if lora_args.q_lora: if lora_args.q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else "auto" device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else "auto"
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
logging.warning( logging.warning("FSDP or ZeRO3 are incompatible with QLoRA.")
"FSDP or ZeRO3 are incompatible with QLoRA."
)
model_dir = snapshot_download(model_args.model_name_or_path) model_dir = snapshot_download(model_args.model_name_or_path)
@ -294,19 +332,18 @@ def train():
# Load model and tokenizer # Load model and tokenizer
model = QWenLMHeadModel(config)
model = transformers.AutoModelForCausalLM.from_pretrained( model = model.from_pretrained(
model_dir, model_dir,
config=config, config=config,
cache_dir=training_args.cache_dir, cache_dir=training_args.cache_dir,
device_map=device_map, device_map=device_map,
trust_remote_code=True, trust_remote_code=True,
quantization_config=GPTQConfig( quantization_config=GPTQConfig(bits=4, disable_exllama=True)
bits=4, disable_exllama=True
)
if training_args.use_lora and lora_args.q_lora if training_args.use_lora and lora_args.q_lora
else None, else None,
) )
tokenizer = transformers.AutoTokenizer.from_pretrained( tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir, model_dir,
cache_dir=training_args.cache_dir, cache_dir=training_args.cache_dir,
@ -318,7 +355,7 @@ def train():
tokenizer.pad_token_id = tokenizer.eod_id tokenizer.pad_token_id = tokenizer.eod_id
if training_args.use_lora: if training_args.use_lora:
if lora_args.q_lora or 'chat' in model_dir.lower(): if lora_args.q_lora or "chat" in model_dir.lower():
modules_to_save = None modules_to_save = None
else: else:
modules_to_save = ["wte", "lm_head"] modules_to_save = ["wte", "lm_head"]
@ -329,7 +366,7 @@ def train():
lora_dropout=lora_args.lora_dropout, lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias, bias=lora_args.lora_bias,
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
modules_to_save=modules_to_save # This argument serves for adding new tokens. modules_to_save=modules_to_save, # This argument serves for adding new tokens.
) )
if lora_args.q_lora: if lora_args.q_lora:
model = prepare_model_for_kbit_training( model = prepare_model_for_kbit_training(
@ -357,7 +394,9 @@ def train():
trainer.train() trainer.train()
trainer.save_state() trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias) safe_save_model_for_hf_trainer(
trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias
)
if __name__ == "__main__": if __name__ == "__main__":

0
qwen/finetune/finetune_ds.sh → qwen/finetune_ds.sh Normal file → Executable file
View File

View File

View File

View File

@ -0,0 +1 @@
import show