diff --git a/.vscode/launch.json b/.vscode/launch.json index 56747ee..f191eb4 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,23 +4,15 @@ // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ - { - "name": "Python: lit train", - "type": "python", - "request": "launch", - "program": "lit_train.py", - "args": [], - "console": "integratedTerminal", - "justMyCode": true - }, { "name": "Python: generate", - "type": "python", + "type": "debugpy", "request": "launch", - "program": "generate.py", + "program": "${file}", "args": [], + "cwd": "${fileDirname}", "console": "integratedTerminal", - "justMyCode": true + "justMyCode": false } ] } \ No newline at end of file diff --git a/lit_train.py b/lit_train.py index 9a5968c..9a9b3bd 100644 --- a/lit_train.py +++ b/lit_train.py @@ -106,13 +106,13 @@ def parse_args(): "--train_batch_size", type=int, help="Batch size of training", - default=8, + default=2, ) parser.add_argument( "--val_batch_size", type=int, help="Batch size of validating", - default=16, + default=2, ) parser.add_argument( "--accumulate_grad_batches", @@ -124,7 +124,7 @@ def parse_args(): "--num_proc", type=str, help="Number of data processes", - default=16, + default=1, ) parser.add_argument( "--max_epochs", @@ -136,7 +136,7 @@ def parse_args(): "--strategy", type=str, help="Name of pytorch lightning distribution strategy", - default='fsdp', + default='ddp', ) parser.add_argument( "--resume_from_ckpt_path", @@ -197,8 +197,11 @@ if __name__ == '__main__': persistent_workers=True, ) + ne = next(train_dataloader._get_iterator()) + print((ne["input_ids"]-ne["labels"]).numpy().tolist()) + # trainer - apply_all_patches() + # apply_all_patches() torch.set_float32_matmul_precision('medium') if args.bf16: precision = 'bf16-mixed'