[feature] export model checkpoint from pl.LightningModule
This commit is contained in:
parent
09507449f7
commit
5392a845f7
|
@ -0,0 +1,9 @@
|
|||
# GPT-Pretrain
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
python lit_train.py --model_name gpt2
|
||||
python lit_export.py --version 0
|
||||
python generate.py --model_name_or_path exports/version_0 --tokenizer_name_or_path gpt2
|
||||
```
|
|
@ -0,0 +1,31 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from lit_module import LitModule
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
type=int,
|
||||
help="Pytorch lightning checkpoint of version to export",
|
||||
required=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
lightning_logs_dir_path = Path("lightning_logs").joinpath(f"version_{args.version}")
|
||||
exports_dir_path = Path("exports").joinpath(f"version_{args.version}")
|
||||
|
||||
checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt"))
|
||||
|
||||
lit_module = LitModule.load_from_checkpoint(checkpoint_file_path)
|
||||
model: PreTrainedModel = lit_module.__core_module__
|
||||
model.save_pretrained(exports_dir_path)
|
Loading…
Reference in New Issue