diff --git a/wit/configuration.py b/wit/configuration.py index 2eb6d09..dc7ba37 100644 --- a/wit/configuration.py +++ b/wit/configuration.py @@ -44,8 +44,8 @@ class MeaningDatasetConfig: self.level = 5 self.dataset_level = 3 self.min_subitem = 2 - self.mask_level = [0, 1, 2] - self.mask_idx = [0, 0, -1] + self.mask_level = None + self.mask_idx = None class DatasetConfig: def __init__(self): diff --git a/wit/dataset/dataset.py b/wit/dataset/dataset.py index 3979b06..9d07a4d 100644 --- a/wit/dataset/dataset.py +++ b/wit/dataset/dataset.py @@ -1,6 +1,8 @@ from dataset.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader from dataset.special_dataset import SpecialDataset from torch.utils.data import random_split, DataLoader +import torch +import os def InitDataset(config): @@ -31,10 +33,27 @@ def InitDataset(config): vocab = config.model_config.vocab_size start = vocab * (conf.level_ratio**conf.level) size = vocab * int((conf.level_ratio**conf.dataset_level)) - raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem) - # print(raw_dataset.token_frequency()) - raw_dataset.set_mask(conf.mask_level, conf.mask_idx) - train_dataset, val_dataset = raw_dataset.split(0.9) + + path = "./data/" + trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_lr{conf.level_ratio}_ms{conf.min_subitem}.pt" + valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{conf.level_ratio}_ms{conf.min_subitem}.pt" + if not os.path.exists(path): + os.mkdir(path) + if os.path.exists(trainfile) and os.path.exists(valfile): + print(f"INFO: Load dataset from {trainfile}") + print(f"INFO: Load dataset from {valfile}") + train_dataset = torch.load(trainfile) + val_dataset = torch.load(valfile) + print(f"INFO: Load dataset end") + else: + raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem) + print("INFO: raw_dataset.token_frequency" + raw_dataset.token_frequency()) + raw_dataset.set_mask(conf.mask_level, conf.mask_idx) + train_dataset, val_dataset = raw_dataset.split(0.9) + torch.save(train_dataset, trainfile) + torch.save(val_dataset, valfile) + print(f"INFO: Build and save dataset end") + train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader( config.dataloader_works ) diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index 415febb..d0fb767 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -41,7 +41,7 @@ class MeaningMap: and os.path.exists(file_rank_all) and use_cache ): - print("Load from disk cache: " + file) + print("Mapping Load from disk cache: " + file) slhwm = np.load(file_prop) self.ms_map = slhwm[:, 4:] self.ms_data = np.load(file_data) @@ -52,9 +52,9 @@ class MeaningMap: self.ms_rank_all = np.load(file_rank_all) self.ms_height = slhwm[:, 2] self.ms_weight = slhwm[:, 3] - print("Load end, elapsed:" + str(time.time() - start_time) + "s") + print("Mapping Load end, elapsed:" + str(time.time() - start_time) + "s") else: - print("Disk cache miss, build new one. size:" + str(size)) + print("Mapping Disk cache miss, build new one. size:" + str(size)) map = np.empty((size, max_subitem), dtype=np.int32) @@ -169,7 +169,7 @@ class MeaningMap: self.ms_len = ms_len self.ms_height = ms_height self.ms_weight = ms_weight - print("Disk cache build end, elapsed:" + str(time.time() - start_time) + "s") + print("Mapping Disk cache build end, elapsed:" + str(time.time() - start_time) + "s") def get_sequence(self, meaning): # return sequence[meaning] start = self.ms_start[meaning] @@ -267,7 +267,7 @@ class MeaningDataset(Dataset): np.random.seed(seed) map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache) np.random.seed(seed) - + print("Build MeaningDataset from MeaningMap.") self.mask_level = None self.mask_idx = None self.tree = [] @@ -319,6 +319,7 @@ class MeaningDataset(Dataset): self.rank_all.append(rank_all) unique, counts = np.unique(seq_len, return_counts=True) + print("Build MeaningDataset end.") print("----------------------------------------------------------------") print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start)) print("MeaningDataset size:" + str(len(seq_len))) diff --git a/wit/doc/meaning_dataset.md b/wit/doc/meaning_dataset.md index e1ae483..62e5bd5 100644 --- a/wit/doc/meaning_dataset.md +++ b/wit/doc/meaning_dataset.md @@ -17,6 +17,7 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。 11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个 12. meaning_height 当前meaning的总高度 13. meaning_weight 当前meaning的总宽度 +14. mask_level mask_idx: 表示用于训练的token的mask,mask_level=[0, 1, 2] mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 ``` diff --git a/wit/train.py b/wit/train.py index f8b6539..bac1e3a 100644 --- a/wit/train.py +++ b/wit/train.py @@ -3,43 +3,58 @@ import torch from model.lit_module import LitModule from wit.model.tokenization_qwen import QWenTokenizer -from logger import MLFLogger +from logger import MLFLogger, TBLogger import configuration import dataset.dataset as ds if __name__ == "__main__": - train_config = configuration.TrainConfig() - config = train_config.model_config + conf = configuration.TrainConfig() + config = conf.model_config - torch.manual_seed(train_config.seed) + conf.name = "bigger" # current train process name + conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" + conf.learning_rate = 0.0001 + conf.use_tril_attention_mask = None + conf.precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" + conf.train_batch_size = 8 + conf.val_batch_size = 4 + conf.num_proc = 8 + conf.max_epochs = 1000 + conf.strategy = "auto" + conf.resume_from_ckpt_path = None + conf.seed = 42 + conf.dataloader_works = 2 + conf.mask_level = None # [0, 1, 2] + conf.mask_idx = None # [0, 0, -1] + config.vocab_size = 256 config.hidden_size = 128 # 128 1024 2048 32 config.num_hidden_layers = 6 # 6 12 24 3 config.num_attention_heads = 16 # 8 8 16 - lit_module = LitModule( - train_config.pretrain_model_name, train_config.learning_rate, config, train_config.use_tril_attention_mask - ) + torch.manual_seed(conf.seed) + lit_module = LitModule(conf.pretrain_model_name, conf.learning_rate, config, conf.use_tril_attention_mask) tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken") - train_dataloader, val_dataloader = ds.InitDataset(train_config) + train_dataloader, val_dataloader = ds.InitDataset(conf) # for i in range(len(train_dataloader)): # print(train_dataloader.print_mapping(i)) torch.set_float32_matmul_precision("medium") lit_trainer = pl.Trainer( accelerator="cuda", - precision=train_config.precision, - logger=MLFLogger("./log/", run_name=train_config.name), - strategy=train_config.strategy, - max_epochs=train_config.max_epochs, + precision=conf.precision, + # logger=MLFLogger("./log/", run_name=conf.name), + logger=TBLogger("./log/", name=conf.name), + strategy=conf.strategy, + max_epochs=conf.max_epochs, ) lit_trainer.fit( lit_module, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, - ckpt_path=train_config.resume_from_ckpt_path, + ckpt_path=conf.resume_from_ckpt_path, )