enable pretrain.

This commit is contained in:
Colin 2024-02-22 15:03:32 +08:00
parent b655153ec7
commit 7d16743184
2 changed files with 12 additions and 17 deletions

16
.vscode/launch.json vendored
View File

@ -4,23 +4,15 @@
// 访: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: lit train",
"type": "python",
"request": "launch",
"program": "lit_train.py",
"args": [],
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Python: generate",
"type": "python",
"type": "debugpy",
"request": "launch",
"program": "generate.py",
"program": "${file}",
"args": [],
"cwd": "${fileDirname}",
"console": "integratedTerminal",
"justMyCode": true
"justMyCode": false
}
]
}

View File

@ -106,13 +106,13 @@ def parse_args():
"--train_batch_size",
type=int,
help="Batch size of training",
default=8,
default=2,
)
parser.add_argument(
"--val_batch_size",
type=int,
help="Batch size of validating",
default=16,
default=2,
)
parser.add_argument(
"--accumulate_grad_batches",
@ -124,7 +124,7 @@ def parse_args():
"--num_proc",
type=str,
help="Number of data processes",
default=16,
default=1,
)
parser.add_argument(
"--max_epochs",
@ -136,7 +136,7 @@ def parse_args():
"--strategy",
type=str,
help="Name of pytorch lightning distribution strategy",
default='fsdp',
default='ddp',
)
parser.add_argument(
"--resume_from_ckpt_path",
@ -197,8 +197,11 @@ if __name__ == '__main__':
persistent_workers=True,
)
ne = next(train_dataloader._get_iterator())
print((ne["input_ids"]-ne["labels"]).numpy().tolist())
# trainer
apply_all_patches()
# apply_all_patches()
torch.set_float32_matmul_precision('medium')
if args.bf16:
precision = 'bf16-mixed'