From 9f8f9ecc8939927e57d55b6dbe83d67775dc3a67 Mon Sep 17 00:00:00 2001 From: yiqing-zhou Date: Sun, 28 May 2023 22:51:42 +0800 Subject: [PATCH] [fix] fix genarate with custom models does not go to custom_models --- custom_models/__init__.py | 4 ++-- utils.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/custom_models/__init__.py b/custom_models/__init__.py index 8401ed2..b1a3e4d 100644 --- a/custom_models/__init__.py +++ b/custom_models/__init__.py @@ -37,12 +37,12 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( MODEL_MAPPING = _LazyAutoMapping( - {**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES + {**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES ) MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping( - {**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + {**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES ) diff --git a/utils.py b/utils.py index 65d3acc..afb5d9e 100644 --- a/utils.py +++ b/utils.py @@ -31,15 +31,17 @@ def init_model(model_name: str) -> PreTrainedModel: def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel: - if model_name_or_path in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path) - elif model_name_or_path in custom_models.MODEL_MAPPING_NAMES: - model = custom_models.AutoModel.from_pretrained(model_name_or_path) + config = AutoConfig.from_pretrained(model_name_or_path) + + if config.model_type in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config) + elif config.model_type in custom_models.MODEL_MAPPING_NAMES: + model = custom_models.AutoModel.from_pretrained(model_name_or_path, config=config) else: try: - model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config, trust_remote_code=True) except ValueError: - model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) + model = AutoModel.from_pretrained(model_name_or_path, config=config, trust_remote_code=True) return model