diff --git a/wit/inference.py b/wit/inference.py index 956b827..e9dc575 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -12,6 +12,7 @@ if __name__ == "__main__": checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt" checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt" + # checkpoint_path = "log/bigger/version_11/checkpoints/epoch=25-step=128336.ckpt" qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() @@ -41,4 +42,4 @@ if __name__ == "__main__": if item[i] != next_token: node.set_seq_prop(i, "ERR_" + str(next_token)) print(str(item[i]) + " " + str(next_token) + " ERROR") - # node.print() + node.print() diff --git a/wit/model/modeling_rwkv7.py b/wit/model/modeling_rwkv7.py index a099a3b..dd24c65 100644 --- a/wit/model/modeling_rwkv7.py +++ b/wit/model/modeling_rwkv7.py @@ -6,24 +6,23 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from torch import nn - +import torch.nn.init as init # for 0.1B -n_layer = 12 -n_embd = 768 +n_layer = 3 +n_embd = 256 D_DECAY_LORA = 64 D_AAA_LORA = 64 D_MV_LORA = 32 D_GATE_LORA = 128 -dim_att = 768 -dim_ffn = 3072 +dim_att = n_embd +dim_ffn = n_embd * 4 -# vocab_size = 65536 -vocab_size = 65536 +vocab_size = 32 # DTYPE = torch.bfloat16 -DTYPE = torch.half # better +DTYPE = torch.float32 # better head_size_a = 64 # don't change HS = head_size_a @@ -172,8 +171,8 @@ class RWKV_CMix_x070(nn.Module): super().__init__() self.layer_id = layer_id - with torch.no_grad(): - self.x_k = nn.Parameter(torch.empty(1, 1, n_embd)) + # with torch.no_grad(): + self.x_k = nn.Parameter(torch.empty(1, 1, n_embd)) self.key = nn.Linear(n_embd, dim_ffn, bias=False) self.value = nn.Linear(dim_ffn, n_embd, bias=False) @@ -218,8 +217,6 @@ class Block(nn.Module): class RWKV(nn.Module): def __init__(self): super().__init__() - dim_att = n_embd - dim_ffn = n_embd * 4 self.emb = nn.Embedding(vocab_size, n_embd) self.blocks = nn.ModuleList([Block(i) for i in range(n_layer)]) @@ -245,10 +242,13 @@ class RWKVLMHeadModel(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.transformer = RWKV() + self.rwkv = RWKV() self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.hook_attention = None + for name, param in self.rwkv.named_parameters(): + init.normal_(param.data) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -256,7 +256,7 @@ class RWKVLMHeadModel(nn.Module): token_type_ids: Optional[torch.LongTensor] = None, **kwargs, ): - lm_logits = self.transformer(input_ids) + lm_logits = self.rwkv(input_ids) loss = None if labels is not None: diff --git a/wit/model/qwen_module.py b/wit/model/qwen_module.py index 44bb67d..3f2177d 100644 --- a/wit/model/qwen_module.py +++ b/wit/model/qwen_module.py @@ -13,7 +13,6 @@ import pytorch_lightning as pl import torch import torchmetrics -from model.modeling_wit import QWenLMHeadModel from configuration import ModelConfig, TrainConfig @@ -166,7 +165,7 @@ class ModelRunner: class QwenModule(pl.LightningModule): - def __init__(self, conf: TrainConfig = None): + def __init__(self, conf: TrainConfig, model): self.config = conf pretrained_model_dir = conf.pretrain_model_name learning_rate = conf.learning_rate @@ -176,7 +175,6 @@ class QwenModule(pl.LightningModule): self.save_hyperparameters() if mconf == None: mconf = ModelConfig() - model = QWenLMHeadModel(mconf) if pretrained_model_dir != None: from modelscope import snapshot_download diff --git a/wit/train.py b/wit/train.py index 43974cb..fea5f67 100644 --- a/wit/train.py +++ b/wit/train.py @@ -2,8 +2,9 @@ import pytorch_lightning as pl import torch from model.qwen_module import QwenModule -from model.tokenization_qwen import QWenTokenizer -from logger import MLFLogger, TBLogger +from model.modeling_wit import QWenLMHeadModel +from model.modeling_rwkv7 import RWKVLMHeadModel +from logger import TBLogger import configuration import dataset.dataset as ds @@ -42,7 +43,9 @@ if __name__ == "__main__": torch.manual_seed(conf.seed) np.random.seed(conf.seed) - qwen = QwenModule(conf) + model = QWenLMHeadModel(conf.model_config) + # model = RWKVLMHeadModel(conf.model_config) + qwen = QwenModule(conf, model) train_dataloader, val_dataloader = ds.InitDataset(conf) # for i in range(len(train_dataloader)):