Refine meaning dataset document.
This commit is contained in:
parent
383125edc9
commit
f8480678d8
|
@ -44,8 +44,8 @@ class MeaningDatasetConfig:
|
|||
self.level = 5
|
||||
self.dataset_level = 3
|
||||
self.min_subitem = 2
|
||||
self.mask_level = [0, 1, 2]
|
||||
self.mask_idx = [0, 0, -1]
|
||||
self.mask_level = None
|
||||
self.mask_idx = None
|
||||
|
||||
class DatasetConfig:
|
||||
def __init__(self):
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from dataset.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
|
||||
from dataset.special_dataset import SpecialDataset
|
||||
from torch.utils.data import random_split, DataLoader
|
||||
import torch
|
||||
import os
|
||||
|
||||
|
||||
def InitDataset(config):
|
||||
|
@ -31,10 +33,27 @@ def InitDataset(config):
|
|||
vocab = config.model_config.vocab_size
|
||||
start = vocab * (conf.level_ratio**conf.level)
|
||||
size = vocab * int((conf.level_ratio**conf.dataset_level))
|
||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem)
|
||||
# print(raw_dataset.token_frequency())
|
||||
raw_dataset.set_mask(conf.mask_level, conf.mask_idx)
|
||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||
|
||||
path = "./data/"
|
||||
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_lr{conf.level_ratio}_ms{conf.min_subitem}.pt"
|
||||
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{conf.level_ratio}_ms{conf.min_subitem}.pt"
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
if os.path.exists(trainfile) and os.path.exists(valfile):
|
||||
print(f"INFO: Load dataset from {trainfile}")
|
||||
print(f"INFO: Load dataset from {valfile}")
|
||||
train_dataset = torch.load(trainfile)
|
||||
val_dataset = torch.load(valfile)
|
||||
print(f"INFO: Load dataset end")
|
||||
else:
|
||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem)
|
||||
print("INFO: raw_dataset.token_frequency" + raw_dataset.token_frequency())
|
||||
raw_dataset.set_mask(conf.mask_level, conf.mask_idx)
|
||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||
torch.save(train_dataset, trainfile)
|
||||
torch.save(val_dataset, valfile)
|
||||
print(f"INFO: Build and save dataset end")
|
||||
|
||||
train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader(
|
||||
config.dataloader_works
|
||||
)
|
||||
|
|
|
@ -41,7 +41,7 @@ class MeaningMap:
|
|||
and os.path.exists(file_rank_all)
|
||||
and use_cache
|
||||
):
|
||||
print("Load from disk cache: " + file)
|
||||
print("Mapping Load from disk cache: " + file)
|
||||
slhwm = np.load(file_prop)
|
||||
self.ms_map = slhwm[:, 4:]
|
||||
self.ms_data = np.load(file_data)
|
||||
|
@ -52,9 +52,9 @@ class MeaningMap:
|
|||
self.ms_rank_all = np.load(file_rank_all)
|
||||
self.ms_height = slhwm[:, 2]
|
||||
self.ms_weight = slhwm[:, 3]
|
||||
print("Load end, elapsed:" + str(time.time() - start_time) + "s")
|
||||
print("Mapping Load end, elapsed:" + str(time.time() - start_time) + "s")
|
||||
else:
|
||||
print("Disk cache miss, build new one. size:" + str(size))
|
||||
print("Mapping Disk cache miss, build new one. size:" + str(size))
|
||||
|
||||
map = np.empty((size, max_subitem), dtype=np.int32)
|
||||
|
||||
|
@ -169,7 +169,7 @@ class MeaningMap:
|
|||
self.ms_len = ms_len
|
||||
self.ms_height = ms_height
|
||||
self.ms_weight = ms_weight
|
||||
print("Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
|
||||
print("Mapping Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
|
||||
|
||||
def get_sequence(self, meaning): # return sequence[meaning]
|
||||
start = self.ms_start[meaning]
|
||||
|
@ -267,7 +267,7 @@ class MeaningDataset(Dataset):
|
|||
np.random.seed(seed)
|
||||
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
|
||||
np.random.seed(seed)
|
||||
|
||||
print("Build MeaningDataset from MeaningMap.")
|
||||
self.mask_level = None
|
||||
self.mask_idx = None
|
||||
self.tree = []
|
||||
|
@ -319,6 +319,7 @@ class MeaningDataset(Dataset):
|
|||
self.rank_all.append(rank_all)
|
||||
|
||||
unique, counts = np.unique(seq_len, return_counts=True)
|
||||
print("Build MeaningDataset end.")
|
||||
print("----------------------------------------------------------------")
|
||||
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
|
||||
print("MeaningDataset size:" + str(len(seq_len)))
|
||||
|
|
|
@ -17,6 +17,7 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
|
|||
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
||||
12. meaning_height 当前meaning的总高度
|
||||
13. meaning_weight 当前meaning的总宽度
|
||||
14. mask_level mask_idx: 表示用于训练的token的mask,mask_level=[0, 1, 2] mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
||||
|
||||
|
||||
```
|
||||
|
|
41
wit/train.py
41
wit/train.py
|
@ -3,43 +3,58 @@ import torch
|
|||
|
||||
from model.lit_module import LitModule
|
||||
from wit.model.tokenization_qwen import QWenTokenizer
|
||||
from logger import MLFLogger
|
||||
from logger import MLFLogger, TBLogger
|
||||
|
||||
import configuration
|
||||
import dataset.dataset as ds
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
train_config = configuration.TrainConfig()
|
||||
config = train_config.model_config
|
||||
conf = configuration.TrainConfig()
|
||||
config = conf.model_config
|
||||
|
||||
torch.manual_seed(train_config.seed)
|
||||
conf.name = "bigger" # current train process name
|
||||
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
|
||||
conf.learning_rate = 0.0001
|
||||
conf.use_tril_attention_mask = None
|
||||
conf.precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true"
|
||||
conf.train_batch_size = 8
|
||||
conf.val_batch_size = 4
|
||||
conf.num_proc = 8
|
||||
conf.max_epochs = 1000
|
||||
conf.strategy = "auto"
|
||||
conf.resume_from_ckpt_path = None
|
||||
conf.seed = 42
|
||||
conf.dataloader_works = 2
|
||||
|
||||
conf.mask_level = None # [0, 1, 2]
|
||||
conf.mask_idx = None # [0, 0, -1]
|
||||
|
||||
config.vocab_size = 256
|
||||
config.hidden_size = 128 # 128 1024 2048 32
|
||||
config.num_hidden_layers = 6 # 6 12 24 3
|
||||
config.num_attention_heads = 16 # 8 8 16
|
||||
|
||||
lit_module = LitModule(
|
||||
train_config.pretrain_model_name, train_config.learning_rate, config, train_config.use_tril_attention_mask
|
||||
)
|
||||
torch.manual_seed(conf.seed)
|
||||
lit_module = LitModule(conf.pretrain_model_name, conf.learning_rate, config, conf.use_tril_attention_mask)
|
||||
tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken")
|
||||
|
||||
train_dataloader, val_dataloader = ds.InitDataset(train_config)
|
||||
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
||||
# for i in range(len(train_dataloader)):
|
||||
# print(train_dataloader.print_mapping(i))
|
||||
|
||||
torch.set_float32_matmul_precision("medium")
|
||||
lit_trainer = pl.Trainer(
|
||||
accelerator="cuda",
|
||||
precision=train_config.precision,
|
||||
logger=MLFLogger("./log/", run_name=train_config.name),
|
||||
strategy=train_config.strategy,
|
||||
max_epochs=train_config.max_epochs,
|
||||
precision=conf.precision,
|
||||
# logger=MLFLogger("./log/", run_name=conf.name),
|
||||
logger=TBLogger("./log/", name=conf.name),
|
||||
strategy=conf.strategy,
|
||||
max_epochs=conf.max_epochs,
|
||||
)
|
||||
lit_trainer.fit(
|
||||
lit_module,
|
||||
train_dataloaders=train_dataloader,
|
||||
val_dataloaders=val_dataloader,
|
||||
ckpt_path=train_config.resume_from_ckpt_path,
|
||||
ckpt_path=conf.resume_from_ckpt_path,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue