[optimize] map_location='cpu' for load_from_checkpoint

This commit is contained in:
Yiqing-Zhou 2023-05-09 00:37:52 +08:00
parent 3f92bbbaa2
commit 8a5e2043bb
1 changed files with 3 additions and 1 deletions

View File

@ -26,6 +26,8 @@ if __name__ == '__main__':
checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt"))
lit_module = LitModule.load_from_checkpoint(checkpoint_file_path)
lit_module = LitModule.load_from_checkpoint(
checkpoint_file_path, map_location='cpu'
)
model: PreTrainedModel = lit_module.__core_module__
model.save_pretrained(exports_dir_path)