diff --git a/README.md b/README.md new file mode 100644 index 0000000..58a200c --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# GPT-Pretrain + +## Usage + +``` +python lit_train.py --model_name gpt2 +python lit_export.py --version 0 +python generate.py --model_name_or_path exports/version_0 --tokenizer_name_or_path gpt2 +``` diff --git a/lit_export.py b/lit_export.py new file mode 100644 index 0000000..1f03924 --- /dev/null +++ b/lit_export.py @@ -0,0 +1,31 @@ +import argparse +from pathlib import Path + +from transformers import PreTrainedModel + +from lit_module import LitModule + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--version", + type=int, + help="Pytorch lightning checkpoint of version to export", + required=True, + ) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + lightning_logs_dir_path = Path("lightning_logs").joinpath(f"version_{args.version}") + exports_dir_path = Path("exports").joinpath(f"version_{args.version}") + + checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt")) + + lit_module = LitModule.load_from_checkpoint(checkpoint_file_path) + model: PreTrainedModel = lit_module.__core_module__ + model.save_pretrained(exports_dir_path)