Enable wit train on cutome dataset and loss down.

This commit is contained in:
Colin 2024-02-26 22:42:50 +08:00
parent 1ef3e419cb
commit d1906629ab
2 changed files with 16 additions and 17 deletions

View File

@ -97,4 +97,5 @@ class LitModule(pl.LightningModule):
mode="max",
stopping_threshold=1,
)
return [checkpoint_callback, early_stop_callback]
return [checkpoint_callback]
# return [checkpoint_callback, early_stop_callback]

View File

@ -20,33 +20,36 @@ from tokenization_qwen import QWenTokenizer
model_name = "qwen/Qwen-1_8B-Chat"
learning_rate = 0.0001
use_tril_attention_mask = None
precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true"
tokenizer_name_or_path = None
dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki"]
dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki/20230198/58.jsonl.gz"]
train_batch_size = 1
val_batch_size = 1
accumulate_grad_batches = 32
train_batch_size = 256
val_batch_size = 16
limit_val_batches = 128
num_proc = 8
max_epochs = None
max_epochs = 1000
strategy = "fsdp"
resume_from_ckpt_path = None
seed = 42
class SpecialDataset(Dataset):
def __init__(self, size=4096):
def __init__(self, size=65536):
self.size = size
self.features = []
a = torch.randint(0, 1024, [size])
self.data = torch.stack([a, a * 2, a * 3, a * 4]).permute(1, 0)
def __len__(self):
return self.size
def __getitem__(self, idx):
output = {}
output["input_ids"] = torch.randint(0, 4096, [128])
output["labels"] = output["input_ids"]
output["token_type_ids"] = torch.zeros([128])
data = self.data[idx]
output["input_ids"] = data
output["labels"] = data
output["token_type_ids"] = torch.zeros(data.shape)
return output
@ -144,7 +147,6 @@ if __name__ == "__main__":
train_dataset = SpecialDataset()
val_dataset = SpecialDataset()
# dataloaders
train_dataloader = DataLoader(
train_dataset,
batch_size=train_batch_size,
@ -159,21 +161,17 @@ if __name__ == "__main__":
num_workers=num_proc,
collate_fn=DefaultDataCollator(),
persistent_workers=True,
shuffle=True,
)
ne = next(train_dataloader._get_iterator())
# trainer
# apply_all_patches()
torch.set_float32_matmul_precision("medium")
precision = precision
lit_trainer = pl.Trainer(
accelerator="gpu",
precision=precision,
log_every_n_steps=5,
accumulate_grad_batches=accumulate_grad_batches,
strategy=strategy,
max_epochs=max_epochs,
limit_val_batches=limit_val_batches,
)
lit_trainer.fit(
lit_module,