Refine train dataset.
This commit is contained in:
		
							parent
							
								
									3c774983d4
								
							
						
					
					
						commit
						2bc9e3b57e
					
				
							
								
								
									
										20
									
								
								wit/train.py
								
								
								
								
							
							
						
						
									
										20
									
								
								wit/train.py
								
								
								
								
							|  | @ -27,12 +27,12 @@ seed = 42 | |||
| 
 | ||||
| vocab_size = 1024 | ||||
| level_ratio = 4 | ||||
| level = 4 | ||||
| level = 6 | ||||
| dataset_level = 1 | ||||
| 
 | ||||
| hidden_size = 256  # 128 1024 2048  32 | ||||
| num_attention_heads = 8  # 8 8 16 | ||||
| num_hidden_layers = 2  # 6 12 24  3 | ||||
| hidden_size = 2048  # 128 1024 2048  32 | ||||
| num_attention_heads = 16  # 8 8 16 | ||||
| num_hidden_layers = 12  # 6 12 24  3 | ||||
| 
 | ||||
| name = "vocab_ratio_level_data_hidden_head_layer" | ||||
| ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{dataset_level}" | ||||
|  | @ -51,16 +51,14 @@ if __name__ == "__main__": | |||
|     tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") | ||||
| 
 | ||||
|     start = vocab_size * (level_ratio**level) | ||||
|     end = start * level_ratio | ||||
|     size = int(vocab_size * (level_ratio**dataset_level)) | ||||
|     raw_dataset = MeaningDataset(start, end, size, vocab_size, level_ratio) | ||||
|     size = vocab_size * (level_ratio**dataset_level) | ||||
|     raw_dataset = MeaningDataset(start, start + size, size, vocab_size, level_ratio) | ||||
|     train_dataset, val_dataset = raw_dataset.split(0.9) | ||||
|     train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size) | ||||
|     val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size) | ||||
|     # it = iter(train_dataloader) | ||||
|     # print("data samples:") | ||||
|     # for i in range(10): | ||||
|     #     print(next(it)["input_ids"].numpy().tolist()) | ||||
| 
 | ||||
|     # for i in range(len(train_dataloader)): | ||||
|     #     print(train_dataloader.print_mapping(i)) | ||||
| 
 | ||||
|     torch.set_float32_matmul_precision("medium") | ||||
|     lit_trainer = pl.Trainer( | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue