Update rwkv train.

This commit is contained in:
Colin 2025-03-10 16:26:53 +08:00
parent 0600d46f2f
commit 1efda9fe25
4 changed files with 23 additions and 21 deletions

View File

@ -12,6 +12,7 @@ if __name__ == "__main__":
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
# checkpoint_path = "log/bigger/version_11/checkpoints/epoch=25-step=128336.ckpt"
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval()
@ -41,4 +42,4 @@ if __name__ == "__main__":
if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR")
# node.print()
node.print()

View File

@ -6,24 +6,23 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from torch import nn
import torch.nn.init as init
# for 0.1B
n_layer = 12
n_embd = 768
n_layer = 3
n_embd = 256
D_DECAY_LORA = 64
D_AAA_LORA = 64
D_MV_LORA = 32
D_GATE_LORA = 128
dim_att = 768
dim_ffn = 3072
dim_att = n_embd
dim_ffn = n_embd * 4
# vocab_size = 65536
vocab_size = 65536
vocab_size = 32
# DTYPE = torch.bfloat16
DTYPE = torch.half # better
DTYPE = torch.float32 # better
head_size_a = 64 # don't change
HS = head_size_a
@ -172,8 +171,8 @@ class RWKV_CMix_x070(nn.Module):
super().__init__()
self.layer_id = layer_id
with torch.no_grad():
self.x_k = nn.Parameter(torch.empty(1, 1, n_embd))
# with torch.no_grad():
self.x_k = nn.Parameter(torch.empty(1, 1, n_embd))
self.key = nn.Linear(n_embd, dim_ffn, bias=False)
self.value = nn.Linear(dim_ffn, n_embd, bias=False)
@ -218,8 +217,6 @@ class Block(nn.Module):
class RWKV(nn.Module):
def __init__(self):
super().__init__()
dim_att = n_embd
dim_ffn = n_embd * 4
self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.ModuleList([Block(i) for i in range(n_layer)])
@ -245,10 +242,13 @@ class RWKVLMHeadModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = RWKV()
self.rwkv = RWKV()
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.hook_attention = None
for name, param in self.rwkv.named_parameters():
init.normal_(param.data)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -256,7 +256,7 @@ class RWKVLMHeadModel(nn.Module):
token_type_ids: Optional[torch.LongTensor] = None,
**kwargs,
):
lm_logits = self.transformer(input_ids)
lm_logits = self.rwkv(input_ids)
loss = None
if labels is not None:

View File

@ -13,7 +13,6 @@ import pytorch_lightning as pl
import torch
import torchmetrics
from model.modeling_wit import QWenLMHeadModel
from configuration import ModelConfig, TrainConfig
@ -166,7 +165,7 @@ class ModelRunner:
class QwenModule(pl.LightningModule):
def __init__(self, conf: TrainConfig = None):
def __init__(self, conf: TrainConfig, model):
self.config = conf
pretrained_model_dir = conf.pretrain_model_name
learning_rate = conf.learning_rate
@ -176,7 +175,6 @@ class QwenModule(pl.LightningModule):
self.save_hyperparameters()
if mconf == None:
mconf = ModelConfig()
model = QWenLMHeadModel(mconf)
if pretrained_model_dir != None:
from modelscope import snapshot_download

View File

@ -2,8 +2,9 @@ import pytorch_lightning as pl
import torch
from model.qwen_module import QwenModule
from model.tokenization_qwen import QWenTokenizer
from logger import MLFLogger, TBLogger
from model.modeling_wit import QWenLMHeadModel
from model.modeling_rwkv7 import RWKVLMHeadModel
from logger import TBLogger
import configuration
import dataset.dataset as ds
@ -42,7 +43,9 @@ if __name__ == "__main__":
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
qwen = QwenModule(conf)
model = QWenLMHeadModel(conf.model_config)
# model = RWKVLMHeadModel(conf.model_config)
qwen = QwenModule(conf, model)
train_dataloader, val_dataloader = ds.InitDataset(conf)
# for i in range(len(train_dataloader)):