gpt-pretrain/utils.py

51 lines
1.8 KiB
Python
Raw Normal View History

2023-05-07 13:01:02 +08:00
import os
from typing import Union
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)
2023-05-14 22:23:16 +08:00
import custom_models
2023-05-28 21:33:46 +08:00
custom_models.register_custom_configs()
2023-05-07 13:01:02 +08:00
2023-05-28 21:33:46 +08:00
def init_model(model_name: str) -> PreTrainedModel:
config = AutoConfig.for_model(model_type=model_name)
2023-05-14 22:23:16 +08:00
if model_name in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = custom_models.AutoModelForCausalLM.from_config(config)
elif model_name in custom_models.MODEL_MAPPING_NAMES:
model = custom_models.AutoModel.from_config(config)
else:
try:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
except ValueError:
model = AutoModel.from_config(config, trust_remote_code=True)
2023-05-07 13:01:02 +08:00
return model
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
2023-05-14 22:23:16 +08:00
if model_name_or_path in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path)
elif model_name_or_path in custom_models.MODEL_MAPPING_NAMES:
model = custom_models.AutoModel.from_pretrained(model_name_or_path)
else:
try:
2023-05-28 20:02:56 +08:00
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)
2023-05-14 22:23:16 +08:00
except ValueError:
2023-05-28 20:02:56 +08:00
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
2023-05-07 13:01:02 +08:00
return model
2023-05-28 20:02:56 +08:00
def load_tokenizer(tokenizer_name_or_path: Union[str, os.PathLike]) -> PreTrainedTokenizer:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, padding_side='left', trust_remote_code=True)
2023-05-07 13:01:02 +08:00
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer