enable pretrain.
This commit is contained in:
parent
b655153ec7
commit
7d16743184
|
@ -4,23 +4,15 @@
|
|||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: lit train",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "lit_train.py",
|
||||
"args": [],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Python: generate",
|
||||
"type": "python",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "generate.py",
|
||||
"program": "${file}",
|
||||
"args": [],
|
||||
"cwd": "${fileDirname}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
"justMyCode": false
|
||||
}
|
||||
]
|
||||
}
|
13
lit_train.py
13
lit_train.py
|
@ -106,13 +106,13 @@ def parse_args():
|
|||
"--train_batch_size",
|
||||
type=int,
|
||||
help="Batch size of training",
|
||||
default=8,
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_batch_size",
|
||||
type=int,
|
||||
help="Batch size of validating",
|
||||
default=16,
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accumulate_grad_batches",
|
||||
|
@ -124,7 +124,7 @@ def parse_args():
|
|||
"--num_proc",
|
||||
type=str,
|
||||
help="Number of data processes",
|
||||
default=16,
|
||||
default=1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_epochs",
|
||||
|
@ -136,7 +136,7 @@ def parse_args():
|
|||
"--strategy",
|
||||
type=str,
|
||||
help="Name of pytorch lightning distribution strategy",
|
||||
default='fsdp',
|
||||
default='ddp',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_ckpt_path",
|
||||
|
@ -197,8 +197,11 @@ if __name__ == '__main__':
|
|||
persistent_workers=True,
|
||||
)
|
||||
|
||||
ne = next(train_dataloader._get_iterator())
|
||||
print((ne["input_ids"]-ne["labels"]).numpy().tolist())
|
||||
|
||||
# trainer
|
||||
apply_all_patches()
|
||||
# apply_all_patches()
|
||||
torch.set_float32_matmul_precision('medium')
|
||||
if args.bf16:
|
||||
precision = 'bf16-mixed'
|
||||
|
|
Loading…
Reference in New Issue