gpt-pretrain/custom_models/__init__.py

61 lines
1.7 KiB
Python
Raw Normal View History

2023-05-14 22:23:16 +08:00
import importlib
from collections import OrderedDict
from transformers.models.auto import auto_factory, configuration_auto
2023-05-28 21:33:46 +08:00
CONFIG_MAPPING_NAMES = OrderedDict([])
def register_custom_configs():
for model_type, map_name in CONFIG_MAPPING_NAMES.items():
module_name = configuration_auto.model_type_to_module_name(model_type)
module = importlib.import_module(f".{module_name}", "custom_models")
mapping = getattr(module, map_name)
configuration_auto.AutoConfig.register(model_type, mapping)
2023-05-14 22:23:16 +08:00
class _LazyAutoMapping(auto_factory._LazyAutoMapping):
def _load_attr_from_module(self, model_type, attr):
module_name = auto_factory.model_type_to_module_name(model_type)
if module_name not in self._modules:
2023-05-28 20:02:56 +08:00
self._modules[module_name] = importlib.import_module(f".{module_name}", "custom_models")
2023-05-14 22:23:16 +08:00
return auto_factory.getattribute_from_module(self._modules[module_name], attr)
MODEL_MAPPING_NAMES = OrderedDict(
[
("gpt2", "GPT2Model"),
]
)
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
("gpt2", "GPT2LMHeadModel"),
]
)
2023-05-28 21:33:46 +08:00
MODEL_MAPPING = _LazyAutoMapping(
{**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES
2023-05-28 21:33:46 +08:00
)
2023-05-14 22:23:16 +08:00
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(
{**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
2023-05-14 22:23:16 +08:00
)
class AutoModel(auto_factory._BaseAutoModelClass):
_model_mapping = MODEL_MAPPING
AutoModel = auto_factory.auto_class_update(AutoModel)
class AutoModelForCausalLM(auto_factory._BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
AutoModelForCausalLM = auto_factory.auto_class_update(AutoModelForCausalLM)