[feature] custom model configs
This commit is contained in:
parent
b76d333f39
commit
fcb93e52c4
|
@ -3,6 +3,16 @@ from collections import OrderedDict
|
||||||
|
|
||||||
from transformers.models.auto import auto_factory, configuration_auto
|
from transformers.models.auto import auto_factory, configuration_auto
|
||||||
|
|
||||||
|
CONFIG_MAPPING_NAMES = OrderedDict([])
|
||||||
|
|
||||||
|
|
||||||
|
def register_custom_configs():
|
||||||
|
for model_type, map_name in CONFIG_MAPPING_NAMES.items():
|
||||||
|
module_name = configuration_auto.model_type_to_module_name(model_type)
|
||||||
|
module = importlib.import_module(f".{module_name}", "custom_models")
|
||||||
|
mapping = getattr(module, map_name)
|
||||||
|
configuration_auto.AutoConfig.register(model_type, mapping)
|
||||||
|
|
||||||
|
|
||||||
class _LazyAutoMapping(auto_factory._LazyAutoMapping):
|
class _LazyAutoMapping(auto_factory._LazyAutoMapping):
|
||||||
def _load_attr_from_module(self, model_type, attr):
|
def _load_attr_from_module(self, model_type, attr):
|
||||||
|
@ -26,11 +36,13 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL_MAPPING = _LazyAutoMapping(configuration_auto.CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
MODEL_MAPPING = _LazyAutoMapping(
|
||||||
|
{**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||||
configuration_auto.CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
{**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
6
utils.py
6
utils.py
|
@ -12,9 +12,11 @@ from transformers import (
|
||||||
|
|
||||||
import custom_models
|
import custom_models
|
||||||
|
|
||||||
|
custom_models.register_custom_configs()
|
||||||
|
|
||||||
def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel:
|
|
||||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
def init_model(model_name: str) -> PreTrainedModel:
|
||||||
|
config = AutoConfig.for_model(model_type=model_name)
|
||||||
|
|
||||||
if model_name in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_name in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
model = custom_models.AutoModelForCausalLM.from_config(config)
|
model = custom_models.AutoModelForCausalLM.from_config(config)
|
||||||
|
|
Loading…
Reference in New Issue