From 5392a845f7ee9d5fd3a21c0ad5ddae634d308a94 Mon Sep 17 00:00:00 2001 From: Yiqing-Zhou Date: Sun, 7 May 2023 13:18:47 +0800 Subject: [PATCH] [feature] export model checkpoint from pl.LightningModule --- README.md | 9 +++++++++ lit_export.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 README.md create mode 100644 lit_export.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..58a200c --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# GPT-Pretrain + +## Usage + +``` +python lit_train.py --model_name gpt2 +python lit_export.py --version 0 +python generate.py --model_name_or_path exports/version_0 --tokenizer_name_or_path gpt2 +``` diff --git a/lit_export.py b/lit_export.py new file mode 100644 index 0000000..1f03924 --- /dev/null +++ b/lit_export.py @@ -0,0 +1,31 @@ +import argparse +from pathlib import Path + +from transformers import PreTrainedModel + +from lit_module import LitModule + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--version", + type=int, + help="Pytorch lightning checkpoint of version to export", + required=True, + ) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + lightning_logs_dir_path = Path("lightning_logs").joinpath(f"version_{args.version}") + exports_dir_path = Path("exports").joinpath(f"version_{args.version}") + + checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt")) + + lit_module = LitModule.load_from_checkpoint(checkpoint_file_path) + model: PreTrainedModel = lit_module.__core_module__ + model.save_pretrained(exports_dir_path)