53 lines
1.3 KiB
Python
53 lines
1.3 KiB
Python
|
import importlib
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
from transformers.models.auto import auto_factory, configuration_auto
|
||
|
|
||
|
|
||
|
class _LazyAutoMapping(auto_factory._LazyAutoMapping):
|
||
|
def _load_attr_from_module(self, model_type, attr):
|
||
|
module_name = auto_factory.model_type_to_module_name(model_type)
|
||
|
if module_name not in self._modules:
|
||
|
self._modules[module_name] = importlib.import_module(
|
||
|
f".{module_name}", "custom_models"
|
||
|
)
|
||
|
return auto_factory.getattribute_from_module(self._modules[module_name], attr)
|
||
|
|
||
|
|
||
|
MODEL_MAPPING_NAMES = OrderedDict(
|
||
|
[
|
||
|
("gpt2", "GPT2Model"),
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||
|
[
|
||
|
("gpt2", "GPT2LMHeadModel"),
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
MODEL_MAPPING = _LazyAutoMapping(
|
||
|
configuration_auto.CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
||
|
)
|
||
|
|
||
|
|
||
|
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||
|
configuration_auto.CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||
|
)
|
||
|
|
||
|
|
||
|
class AutoModel(auto_factory._BaseAutoModelClass):
|
||
|
_model_mapping = MODEL_MAPPING
|
||
|
|
||
|
|
||
|
AutoModel = auto_factory.auto_class_update(AutoModel)
|
||
|
|
||
|
|
||
|
class AutoModelForCausalLM(auto_factory._BaseAutoModelClass):
|
||
|
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
||
|
|
||
|
|
||
|
AutoModelForCausalLM = auto_factory.auto_class_update(AutoModelForCausalLM)
|