gpt-pretrain/lit_export.py

34 lines
890 B
Python
Raw Normal View History

import argparse
from pathlib import Path
from transformers import PreTrainedModel
from lit_module import LitModule
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--version",
type=int,
help="Pytorch lightning checkpoint of version to export",
required=True,
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
lightning_logs_dir_path = Path("lightning_logs").joinpath(f"version_{args.version}")
exports_dir_path = Path("exports").joinpath(f"version_{args.version}")
checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt"))
lit_module = LitModule.load_from_checkpoint(
checkpoint_file_path, map_location='cpu'
)
model: PreTrainedModel = lit_module.__core_module__
model.save_pretrained(exports_dir_path)