61 lines
1.7 KiB
Python
61 lines
1.7 KiB
Python
import importlib
|
|
from collections import OrderedDict
|
|
|
|
from transformers.models.auto import auto_factory, configuration_auto
|
|
|
|
CONFIG_MAPPING_NAMES = OrderedDict([])
|
|
|
|
|
|
def register_custom_configs():
|
|
for model_type, map_name in CONFIG_MAPPING_NAMES.items():
|
|
module_name = configuration_auto.model_type_to_module_name(model_type)
|
|
module = importlib.import_module(f".{module_name}", "custom_models")
|
|
mapping = getattr(module, map_name)
|
|
configuration_auto.AutoConfig.register(model_type, mapping)
|
|
|
|
|
|
class _LazyAutoMapping(auto_factory._LazyAutoMapping):
|
|
def _load_attr_from_module(self, model_type, attr):
|
|
module_name = auto_factory.model_type_to_module_name(model_type)
|
|
if module_name not in self._modules:
|
|
self._modules[module_name] = importlib.import_module(f".{module_name}", "custom_models")
|
|
return auto_factory.getattribute_from_module(self._modules[module_name], attr)
|
|
|
|
|
|
MODEL_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("gpt2", "GPT2Model"),
|
|
]
|
|
)
|
|
|
|
|
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("gpt2", "GPT2LMHeadModel"),
|
|
]
|
|
)
|
|
|
|
|
|
MODEL_MAPPING = _LazyAutoMapping(
|
|
{**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES
|
|
)
|
|
|
|
|
|
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
|
{**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
|
)
|
|
|
|
|
|
class AutoModel(auto_factory._BaseAutoModelClass):
|
|
_model_mapping = MODEL_MAPPING
|
|
|
|
|
|
AutoModel = auto_factory.auto_class_update(AutoModel)
|
|
|
|
|
|
class AutoModelForCausalLM(auto_factory._BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
|
|
|
|
|
AutoModelForCausalLM = auto_factory.auto_class_update(AutoModelForCausalLM)
|