Refine meaning dataset document.

This commit is contained in:
Colin 2025-02-18 19:35:23 +08:00
parent 383125edc9
commit f8480678d8
5 changed files with 60 additions and 24 deletions

View File

@ -44,8 +44,8 @@ class MeaningDatasetConfig:
self.level = 5
self.dataset_level = 3
self.min_subitem = 2
self.mask_level = [0, 1, 2]
self.mask_idx = [0, 0, -1]
self.mask_level = None
self.mask_idx = None
class DatasetConfig:
def __init__(self):

View File

@ -1,6 +1,8 @@
from dataset.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
from dataset.special_dataset import SpecialDataset
from torch.utils.data import random_split, DataLoader
import torch
import os
def InitDataset(config):
@ -31,10 +33,27 @@ def InitDataset(config):
vocab = config.model_config.vocab_size
start = vocab * (conf.level_ratio**conf.level)
size = vocab * int((conf.level_ratio**conf.dataset_level))
raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem)
# print(raw_dataset.token_frequency())
raw_dataset.set_mask(conf.mask_level, conf.mask_idx)
train_dataset, val_dataset = raw_dataset.split(0.9)
path = "./data/"
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_lr{conf.level_ratio}_ms{conf.min_subitem}.pt"
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{conf.level_ratio}_ms{conf.min_subitem}.pt"
if not os.path.exists(path):
os.mkdir(path)
if os.path.exists(trainfile) and os.path.exists(valfile):
print(f"INFO: Load dataset from {trainfile}")
print(f"INFO: Load dataset from {valfile}")
train_dataset = torch.load(trainfile)
val_dataset = torch.load(valfile)
print(f"INFO: Load dataset end")
else:
raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem)
print("INFO: raw_dataset.token_frequency" + raw_dataset.token_frequency())
raw_dataset.set_mask(conf.mask_level, conf.mask_idx)
train_dataset, val_dataset = raw_dataset.split(0.9)
torch.save(train_dataset, trainfile)
torch.save(val_dataset, valfile)
print(f"INFO: Build and save dataset end")
train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader(
config.dataloader_works
)

View File

@ -41,7 +41,7 @@ class MeaningMap:
and os.path.exists(file_rank_all)
and use_cache
):
print("Load from disk cache: " + file)
print("Mapping Load from disk cache: " + file)
slhwm = np.load(file_prop)
self.ms_map = slhwm[:, 4:]
self.ms_data = np.load(file_data)
@ -52,9 +52,9 @@ class MeaningMap:
self.ms_rank_all = np.load(file_rank_all)
self.ms_height = slhwm[:, 2]
self.ms_weight = slhwm[:, 3]
print("Load end, elapsed:" + str(time.time() - start_time) + "s")
print("Mapping Load end, elapsed:" + str(time.time() - start_time) + "s")
else:
print("Disk cache miss, build new one. size:" + str(size))
print("Mapping Disk cache miss, build new one. size:" + str(size))
map = np.empty((size, max_subitem), dtype=np.int32)
@ -169,7 +169,7 @@ class MeaningMap:
self.ms_len = ms_len
self.ms_height = ms_height
self.ms_weight = ms_weight
print("Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
print("Mapping Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
def get_sequence(self, meaning): # return sequence[meaning]
start = self.ms_start[meaning]
@ -267,7 +267,7 @@ class MeaningDataset(Dataset):
np.random.seed(seed)
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
np.random.seed(seed)
print("Build MeaningDataset from MeaningMap.")
self.mask_level = None
self.mask_idx = None
self.tree = []
@ -319,6 +319,7 @@ class MeaningDataset(Dataset):
self.rank_all.append(rank_all)
unique, counts = np.unique(seq_len, return_counts=True)
print("Build MeaningDataset end.")
print("----------------------------------------------------------------")
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
print("MeaningDataset size:" + str(len(seq_len)))

View File

@ -17,6 +17,7 @@ meaning数据集是一个模仿自然语言以及抽象表达的数据集。
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层index=-1:最后一个index=0:第一个
12. meaning_height 当前meaning的总高度
13. meaning_weight 当前meaning的总宽度
14. mask_level mask_idx: 表示用于训练的token的mask,mask_level=[0, 1, 2] mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
```

View File

@ -3,43 +3,58 @@ import torch
from model.lit_module import LitModule
from wit.model.tokenization_qwen import QWenTokenizer
from logger import MLFLogger
from logger import MLFLogger, TBLogger
import configuration
import dataset.dataset as ds
if __name__ == "__main__":
train_config = configuration.TrainConfig()
config = train_config.model_config
conf = configuration.TrainConfig()
config = conf.model_config
torch.manual_seed(train_config.seed)
conf.name = "bigger" # current train process name
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
conf.learning_rate = 0.0001
conf.use_tril_attention_mask = None
conf.precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true"
conf.train_batch_size = 8
conf.val_batch_size = 4
conf.num_proc = 8
conf.max_epochs = 1000
conf.strategy = "auto"
conf.resume_from_ckpt_path = None
conf.seed = 42
conf.dataloader_works = 2
conf.mask_level = None # [0, 1, 2]
conf.mask_idx = None # [0, 0, -1]
config.vocab_size = 256
config.hidden_size = 128 # 128 1024 2048 32
config.num_hidden_layers = 6 # 6 12 24 3
config.num_attention_heads = 16 # 8 8 16
lit_module = LitModule(
train_config.pretrain_model_name, train_config.learning_rate, config, train_config.use_tril_attention_mask
)
torch.manual_seed(conf.seed)
lit_module = LitModule(conf.pretrain_model_name, conf.learning_rate, config, conf.use_tril_attention_mask)
tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken")
train_dataloader, val_dataloader = ds.InitDataset(train_config)
train_dataloader, val_dataloader = ds.InitDataset(conf)
# for i in range(len(train_dataloader)):
# print(train_dataloader.print_mapping(i))
torch.set_float32_matmul_precision("medium")
lit_trainer = pl.Trainer(
accelerator="cuda",
precision=train_config.precision,
logger=MLFLogger("./log/", run_name=train_config.name),
strategy=train_config.strategy,
max_epochs=train_config.max_epochs,
precision=conf.precision,
# logger=MLFLogger("./log/", run_name=conf.name),
logger=TBLogger("./log/", name=conf.name),
strategy=conf.strategy,
max_epochs=conf.max_epochs,
)
lit_trainer.fit(
lit_module,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
ckpt_path=train_config.resume_from_ckpt_path,
ckpt_path=conf.resume_from_ckpt_path,
)