Merge pull request #2 from Yiqing-Zhou/fix-custom-models

[fix] fix genarate with custom models does not go to custom_models
This commit is contained in:
周以晴 2023-05-28 22:58:51 +08:00 committed by GitHub
commit b655153ec7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 8 deletions

View File

@ -37,12 +37,12 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
MODEL_MAPPING = _LazyAutoMapping(
{**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES
{**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES
)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(
{**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
{**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)

View File

@ -31,15 +31,17 @@ def init_model(model_name: str) -> PreTrainedModel:
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
if model_name_or_path in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path)
elif model_name_or_path in custom_models.MODEL_MAPPING_NAMES:
model = custom_models.AutoModel.from_pretrained(model_name_or_path)
config = AutoConfig.from_pretrained(model_name_or_path)
if config.model_type in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config)
elif config.model_type in custom_models.MODEL_MAPPING_NAMES:
model = custom_models.AutoModel.from_pretrained(model_name_or_path, config=config)
else:
try:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config, trust_remote_code=True)
except ValueError:
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name_or_path, config=config, trust_remote_code=True)
return model