enable pretrain.
This commit is contained in:
parent
b655153ec7
commit
7d16743184
|
@ -4,23 +4,15 @@
|
||||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
{
|
|
||||||
"name": "Python: lit train",
|
|
||||||
"type": "python",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "lit_train.py",
|
|
||||||
"args": [],
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"justMyCode": true
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "Python: generate",
|
"name": "Python: generate",
|
||||||
"type": "python",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "generate.py",
|
"program": "${file}",
|
||||||
"args": [],
|
"args": [],
|
||||||
|
"cwd": "${fileDirname}",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": false
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
13
lit_train.py
13
lit_train.py
|
@ -106,13 +106,13 @@ def parse_args():
|
||||||
"--train_batch_size",
|
"--train_batch_size",
|
||||||
type=int,
|
type=int,
|
||||||
help="Batch size of training",
|
help="Batch size of training",
|
||||||
default=8,
|
default=2,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--val_batch_size",
|
"--val_batch_size",
|
||||||
type=int,
|
type=int,
|
||||||
help="Batch size of validating",
|
help="Batch size of validating",
|
||||||
default=16,
|
default=2,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--accumulate_grad_batches",
|
"--accumulate_grad_batches",
|
||||||
|
@ -124,7 +124,7 @@ def parse_args():
|
||||||
"--num_proc",
|
"--num_proc",
|
||||||
type=str,
|
type=str,
|
||||||
help="Number of data processes",
|
help="Number of data processes",
|
||||||
default=16,
|
default=1,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_epochs",
|
"--max_epochs",
|
||||||
|
@ -136,7 +136,7 @@ def parse_args():
|
||||||
"--strategy",
|
"--strategy",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name of pytorch lightning distribution strategy",
|
help="Name of pytorch lightning distribution strategy",
|
||||||
default='fsdp',
|
default='ddp',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resume_from_ckpt_path",
|
"--resume_from_ckpt_path",
|
||||||
|
@ -197,8 +197,11 @@ if __name__ == '__main__':
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ne = next(train_dataloader._get_iterator())
|
||||||
|
print((ne["input_ids"]-ne["labels"]).numpy().tolist())
|
||||||
|
|
||||||
# trainer
|
# trainer
|
||||||
apply_all_patches()
|
# apply_all_patches()
|
||||||
torch.set_float32_matmul_precision('medium')
|
torch.set_float32_matmul_precision('medium')
|
||||||
if args.bf16:
|
if args.bf16:
|
||||||
precision = 'bf16-mixed'
|
precision = 'bf16-mixed'
|
||||||
|
|
Loading…
Reference in New Issue