diff --git a/lit_export.py b/lit_export.py index 1f03924..07cac08 100644 --- a/lit_export.py +++ b/lit_export.py @@ -26,6 +26,8 @@ if __name__ == '__main__': checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt")) - lit_module = LitModule.load_from_checkpoint(checkpoint_file_path) + lit_module = LitModule.load_from_checkpoint( + checkpoint_file_path, map_location='cpu' + ) model: PreTrainedModel = lit_module.__core_module__ model.save_pretrained(exports_dir_path)