[feature] custom model configs

This commit is contained in:
yiqing-zhou 2023-05-28 21:33:46 +08:00
parent b76d333f39
commit fcb93e52c4
2 changed files with 18 additions and 4 deletions

View File

@ -3,6 +3,16 @@ from collections import OrderedDict
from transformers.models.auto import auto_factory, configuration_auto from transformers.models.auto import auto_factory, configuration_auto
CONFIG_MAPPING_NAMES = OrderedDict([])
def register_custom_configs():
for model_type, map_name in CONFIG_MAPPING_NAMES.items():
module_name = configuration_auto.model_type_to_module_name(model_type)
module = importlib.import_module(f".{module_name}", "custom_models")
mapping = getattr(module, map_name)
configuration_auto.AutoConfig.register(model_type, mapping)
class _LazyAutoMapping(auto_factory._LazyAutoMapping): class _LazyAutoMapping(auto_factory._LazyAutoMapping):
def _load_attr_from_module(self, model_type, attr): def _load_attr_from_module(self, model_type, attr):
@ -26,11 +36,13 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
) )
MODEL_MAPPING = _LazyAutoMapping(configuration_auto.CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) MODEL_MAPPING = _LazyAutoMapping(
{**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES
)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping( MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(
configuration_auto.CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES {**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
) )

View File

@ -12,9 +12,11 @@ from transformers import (
import custom_models import custom_models
custom_models.register_custom_configs()
def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) def init_model(model_name: str) -> PreTrainedModel:
config = AutoConfig.for_model(model_type=model_name)
if model_name in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_name in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = custom_models.AutoModelForCausalLM.from_config(config) model = custom_models.AutoModelForCausalLM.from_config(config)