enable pretrain.

This commit is contained in:
Colin 2024-02-22 15:03:32 +08:00
parent b655153ec7
commit 7d16743184
2 changed files with 12 additions and 17 deletions

16
.vscode/launch.json vendored
View File

@ -4,23 +4,15 @@
// 访: https://go.microsoft.com/fwlink/?linkid=830387 // 访: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{
"name": "Python: lit train",
"type": "python",
"request": "launch",
"program": "lit_train.py",
"args": [],
"console": "integratedTerminal",
"justMyCode": true
},
{ {
"name": "Python: generate", "name": "Python: generate",
"type": "python", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "generate.py", "program": "${file}",
"args": [], "args": [],
"cwd": "${fileDirname}",
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true "justMyCode": false
} }
] ]
} }

View File

@ -106,13 +106,13 @@ def parse_args():
"--train_batch_size", "--train_batch_size",
type=int, type=int,
help="Batch size of training", help="Batch size of training",
default=8, default=2,
) )
parser.add_argument( parser.add_argument(
"--val_batch_size", "--val_batch_size",
type=int, type=int,
help="Batch size of validating", help="Batch size of validating",
default=16, default=2,
) )
parser.add_argument( parser.add_argument(
"--accumulate_grad_batches", "--accumulate_grad_batches",
@ -124,7 +124,7 @@ def parse_args():
"--num_proc", "--num_proc",
type=str, type=str,
help="Number of data processes", help="Number of data processes",
default=16, default=1,
) )
parser.add_argument( parser.add_argument(
"--max_epochs", "--max_epochs",
@ -136,7 +136,7 @@ def parse_args():
"--strategy", "--strategy",
type=str, type=str,
help="Name of pytorch lightning distribution strategy", help="Name of pytorch lightning distribution strategy",
default='fsdp', default='ddp',
) )
parser.add_argument( parser.add_argument(
"--resume_from_ckpt_path", "--resume_from_ckpt_path",
@ -197,8 +197,11 @@ if __name__ == '__main__':
persistent_workers=True, persistent_workers=True,
) )
ne = next(train_dataloader._get_iterator())
print((ne["input_ids"]-ne["labels"]).numpy().tolist())
# trainer # trainer
apply_all_patches() # apply_all_patches()
torch.set_float32_matmul_precision('medium') torch.set_float32_matmul_precision('medium')
if args.bf16: if args.bf16:
precision = 'bf16-mixed' precision = 'bf16-mixed'