add mnbvc dataset .
This commit is contained in:
		
							parent
							
								
									8120be66a6
								
							
						
					
					
						commit
						1622bf3054
					
				|  | @ -0,0 +1,104 @@ | |||
| import argparse | ||||
| from functools import partial | ||||
| from itertools import chain | ||||
| from typing import Dict, Tuple | ||||
| 
 | ||||
| import datasets | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split | ||||
| 
 | ||||
| from transformers import ( | ||||
|     BatchEncoding, | ||||
|     DefaultDataCollator, | ||||
|     PreTrainedTokenizer, | ||||
|     set_seed, | ||||
| ) | ||||
| from tokenization_qwen import QWenTokenizer | ||||
| 
 | ||||
| dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki"] | ||||
| dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki/20230198/58.jsonl.gz"] | ||||
| num_proc = 8 | ||||
| seed = 42 | ||||
| 
 | ||||
| 
 | ||||
| def split_raw_dataset( | ||||
|     raw_dataset: datasets.DatasetDict, | ||||
| ) -> Tuple[datasets.Dataset, datasets.Dataset]: | ||||
|     if "validation" in raw_dataset: | ||||
|         train_dataset, val_dataset = raw_dataset["train"], raw_dataset["validation"] | ||||
|     else: | ||||
|         raw_dataset = raw_dataset["train"].train_test_split(test_size=0.05, seed=seed) | ||||
|         train_dataset, val_dataset = raw_dataset["train"], raw_dataset["test"] | ||||
|     return train_dataset, val_dataset | ||||
| 
 | ||||
| 
 | ||||
| def process_dataset(dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer) -> datasets.Dataset: | ||||
|     def group_texts(examples: Dict[str, list], block_size: int = 512) -> BatchEncoding: | ||||
|         concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} | ||||
|         total_length = len(concatenated_examples[list(examples.keys())[0]]) | ||||
|         total_length = (total_length // block_size) * block_size | ||||
|         result = { | ||||
|             k: [t[i : i + block_size] for i in range(0, total_length, block_size)] | ||||
|             for k, t in concatenated_examples.items() | ||||
|         } | ||||
|         result["labels"] = result["input_ids"].copy() | ||||
|         result = BatchEncoding(result) | ||||
|         return result | ||||
| 
 | ||||
|     def format_inputs(examples): | ||||
|         p = examples["段落"] | ||||
|         mergeLine = "" | ||||
|         for line in p: | ||||
|             mergeLine += line["内容"] + "\n" | ||||
|         return {"text": mergeLine} | ||||
| 
 | ||||
|     def tokenize_inputs( | ||||
|         examples: Dict[str, list], | ||||
|         tokenizer: PreTrainedTokenizer, | ||||
|         column_name: str = "text", | ||||
|     ) -> BatchEncoding: | ||||
|         logits = tokenizer(examples[column_name], return_attention_mask=False) | ||||
|         return logits | ||||
| 
 | ||||
|     dataset_column_names = list(dataset.features) | ||||
|     dataset = dataset.map( | ||||
|         partial(format_inputs), | ||||
|         batched=False, | ||||
|         num_proc=num_proc, | ||||
|         remove_columns=dataset_column_names, | ||||
|     ) | ||||
|     dataset_column_names = list(dataset.features) | ||||
|     dataset = dataset.map( | ||||
|         partial(tokenize_inputs, tokenizer=tokenizer), | ||||
|         batched=True, | ||||
|         num_proc=num_proc, | ||||
|         remove_columns=dataset_column_names, | ||||
|     ) | ||||
|     dataset = dataset.map( | ||||
|         partial(group_texts, block_size=tokenizer.model_max_length), | ||||
|         batched=True, | ||||
|         num_proc=num_proc, | ||||
|     ) | ||||
| 
 | ||||
|     return dataset | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     set_seed(seed) | ||||
|     tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") | ||||
|     train_dataset_list = [] | ||||
|     val_dataset_list = [] | ||||
|     for dn in dataset_name: | ||||
|         datanames = dn.split(".") | ||||
|         if datanames[-1] == "gz" and datanames[-2] == "jsonl": | ||||
|             raw_dataset = datasets.load_dataset("json", data_files=dn) | ||||
|         elif datanames[-1] == "json": | ||||
|             raw_dataset = datasets.load_dataset("json", data_files=dn) | ||||
|         else: | ||||
|             raw_dataset = datasets.load_dataset(dn) | ||||
|         train_dataset, val_dataset = split_raw_dataset(raw_dataset) | ||||
|         train_dataset = process_dataset(train_dataset, tokenizer) | ||||
|         val_dataset = process_dataset(val_dataset, tokenizer) | ||||
|         train_dataset_list.append(train_dataset) | ||||
|         val_dataset_list.append(val_dataset) | ||||
		Loading…
	
		Reference in New Issue