From fcb93e52c468a23501f8dc825d44a6f91a10635f Mon Sep 17 00:00:00 2001 From: yiqing-zhou Date: Sun, 28 May 2023 21:33:46 +0800 Subject: [PATCH] [feature] custom model configs --- custom_models/__init__.py | 16 ++++++++++++++-- utils.py | 6 ++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/custom_models/__init__.py b/custom_models/__init__.py index 5534944..8401ed2 100644 --- a/custom_models/__init__.py +++ b/custom_models/__init__.py @@ -3,6 +3,16 @@ from collections import OrderedDict from transformers.models.auto import auto_factory, configuration_auto +CONFIG_MAPPING_NAMES = OrderedDict([]) + + +def register_custom_configs(): + for model_type, map_name in CONFIG_MAPPING_NAMES.items(): + module_name = configuration_auto.model_type_to_module_name(model_type) + module = importlib.import_module(f".{module_name}", "custom_models") + mapping = getattr(module, map_name) + configuration_auto.AutoConfig.register(model_type, mapping) + class _LazyAutoMapping(auto_factory._LazyAutoMapping): def _load_attr_from_module(self, model_type, attr): @@ -26,11 +36,13 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ) -MODEL_MAPPING = _LazyAutoMapping(configuration_auto.CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) +MODEL_MAPPING = _LazyAutoMapping( + {**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES +) MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping( - configuration_auto.CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + {**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES ) diff --git a/utils.py b/utils.py index 00bbc6f..65d3acc 100644 --- a/utils.py +++ b/utils.py @@ -12,9 +12,11 @@ from transformers import ( import custom_models +custom_models.register_custom_configs() -def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel: - config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + +def init_model(model_name: str) -> PreTrainedModel: + config = AutoConfig.for_model(model_type=model_name) if model_name in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: model = custom_models.AutoModelForCausalLM.from_config(config)