Merge pull request #2 from Yiqing-Zhou/fix-custom-models
[fix] fix genarate with custom models does not go to custom_models
This commit is contained in:
commit
b655153ec7
|
@ -37,12 +37,12 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||||
|
|
||||||
|
|
||||||
MODEL_MAPPING = _LazyAutoMapping(
|
MODEL_MAPPING = _LazyAutoMapping(
|
||||||
{**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES
|
{**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||||
{**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
{**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
14
utils.py
14
utils.py
|
@ -31,15 +31,17 @@ def init_model(model_name: str) -> PreTrainedModel:
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
|
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
|
||||||
if model_name_or_path in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
|
||||||
elif model_name_or_path in custom_models.MODEL_MAPPING_NAMES:
|
if config.model_type in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
model = custom_models.AutoModel.from_pretrained(model_name_or_path)
|
model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config)
|
||||||
|
elif config.model_type in custom_models.MODEL_MAPPING_NAMES:
|
||||||
|
model = custom_models.AutoModel.from_pretrained(model_name_or_path, config=config)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config, trust_remote_code=True)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
|
model = AutoModel.from_pretrained(model_name_or_path, config=config, trust_remote_code=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue