diff --git a/wit/train.py b/wit/train.py index 6baec8a..be23a24 100644 --- a/wit/train.py +++ b/wit/train.py @@ -27,12 +27,12 @@ seed = 42 vocab_size = 1024 level_ratio = 4 -level = 4 +level = 6 dataset_level = 1 -hidden_size = 256 # 128 1024 2048 32 -num_attention_heads = 8 # 8 8 16 -num_hidden_layers = 2 # 6 12 24 3 +hidden_size = 2048 # 128 1024 2048 32 +num_attention_heads = 16 # 8 8 16 +num_hidden_layers = 12 # 6 12 24 3 name = "vocab_ratio_level_data_hidden_head_layer" ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{dataset_level}" @@ -51,16 +51,14 @@ if __name__ == "__main__": tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") start = vocab_size * (level_ratio**level) - end = start * level_ratio - size = int(vocab_size * (level_ratio**dataset_level)) - raw_dataset = MeaningDataset(start, end, size, vocab_size, level_ratio) + size = vocab_size * (level_ratio**dataset_level) + raw_dataset = MeaningDataset(start, start + size, size, vocab_size, level_ratio) train_dataset, val_dataset = raw_dataset.split(0.9) train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size) val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size) - # it = iter(train_dataloader) - # print("data samples:") - # for i in range(10): - # print(next(it)["input_ids"].numpy().tolist()) + + # for i in range(len(train_dataloader)): + # print(train_dataloader.print_mapping(i)) torch.set_float32_matmul_precision("medium") lit_trainer = pl.Trainer(