diff --git a/custom_models/__init__.py b/custom_models/__init__.py index 8401ed2..b1a3e4d 100644 --- a/custom_models/__init__.py +++ b/custom_models/__init__.py @@ -37,12 +37,12 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( MODEL_MAPPING = _LazyAutoMapping( - {**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES + {**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES ) MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping( - {**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + {**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES ) diff --git a/utils.py b/utils.py index 65d3acc..afb5d9e 100644 --- a/utils.py +++ b/utils.py @@ -31,15 +31,17 @@ def init_model(model_name: str) -> PreTrainedModel: def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel: - if model_name_or_path in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path) - elif model_name_or_path in custom_models.MODEL_MAPPING_NAMES: - model = custom_models.AutoModel.from_pretrained(model_name_or_path) + config = AutoConfig.from_pretrained(model_name_or_path) + + if config.model_type in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config) + elif config.model_type in custom_models.MODEL_MAPPING_NAMES: + model = custom_models.AutoModel.from_pretrained(model_name_or_path, config=config) else: try: - model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config, trust_remote_code=True) except ValueError: - model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) + model = AutoModel.from_pretrained(model_name_or_path, config=config, trust_remote_code=True) return model