Merge pull request #2 from Yiqing-Zhou/fix-custom-models
[fix] fix genarate with custom models does not go to custom_models
This commit is contained in:
		
						commit
						b655153ec7
					
				|  | @ -37,12 +37,12 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( | |||
| 
 | ||||
| 
 | ||||
| MODEL_MAPPING = _LazyAutoMapping( | ||||
|     {**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES | ||||
|     {**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping( | ||||
|     {**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES | ||||
|     {**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										14
									
								
								utils.py
								
								
								
								
							
							
						
						
									
										14
									
								
								utils.py
								
								
								
								
							|  | @ -31,15 +31,17 @@ def init_model(model_name: str) -> PreTrainedModel: | |||
| 
 | ||||
| 
 | ||||
| def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel: | ||||
|     if model_name_or_path in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: | ||||
|         model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path) | ||||
|     elif model_name_or_path in custom_models.MODEL_MAPPING_NAMES: | ||||
|         model = custom_models.AutoModel.from_pretrained(model_name_or_path) | ||||
|     config = AutoConfig.from_pretrained(model_name_or_path) | ||||
| 
 | ||||
|     if config.model_type in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: | ||||
|         model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config) | ||||
|     elif config.model_type in custom_models.MODEL_MAPPING_NAMES: | ||||
|         model = custom_models.AutoModel.from_pretrained(model_name_or_path, config=config) | ||||
|     else: | ||||
|         try: | ||||
|             model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True) | ||||
|             model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config, trust_remote_code=True) | ||||
|         except ValueError: | ||||
|             model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) | ||||
|             model = AutoModel.from_pretrained(model_name_or_path, config=config, trust_remote_code=True) | ||||
|     return model | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue