Refine meaning dataset document.
This commit is contained in:
parent
383125edc9
commit
f8480678d8
|
@ -44,8 +44,8 @@ class MeaningDatasetConfig:
|
||||||
self.level = 5
|
self.level = 5
|
||||||
self.dataset_level = 3
|
self.dataset_level = 3
|
||||||
self.min_subitem = 2
|
self.min_subitem = 2
|
||||||
self.mask_level = [0, 1, 2]
|
self.mask_level = None
|
||||||
self.mask_idx = [0, 0, -1]
|
self.mask_idx = None
|
||||||
|
|
||||||
class DatasetConfig:
|
class DatasetConfig:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from dataset.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
|
from dataset.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
|
||||||
from dataset.special_dataset import SpecialDataset
|
from dataset.special_dataset import SpecialDataset
|
||||||
from torch.utils.data import random_split, DataLoader
|
from torch.utils.data import random_split, DataLoader
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def InitDataset(config):
|
def InitDataset(config):
|
||||||
|
@ -31,10 +33,27 @@ def InitDataset(config):
|
||||||
vocab = config.model_config.vocab_size
|
vocab = config.model_config.vocab_size
|
||||||
start = vocab * (conf.level_ratio**conf.level)
|
start = vocab * (conf.level_ratio**conf.level)
|
||||||
size = vocab * int((conf.level_ratio**conf.dataset_level))
|
size = vocab * int((conf.level_ratio**conf.dataset_level))
|
||||||
|
|
||||||
|
path = "./data/"
|
||||||
|
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_lr{conf.level_ratio}_ms{conf.min_subitem}.pt"
|
||||||
|
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{conf.level_ratio}_ms{conf.min_subitem}.pt"
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.mkdir(path)
|
||||||
|
if os.path.exists(trainfile) and os.path.exists(valfile):
|
||||||
|
print(f"INFO: Load dataset from {trainfile}")
|
||||||
|
print(f"INFO: Load dataset from {valfile}")
|
||||||
|
train_dataset = torch.load(trainfile)
|
||||||
|
val_dataset = torch.load(valfile)
|
||||||
|
print(f"INFO: Load dataset end")
|
||||||
|
else:
|
||||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem)
|
raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem)
|
||||||
# print(raw_dataset.token_frequency())
|
print("INFO: raw_dataset.token_frequency" + raw_dataset.token_frequency())
|
||||||
raw_dataset.set_mask(conf.mask_level, conf.mask_idx)
|
raw_dataset.set_mask(conf.mask_level, conf.mask_idx)
|
||||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||||
|
torch.save(train_dataset, trainfile)
|
||||||
|
torch.save(val_dataset, valfile)
|
||||||
|
print(f"INFO: Build and save dataset end")
|
||||||
|
|
||||||
train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader(
|
train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader(
|
||||||
config.dataloader_works
|
config.dataloader_works
|
||||||
)
|
)
|
||||||
|
|
|
@ -41,7 +41,7 @@ class MeaningMap:
|
||||||
and os.path.exists(file_rank_all)
|
and os.path.exists(file_rank_all)
|
||||||
and use_cache
|
and use_cache
|
||||||
):
|
):
|
||||||
print("Load from disk cache: " + file)
|
print("Mapping Load from disk cache: " + file)
|
||||||
slhwm = np.load(file_prop)
|
slhwm = np.load(file_prop)
|
||||||
self.ms_map = slhwm[:, 4:]
|
self.ms_map = slhwm[:, 4:]
|
||||||
self.ms_data = np.load(file_data)
|
self.ms_data = np.load(file_data)
|
||||||
|
@ -52,9 +52,9 @@ class MeaningMap:
|
||||||
self.ms_rank_all = np.load(file_rank_all)
|
self.ms_rank_all = np.load(file_rank_all)
|
||||||
self.ms_height = slhwm[:, 2]
|
self.ms_height = slhwm[:, 2]
|
||||||
self.ms_weight = slhwm[:, 3]
|
self.ms_weight = slhwm[:, 3]
|
||||||
print("Load end, elapsed:" + str(time.time() - start_time) + "s")
|
print("Mapping Load end, elapsed:" + str(time.time() - start_time) + "s")
|
||||||
else:
|
else:
|
||||||
print("Disk cache miss, build new one. size:" + str(size))
|
print("Mapping Disk cache miss, build new one. size:" + str(size))
|
||||||
|
|
||||||
map = np.empty((size, max_subitem), dtype=np.int32)
|
map = np.empty((size, max_subitem), dtype=np.int32)
|
||||||
|
|
||||||
|
@ -169,7 +169,7 @@ class MeaningMap:
|
||||||
self.ms_len = ms_len
|
self.ms_len = ms_len
|
||||||
self.ms_height = ms_height
|
self.ms_height = ms_height
|
||||||
self.ms_weight = ms_weight
|
self.ms_weight = ms_weight
|
||||||
print("Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
|
print("Mapping Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
|
||||||
|
|
||||||
def get_sequence(self, meaning): # return sequence[meaning]
|
def get_sequence(self, meaning): # return sequence[meaning]
|
||||||
start = self.ms_start[meaning]
|
start = self.ms_start[meaning]
|
||||||
|
@ -267,7 +267,7 @@ class MeaningDataset(Dataset):
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
|
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
print("Build MeaningDataset from MeaningMap.")
|
||||||
self.mask_level = None
|
self.mask_level = None
|
||||||
self.mask_idx = None
|
self.mask_idx = None
|
||||||
self.tree = []
|
self.tree = []
|
||||||
|
@ -319,6 +319,7 @@ class MeaningDataset(Dataset):
|
||||||
self.rank_all.append(rank_all)
|
self.rank_all.append(rank_all)
|
||||||
|
|
||||||
unique, counts = np.unique(seq_len, return_counts=True)
|
unique, counts = np.unique(seq_len, return_counts=True)
|
||||||
|
print("Build MeaningDataset end.")
|
||||||
print("----------------------------------------------------------------")
|
print("----------------------------------------------------------------")
|
||||||
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
|
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
|
||||||
print("MeaningDataset size:" + str(len(seq_len)))
|
print("MeaningDataset size:" + str(len(seq_len)))
|
||||||
|
|
|
@ -17,6 +17,7 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
|
||||||
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
||||||
12. meaning_height 当前meaning的总高度
|
12. meaning_height 当前meaning的总高度
|
||||||
13. meaning_weight 当前meaning的总宽度
|
13. meaning_weight 当前meaning的总宽度
|
||||||
|
14. mask_level mask_idx: 表示用于训练的token的mask,mask_level=[0, 1, 2] mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
41
wit/train.py
41
wit/train.py
|
@ -3,43 +3,58 @@ import torch
|
||||||
|
|
||||||
from model.lit_module import LitModule
|
from model.lit_module import LitModule
|
||||||
from wit.model.tokenization_qwen import QWenTokenizer
|
from wit.model.tokenization_qwen import QWenTokenizer
|
||||||
from logger import MLFLogger
|
from logger import MLFLogger, TBLogger
|
||||||
|
|
||||||
import configuration
|
import configuration
|
||||||
import dataset.dataset as ds
|
import dataset.dataset as ds
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
train_config = configuration.TrainConfig()
|
conf = configuration.TrainConfig()
|
||||||
config = train_config.model_config
|
config = conf.model_config
|
||||||
|
|
||||||
torch.manual_seed(train_config.seed)
|
conf.name = "bigger" # current train process name
|
||||||
|
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
|
||||||
|
conf.learning_rate = 0.0001
|
||||||
|
conf.use_tril_attention_mask = None
|
||||||
|
conf.precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true"
|
||||||
|
conf.train_batch_size = 8
|
||||||
|
conf.val_batch_size = 4
|
||||||
|
conf.num_proc = 8
|
||||||
|
conf.max_epochs = 1000
|
||||||
|
conf.strategy = "auto"
|
||||||
|
conf.resume_from_ckpt_path = None
|
||||||
|
conf.seed = 42
|
||||||
|
conf.dataloader_works = 2
|
||||||
|
|
||||||
|
conf.mask_level = None # [0, 1, 2]
|
||||||
|
conf.mask_idx = None # [0, 0, -1]
|
||||||
|
|
||||||
config.vocab_size = 256
|
config.vocab_size = 256
|
||||||
config.hidden_size = 128 # 128 1024 2048 32
|
config.hidden_size = 128 # 128 1024 2048 32
|
||||||
config.num_hidden_layers = 6 # 6 12 24 3
|
config.num_hidden_layers = 6 # 6 12 24 3
|
||||||
config.num_attention_heads = 16 # 8 8 16
|
config.num_attention_heads = 16 # 8 8 16
|
||||||
|
|
||||||
lit_module = LitModule(
|
torch.manual_seed(conf.seed)
|
||||||
train_config.pretrain_model_name, train_config.learning_rate, config, train_config.use_tril_attention_mask
|
lit_module = LitModule(conf.pretrain_model_name, conf.learning_rate, config, conf.use_tril_attention_mask)
|
||||||
)
|
|
||||||
tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken")
|
tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken")
|
||||||
|
|
||||||
train_dataloader, val_dataloader = ds.InitDataset(train_config)
|
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
||||||
# for i in range(len(train_dataloader)):
|
# for i in range(len(train_dataloader)):
|
||||||
# print(train_dataloader.print_mapping(i))
|
# print(train_dataloader.print_mapping(i))
|
||||||
|
|
||||||
torch.set_float32_matmul_precision("medium")
|
torch.set_float32_matmul_precision("medium")
|
||||||
lit_trainer = pl.Trainer(
|
lit_trainer = pl.Trainer(
|
||||||
accelerator="cuda",
|
accelerator="cuda",
|
||||||
precision=train_config.precision,
|
precision=conf.precision,
|
||||||
logger=MLFLogger("./log/", run_name=train_config.name),
|
# logger=MLFLogger("./log/", run_name=conf.name),
|
||||||
strategy=train_config.strategy,
|
logger=TBLogger("./log/", name=conf.name),
|
||||||
max_epochs=train_config.max_epochs,
|
strategy=conf.strategy,
|
||||||
|
max_epochs=conf.max_epochs,
|
||||||
)
|
)
|
||||||
lit_trainer.fit(
|
lit_trainer.fit(
|
||||||
lit_module,
|
lit_module,
|
||||||
train_dataloaders=train_dataloader,
|
train_dataloaders=train_dataloader,
|
||||||
val_dataloaders=val_dataloader,
|
val_dataloaders=val_dataloader,
|
||||||
ckpt_path=train_config.resume_from_ckpt_path,
|
ckpt_path=conf.resume_from_ckpt_path,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue