Update rwkv train.

This commit is contained in:
Colin 2025-03-10 16:26:53 +08:00
parent 0600d46f2f
commit 1efda9fe25
4 changed files with 23 additions and 21 deletions

View File

@ -12,6 +12,7 @@ if __name__ == "__main__":
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt" checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt" checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
# checkpoint_path = "log/bigger/version_11/checkpoints/epoch=25-step=128336.ckpt"
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval() qwen.eval()
@ -41,4 +42,4 @@ if __name__ == "__main__":
if item[i] != next_token: if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token)) node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR") print(str(item[i]) + " " + str(next_token) + " ERROR")
# node.print() node.print()

View File

@ -6,24 +6,23 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch import nn from torch import nn
import torch.nn.init as init
# for 0.1B # for 0.1B
n_layer = 12 n_layer = 3
n_embd = 768 n_embd = 256
D_DECAY_LORA = 64 D_DECAY_LORA = 64
D_AAA_LORA = 64 D_AAA_LORA = 64
D_MV_LORA = 32 D_MV_LORA = 32
D_GATE_LORA = 128 D_GATE_LORA = 128
dim_att = 768 dim_att = n_embd
dim_ffn = 3072 dim_ffn = n_embd * 4
# vocab_size = 65536 vocab_size = 32
vocab_size = 65536
# DTYPE = torch.bfloat16 # DTYPE = torch.bfloat16
DTYPE = torch.half # better DTYPE = torch.float32 # better
head_size_a = 64 # don't change head_size_a = 64 # don't change
HS = head_size_a HS = head_size_a
@ -172,8 +171,8 @@ class RWKV_CMix_x070(nn.Module):
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
with torch.no_grad(): # with torch.no_grad():
self.x_k = nn.Parameter(torch.empty(1, 1, n_embd)) self.x_k = nn.Parameter(torch.empty(1, 1, n_embd))
self.key = nn.Linear(n_embd, dim_ffn, bias=False) self.key = nn.Linear(n_embd, dim_ffn, bias=False)
self.value = nn.Linear(dim_ffn, n_embd, bias=False) self.value = nn.Linear(dim_ffn, n_embd, bias=False)
@ -218,8 +217,6 @@ class Block(nn.Module):
class RWKV(nn.Module): class RWKV(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
dim_att = n_embd
dim_ffn = n_embd * 4
self.emb = nn.Embedding(vocab_size, n_embd) self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.ModuleList([Block(i) for i in range(n_layer)]) self.blocks = nn.ModuleList([Block(i) for i in range(n_layer)])
@ -245,10 +242,13 @@ class RWKVLMHeadModel(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = RWKV() self.rwkv = RWKV()
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.hook_attention = None self.hook_attention = None
for name, param in self.rwkv.named_parameters():
init.normal_(param.data)
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@ -256,7 +256,7 @@ class RWKVLMHeadModel(nn.Module):
token_type_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
): ):
lm_logits = self.transformer(input_ids) lm_logits = self.rwkv(input_ids)
loss = None loss = None
if labels is not None: if labels is not None:

View File

@ -13,7 +13,6 @@ import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from model.modeling_wit import QWenLMHeadModel
from configuration import ModelConfig, TrainConfig from configuration import ModelConfig, TrainConfig
@ -166,7 +165,7 @@ class ModelRunner:
class QwenModule(pl.LightningModule): class QwenModule(pl.LightningModule):
def __init__(self, conf: TrainConfig = None): def __init__(self, conf: TrainConfig, model):
self.config = conf self.config = conf
pretrained_model_dir = conf.pretrain_model_name pretrained_model_dir = conf.pretrain_model_name
learning_rate = conf.learning_rate learning_rate = conf.learning_rate
@ -176,7 +175,6 @@ class QwenModule(pl.LightningModule):
self.save_hyperparameters() self.save_hyperparameters()
if mconf == None: if mconf == None:
mconf = ModelConfig() mconf = ModelConfig()
model = QWenLMHeadModel(mconf)
if pretrained_model_dir != None: if pretrained_model_dir != None:
from modelscope import snapshot_download from modelscope import snapshot_download

View File

@ -2,8 +2,9 @@ import pytorch_lightning as pl
import torch import torch
from model.qwen_module import QwenModule from model.qwen_module import QwenModule
from model.tokenization_qwen import QWenTokenizer from model.modeling_wit import QWenLMHeadModel
from logger import MLFLogger, TBLogger from model.modeling_rwkv7 import RWKVLMHeadModel
from logger import TBLogger
import configuration import configuration
import dataset.dataset as ds import dataset.dataset as ds
@ -42,7 +43,9 @@ if __name__ == "__main__":
torch.manual_seed(conf.seed) torch.manual_seed(conf.seed)
np.random.seed(conf.seed) np.random.seed(conf.seed)
qwen = QwenModule(conf) model = QWenLMHeadModel(conf.model_config)
# model = RWKVLMHeadModel(conf.model_config)
qwen = QwenModule(conf, model)
train_dataloader, val_dataloader = ds.InitDataset(conf) train_dataloader, val_dataloader = ds.InitDataset(conf)
# for i in range(len(train_dataloader)): # for i in range(len(train_dataloader)):