[feature] export model checkpoint from pl.LightningModule

This commit is contained in:
Yiqing-Zhou 2023-05-07 13:18:47 +08:00
parent 09507449f7
commit 5392a845f7
2 changed files with 40 additions and 0 deletions

9
README.md Normal file
View File

@ -0,0 +1,9 @@
# GPT-Pretrain
## Usage
```
python lit_train.py --model_name gpt2
python lit_export.py --version 0
python generate.py --model_name_or_path exports/version_0 --tokenizer_name_or_path gpt2
```

31
lit_export.py Normal file
View File

@ -0,0 +1,31 @@
import argparse
from pathlib import Path
from transformers import PreTrainedModel
from lit_module import LitModule
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--version",
type=int,
help="Pytorch lightning checkpoint of version to export",
required=True,
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
lightning_logs_dir_path = Path("lightning_logs").joinpath(f"version_{args.version}")
exports_dir_path = Path("exports").joinpath(f"version_{args.version}")
checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt"))
lit_module = LitModule.load_from_checkpoint(checkpoint_file_path)
model: PreTrainedModel = lit_module.__core_module__
model.save_pretrained(exports_dir_path)