2023-05-07 13:01:02 +08:00
|
|
|
import os
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
from transformers import (
|
|
|
|
AutoConfig,
|
|
|
|
AutoModel,
|
|
|
|
AutoModelForCausalLM,
|
|
|
|
AutoTokenizer,
|
|
|
|
PreTrainedModel,
|
|
|
|
PreTrainedTokenizer,
|
|
|
|
)
|
|
|
|
|
2023-05-14 22:23:16 +08:00
|
|
|
import custom_models
|
|
|
|
|
2023-05-28 21:33:46 +08:00
|
|
|
custom_models.register_custom_configs()
|
2023-05-07 13:01:02 +08:00
|
|
|
|
2023-05-28 21:33:46 +08:00
|
|
|
|
|
|
|
def init_model(model_name: str) -> PreTrainedModel:
|
|
|
|
config = AutoConfig.for_model(model_type=model_name)
|
2023-05-14 22:23:16 +08:00
|
|
|
|
|
|
|
if model_name in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
|
|
|
model = custom_models.AutoModelForCausalLM.from_config(config)
|
|
|
|
elif model_name in custom_models.MODEL_MAPPING_NAMES:
|
|
|
|
model = custom_models.AutoModel.from_config(config)
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
|
|
|
except ValueError:
|
|
|
|
model = AutoModel.from_config(config, trust_remote_code=True)
|
2023-05-07 13:01:02 +08:00
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
|
2023-05-28 22:51:42 +08:00
|
|
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
|
|
|
|
|
|
|
if config.model_type in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
|
|
|
model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config)
|
|
|
|
elif config.model_type in custom_models.MODEL_MAPPING_NAMES:
|
|
|
|
model = custom_models.AutoModel.from_pretrained(model_name_or_path, config=config)
|
2023-05-14 22:23:16 +08:00
|
|
|
else:
|
|
|
|
try:
|
2023-05-28 22:51:42 +08:00
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config, trust_remote_code=True)
|
2023-05-14 22:23:16 +08:00
|
|
|
except ValueError:
|
2023-05-28 22:51:42 +08:00
|
|
|
model = AutoModel.from_pretrained(model_name_or_path, config=config, trust_remote_code=True)
|
2023-05-07 13:01:02 +08:00
|
|
|
return model
|
|
|
|
|
|
|
|
|
2023-05-28 20:02:56 +08:00
|
|
|
def load_tokenizer(tokenizer_name_or_path: Union[str, os.PathLike]) -> PreTrainedTokenizer:
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, padding_side='left', trust_remote_code=True)
|
2023-05-07 13:01:02 +08:00
|
|
|
if tokenizer.pad_token_id is None:
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
return tokenizer
|