From 0ae63298b2191e0a3e1a1990998b6d12f000177a Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 14 Mar 2024 13:28:40 +0800 Subject: [PATCH] use custom vocab_size. --- wit/lightning_logs/version_0/hparams.yaml | 37 ----------------------- wit/lit_module.py | 3 +- wit/special_dataset.py | 4 +-- wit/train.py | 4 +-- 4 files changed, 5 insertions(+), 43 deletions(-) delete mode 100644 wit/lightning_logs/version_0/hparams.yaml diff --git a/wit/lightning_logs/version_0/hparams.yaml b/wit/lightning_logs/version_0/hparams.yaml deleted file mode 100644 index 258ea96..0000000 --- a/wit/lightning_logs/version_0/hparams.yaml +++ /dev/null @@ -1,37 +0,0 @@ -config: !!python/object:wit.configuration.ModelConfig - attn_dropout_prob: 0.0 - bf16: false - chat_format: chatml - do_sample: true - emb_dropout_prob: 0.0 - fp16: false - fp32: false - hidden_size: 128 - initializer_range: 0.02 - intermediate_size: 5504 - layer_norm_epsilon: 1.0e-06 - max_new_tokens: 512 - max_position_embeddings: 8192 - max_window_size: 6144 - model_max_length: 8192 - no_bias: true - num_attention_heads: 8 - num_hidden_layers: 6 - repetition_penalty: 1.1 - rotary_emb_base: 10000 - rotary_pct: 1.0 - scale_attn_weights: true - softmax_in_fp32: false - tie_word_embeddings: false - top_k: 0 - top_p: 0.8 - use_cache: true - use_cache_kernel: false - use_cache_quantization: false - use_dynamic_ntk: true - use_flash_attn: auto - use_logn_attn: true - vocab_size: 4096 -learning_rate: 0.0001 -pretrained_model_dir: null -use_tril_attention_mask: null diff --git a/wit/lit_module.py b/wit/lit_module.py index 07890a4..ced90cb 100644 --- a/wit/lit_module.py +++ b/wit/lit_module.py @@ -9,7 +9,6 @@ from modeling_wit import QWenLMHeadModel from wit.configuration import ModelConfig from transformers import AutoConfig -from modelscope import snapshot_download class LitModule(pl.LightningModule): @@ -26,6 +25,8 @@ class LitModule(pl.LightningModule): config = ModelConfig() model = QWenLMHeadModel(config) if pretrained_model_dir != None: + from modelscope import snapshot_download + model = model.from_pretrained(snapshot_download(pretrained_model_dir)) self.llm = self.register_core_module(model) self.learning_rate = learning_rate diff --git a/wit/special_dataset.py b/wit/special_dataset.py index b9a6dbc..243f333 100644 --- a/wit/special_dataset.py +++ b/wit/special_dataset.py @@ -10,7 +10,7 @@ from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, S class SpecialDataset(Dataset): - def __init__(self, start=1, end=320, size=32768): # 1048576 32768 + def __init__(self, start=1, end=128, size=32768): # 1048576 32768 self.size = size self.features = [] a = torch.randint(start, end, [size]) @@ -20,7 +20,7 @@ class SpecialDataset(Dataset): z = torch.zeros([size]).long() # self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0) # self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long() - self.data = torch.stack([a, a + 1, a + 2]).permute(1, 0).long() + self.data = torch.stack([a, a + a, a + a]).permute(1, 0).long() # self.data = torch.stack([a, b, a]).permute(1, 0).long() # self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long() diff --git a/wit/train.py b/wit/train.py index bfa8968..b54ccd9 100644 --- a/wit/train.py +++ b/wit/train.py @@ -9,9 +9,7 @@ import torch from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset from transformers import ( - BatchEncoding, DefaultDataCollator, - PreTrainedTokenizer, set_seed, ) from lit_module import LitModule @@ -33,7 +31,7 @@ max_epochs = 1000 strategy = "auto" resume_from_ckpt_path = None seed = 42 -vocab_size = 4096 +vocab_size = 256 if __name__ == "__main__":