Enable wit train on cutome dataset and loss down.

This commit is contained in:
Colin 2024-02-26 22:42:50 +08:00
parent 1ef3e419cb
commit d1906629ab
2 changed files with 16 additions and 17 deletions

View File

@ -97,4 +97,5 @@ class LitModule(pl.LightningModule):
mode="max", mode="max",
stopping_threshold=1, stopping_threshold=1,
) )
return [checkpoint_callback, early_stop_callback] return [checkpoint_callback]
# return [checkpoint_callback, early_stop_callback]

View File

@ -20,33 +20,36 @@ from tokenization_qwen import QWenTokenizer
model_name = "qwen/Qwen-1_8B-Chat" model_name = "qwen/Qwen-1_8B-Chat"
learning_rate = 0.0001 learning_rate = 0.0001
use_tril_attention_mask = None use_tril_attention_mask = None
precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true" precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true"
tokenizer_name_or_path = None tokenizer_name_or_path = None
dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki"] dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki"]
dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki/20230198/58.jsonl.gz"] dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki/20230198/58.jsonl.gz"]
train_batch_size = 1 train_batch_size = 256
val_batch_size = 1 val_batch_size = 16
accumulate_grad_batches = 32 limit_val_batches = 128
num_proc = 8 num_proc = 8
max_epochs = None max_epochs = 1000
strategy = "fsdp" strategy = "fsdp"
resume_from_ckpt_path = None resume_from_ckpt_path = None
seed = 42 seed = 42
class SpecialDataset(Dataset): class SpecialDataset(Dataset):
def __init__(self, size=4096): def __init__(self, size=65536):
self.size = size self.size = size
self.features = [] self.features = []
a = torch.randint(0, 1024, [size])
self.data = torch.stack([a, a * 2, a * 3, a * 4]).permute(1, 0)
def __len__(self): def __len__(self):
return self.size return self.size
def __getitem__(self, idx): def __getitem__(self, idx):
output = {} output = {}
output["input_ids"] = torch.randint(0, 4096, [128]) data = self.data[idx]
output["labels"] = output["input_ids"] output["input_ids"] = data
output["token_type_ids"] = torch.zeros([128]) output["labels"] = data
output["token_type_ids"] = torch.zeros(data.shape)
return output return output
@ -144,7 +147,6 @@ if __name__ == "__main__":
train_dataset = SpecialDataset() train_dataset = SpecialDataset()
val_dataset = SpecialDataset() val_dataset = SpecialDataset()
# dataloaders
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_dataset, train_dataset,
batch_size=train_batch_size, batch_size=train_batch_size,
@ -159,21 +161,17 @@ if __name__ == "__main__":
num_workers=num_proc, num_workers=num_proc,
collate_fn=DefaultDataCollator(), collate_fn=DefaultDataCollator(),
persistent_workers=True, persistent_workers=True,
shuffle=True,
) )
ne = next(train_dataloader._get_iterator())
# trainer
# apply_all_patches()
torch.set_float32_matmul_precision("medium") torch.set_float32_matmul_precision("medium")
precision = precision precision = precision
lit_trainer = pl.Trainer( lit_trainer = pl.Trainer(
accelerator="gpu", accelerator="gpu",
precision=precision, precision=precision,
log_every_n_steps=5,
accumulate_grad_batches=accumulate_grad_batches,
strategy=strategy, strategy=strategy,
max_epochs=max_epochs, max_epochs=max_epochs,
limit_val_batches=limit_val_batches,
) )
lit_trainer.fit( lit_trainer.fit(
lit_module, lit_module,