34 lines
890 B
Python
34 lines
890 B
Python
import argparse
|
|
from pathlib import Path
|
|
|
|
from transformers import PreTrainedModel
|
|
|
|
from lit_module import LitModule
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--version",
|
|
type=int,
|
|
help="Pytorch lightning checkpoint of version to export",
|
|
required=True,
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
lightning_logs_dir_path = Path("lightning_logs").joinpath(f"version_{args.version}")
|
|
exports_dir_path = Path("exports").joinpath(f"version_{args.version}")
|
|
|
|
checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt"))
|
|
|
|
lit_module = LitModule.load_from_checkpoint(
|
|
checkpoint_file_path, map_location='cpu'
|
|
)
|
|
model: PreTrainedModel = lit_module.__core_module__
|
|
model.save_pretrained(exports_dir_path)
|