[fix] fix genarate with custom models does not go to custom_models

This commit is contained in:
yiqing-zhou 2023-05-28 22:51:42 +08:00
parent e8d543558c
commit 9f8f9ecc89
2 changed files with 10 additions and 8 deletions

View File

@ -37,12 +37,12 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
MODEL_MAPPING = _LazyAutoMapping( MODEL_MAPPING = _LazyAutoMapping(
{**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES {**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES
) )
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping( MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(
{**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES {**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
) )

View File

@ -31,15 +31,17 @@ def init_model(model_name: str) -> PreTrainedModel:
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel: def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
if model_name_or_path in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: config = AutoConfig.from_pretrained(model_name_or_path)
model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path)
elif model_name_or_path in custom_models.MODEL_MAPPING_NAMES: if config.model_type in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = custom_models.AutoModel.from_pretrained(model_name_or_path) model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config)
elif config.model_type in custom_models.MODEL_MAPPING_NAMES:
model = custom_models.AutoModel.from_pretrained(model_name_or_path, config=config)
else: else:
try: try:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config, trust_remote_code=True)
except ValueError: except ValueError:
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) model = AutoModel.from_pretrained(model_name_or_path, config=config, trust_remote_code=True)
return model return model