Add custom dataset support.
This commit is contained in:
		
							parent
							
								
									e5f97af291
								
							
						
					
					
						commit
						1ef3e419cb
					
				| 
						 | 
					@ -6,7 +6,7 @@ from typing import Dict, Tuple
 | 
				
			||||||
import datasets
 | 
					import datasets
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch.utils.data import ConcatDataset, DataLoader
 | 
					from torch.utils.data import ConcatDataset, DataLoader, Dataset
 | 
				
			||||||
from transformers import (
 | 
					from transformers import (
 | 
				
			||||||
    BatchEncoding,
 | 
					    BatchEncoding,
 | 
				
			||||||
    DefaultDataCollator,
 | 
					    DefaultDataCollator,
 | 
				
			||||||
| 
						 | 
					@ -22,9 +22,9 @@ learning_rate = 0.0001
 | 
				
			||||||
use_tril_attention_mask = None
 | 
					use_tril_attention_mask = None
 | 
				
			||||||
precision = "16-mixed"  # "precision:bf16-mixed,16-mixed,32-true"
 | 
					precision = "16-mixed"  # "precision:bf16-mixed,16-mixed,32-true"
 | 
				
			||||||
tokenizer_name_or_path = None
 | 
					tokenizer_name_or_path = None
 | 
				
			||||||
dataset_name = "/home/colin/develop/dataset/liwu/MNBVC/wiki/20230197/0.jsonl.gz"
 | 
					dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki"]
 | 
				
			||||||
dataset_name = "/home/colin/develop/dataset/liwu/MNBVC/wiki"
 | 
					dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki/20230198/58.jsonl.gz"]
 | 
				
			||||||
train_batch_size = 8
 | 
					train_batch_size = 1
 | 
				
			||||||
val_batch_size = 1
 | 
					val_batch_size = 1
 | 
				
			||||||
accumulate_grad_batches = 32
 | 
					accumulate_grad_batches = 32
 | 
				
			||||||
num_proc = 8
 | 
					num_proc = 8
 | 
				
			||||||
| 
						 | 
					@ -34,6 +34,22 @@ resume_from_ckpt_path = None
 | 
				
			||||||
seed = 42
 | 
					seed = 42
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SpecialDataset(Dataset):
 | 
				
			||||||
 | 
					    def __init__(self, size=4096):
 | 
				
			||||||
 | 
					        self.size = size
 | 
				
			||||||
 | 
					        self.features = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return self.size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, idx):
 | 
				
			||||||
 | 
					        output = {}
 | 
				
			||||||
 | 
					        output["input_ids"] = torch.randint(0, 4096, [128])
 | 
				
			||||||
 | 
					        output["labels"] = output["input_ids"]
 | 
				
			||||||
 | 
					        output["token_type_ids"] = torch.zeros([128])
 | 
				
			||||||
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def split_raw_dataset(
 | 
					def split_raw_dataset(
 | 
				
			||||||
    raw_dataset: datasets.DatasetDict,
 | 
					    raw_dataset: datasets.DatasetDict,
 | 
				
			||||||
) -> Tuple[datasets.Dataset, datasets.Dataset]:
 | 
					) -> Tuple[datasets.Dataset, datasets.Dataset]:
 | 
				
			||||||
| 
						 | 
					@ -106,17 +122,17 @@ if __name__ == "__main__":
 | 
				
			||||||
    model_dir = snapshot_download(model_name)
 | 
					    model_dir = snapshot_download(model_name)
 | 
				
			||||||
    lit_module = LitModule(model_dir, learning_rate, use_tril_attention_mask)
 | 
					    lit_module = LitModule(model_dir, learning_rate, use_tril_attention_mask)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # datasets
 | 
					 | 
				
			||||||
    # tokenizer = load_tokenizer("./custom_models/gpt2")
 | 
					 | 
				
			||||||
    tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken")
 | 
					    tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken")
 | 
				
			||||||
    train_dataset_list = []
 | 
					    train_dataset_list = []
 | 
				
			||||||
    val_dataset_list = []
 | 
					    val_dataset_list = []
 | 
				
			||||||
    for dataset_name in dataset_name:
 | 
					    for dn in dataset_name:
 | 
				
			||||||
        dataset_args = dataset_name.split(":")
 | 
					        datanames = dn.split(".")
 | 
				
			||||||
        raw_dataset = datasets.load_dataset(
 | 
					        if datanames[-1] == "gz" and datanames[-2] == "jsonl":
 | 
				
			||||||
            "json", data_files="/home/colin/develop/dataset/liwu/MNBVC/wiki/20230197/0.jsonl.gz"
 | 
					            raw_dataset = datasets.load_dataset("json", data_files=dn)
 | 
				
			||||||
        )
 | 
					        elif datanames[-1] == "json":
 | 
				
			||||||
        # raw_dataset = datasets.load_dataset(*dataset_args)
 | 
					            raw_dataset = datasets.load_dataset("json", data_files=dn)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raw_dataset = datasets.load_dataset(dn)
 | 
				
			||||||
        train_dataset, val_dataset = split_raw_dataset(raw_dataset)
 | 
					        train_dataset, val_dataset = split_raw_dataset(raw_dataset)
 | 
				
			||||||
        train_dataset = process_dataset(train_dataset, tokenizer)
 | 
					        train_dataset = process_dataset(train_dataset, tokenizer)
 | 
				
			||||||
        val_dataset = process_dataset(val_dataset, tokenizer)
 | 
					        val_dataset = process_dataset(val_dataset, tokenizer)
 | 
				
			||||||
| 
						 | 
					@ -125,6 +141,9 @@ if __name__ == "__main__":
 | 
				
			||||||
    train_dataset = ConcatDataset(train_dataset_list)
 | 
					    train_dataset = ConcatDataset(train_dataset_list)
 | 
				
			||||||
    val_dataset = ConcatDataset(val_dataset_list)
 | 
					    val_dataset = ConcatDataset(val_dataset_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_dataset = SpecialDataset()
 | 
				
			||||||
 | 
					    val_dataset = SpecialDataset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # dataloaders
 | 
					    # dataloaders
 | 
				
			||||||
    train_dataloader = DataLoader(
 | 
					    train_dataloader = DataLoader(
 | 
				
			||||||
        train_dataset,
 | 
					        train_dataset,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue