use custom vocab_size.
This commit is contained in:
parent
05f17b1221
commit
0ae63298b2
|
@ -1,37 +0,0 @@
|
||||||
config: !!python/object:wit.configuration.ModelConfig
|
|
||||||
attn_dropout_prob: 0.0
|
|
||||||
bf16: false
|
|
||||||
chat_format: chatml
|
|
||||||
do_sample: true
|
|
||||||
emb_dropout_prob: 0.0
|
|
||||||
fp16: false
|
|
||||||
fp32: false
|
|
||||||
hidden_size: 128
|
|
||||||
initializer_range: 0.02
|
|
||||||
intermediate_size: 5504
|
|
||||||
layer_norm_epsilon: 1.0e-06
|
|
||||||
max_new_tokens: 512
|
|
||||||
max_position_embeddings: 8192
|
|
||||||
max_window_size: 6144
|
|
||||||
model_max_length: 8192
|
|
||||||
no_bias: true
|
|
||||||
num_attention_heads: 8
|
|
||||||
num_hidden_layers: 6
|
|
||||||
repetition_penalty: 1.1
|
|
||||||
rotary_emb_base: 10000
|
|
||||||
rotary_pct: 1.0
|
|
||||||
scale_attn_weights: true
|
|
||||||
softmax_in_fp32: false
|
|
||||||
tie_word_embeddings: false
|
|
||||||
top_k: 0
|
|
||||||
top_p: 0.8
|
|
||||||
use_cache: true
|
|
||||||
use_cache_kernel: false
|
|
||||||
use_cache_quantization: false
|
|
||||||
use_dynamic_ntk: true
|
|
||||||
use_flash_attn: auto
|
|
||||||
use_logn_attn: true
|
|
||||||
vocab_size: 4096
|
|
||||||
learning_rate: 0.0001
|
|
||||||
pretrained_model_dir: null
|
|
||||||
use_tril_attention_mask: null
|
|
|
@ -9,7 +9,6 @@ from modeling_wit import QWenLMHeadModel
|
||||||
from wit.configuration import ModelConfig
|
from wit.configuration import ModelConfig
|
||||||
|
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from modelscope import snapshot_download
|
|
||||||
|
|
||||||
|
|
||||||
class LitModule(pl.LightningModule):
|
class LitModule(pl.LightningModule):
|
||||||
|
@ -26,6 +25,8 @@ class LitModule(pl.LightningModule):
|
||||||
config = ModelConfig()
|
config = ModelConfig()
|
||||||
model = QWenLMHeadModel(config)
|
model = QWenLMHeadModel(config)
|
||||||
if pretrained_model_dir != None:
|
if pretrained_model_dir != None:
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
model = model.from_pretrained(snapshot_download(pretrained_model_dir))
|
model = model.from_pretrained(snapshot_download(pretrained_model_dir))
|
||||||
self.llm = self.register_core_module(model)
|
self.llm = self.register_core_module(model)
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
|
|
|
@ -10,7 +10,7 @@ from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, S
|
||||||
|
|
||||||
|
|
||||||
class SpecialDataset(Dataset):
|
class SpecialDataset(Dataset):
|
||||||
def __init__(self, start=1, end=320, size=32768): # 1048576 32768
|
def __init__(self, start=1, end=128, size=32768): # 1048576 32768
|
||||||
self.size = size
|
self.size = size
|
||||||
self.features = []
|
self.features = []
|
||||||
a = torch.randint(start, end, [size])
|
a = torch.randint(start, end, [size])
|
||||||
|
@ -20,7 +20,7 @@ class SpecialDataset(Dataset):
|
||||||
z = torch.zeros([size]).long()
|
z = torch.zeros([size]).long()
|
||||||
# self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0)
|
# self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0)
|
||||||
# self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long()
|
# self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long()
|
||||||
self.data = torch.stack([a, a + 1, a + 2]).permute(1, 0).long()
|
self.data = torch.stack([a, a + a, a + a]).permute(1, 0).long()
|
||||||
# self.data = torch.stack([a, b, a]).permute(1, 0).long()
|
# self.data = torch.stack([a, b, a]).permute(1, 0).long()
|
||||||
# self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long()
|
# self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long()
|
||||||
|
|
||||||
|
|
|
@ -9,9 +9,7 @@ import torch
|
||||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
BatchEncoding,
|
|
||||||
DefaultDataCollator,
|
DefaultDataCollator,
|
||||||
PreTrainedTokenizer,
|
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from lit_module import LitModule
|
from lit_module import LitModule
|
||||||
|
@ -33,7 +31,7 @@ max_epochs = 1000
|
||||||
strategy = "auto"
|
strategy = "auto"
|
||||||
resume_from_ckpt_path = None
|
resume_from_ckpt_path = None
|
||||||
seed = 42
|
seed = 42
|
||||||
vocab_size = 4096
|
vocab_size = 256
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue