Update rwkv train.
This commit is contained in:
parent
0600d46f2f
commit
1efda9fe25
|
@ -12,6 +12,7 @@ if __name__ == "__main__":
|
|||
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
||||
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
|
||||
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
||||
# checkpoint_path = "log/bigger/version_11/checkpoints/epoch=25-step=128336.ckpt"
|
||||
|
||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen.eval()
|
||||
|
@ -41,4 +42,4 @@ if __name__ == "__main__":
|
|||
if item[i] != next_token:
|
||||
node.set_seq_prop(i, "ERR_" + str(next_token))
|
||||
print(str(item[i]) + " " + str(next_token) + " ERROR")
|
||||
# node.print()
|
||||
node.print()
|
||||
|
|
|
@ -6,24 +6,23 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch import nn
|
||||
|
||||
import torch.nn.init as init
|
||||
|
||||
# for 0.1B
|
||||
n_layer = 12
|
||||
n_embd = 768
|
||||
n_layer = 3
|
||||
n_embd = 256
|
||||
D_DECAY_LORA = 64
|
||||
D_AAA_LORA = 64
|
||||
D_MV_LORA = 32
|
||||
D_GATE_LORA = 128
|
||||
|
||||
dim_att = 768
|
||||
dim_ffn = 3072
|
||||
dim_att = n_embd
|
||||
dim_ffn = n_embd * 4
|
||||
|
||||
# vocab_size = 65536
|
||||
vocab_size = 65536
|
||||
vocab_size = 32
|
||||
|
||||
# DTYPE = torch.bfloat16
|
||||
DTYPE = torch.half # better
|
||||
DTYPE = torch.float32 # better
|
||||
|
||||
head_size_a = 64 # don't change
|
||||
HS = head_size_a
|
||||
|
@ -172,8 +171,8 @@ class RWKV_CMix_x070(nn.Module):
|
|||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
|
||||
with torch.no_grad():
|
||||
self.x_k = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||
# with torch.no_grad():
|
||||
self.x_k = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||
|
||||
self.key = nn.Linear(n_embd, dim_ffn, bias=False)
|
||||
self.value = nn.Linear(dim_ffn, n_embd, bias=False)
|
||||
|
@ -218,8 +217,6 @@ class Block(nn.Module):
|
|||
class RWKV(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
dim_att = n_embd
|
||||
dim_ffn = n_embd * 4
|
||||
self.emb = nn.Embedding(vocab_size, n_embd)
|
||||
|
||||
self.blocks = nn.ModuleList([Block(i) for i in range(n_layer)])
|
||||
|
@ -245,10 +242,13 @@ class RWKVLMHeadModel(nn.Module):
|
|||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = RWKV()
|
||||
self.rwkv = RWKV()
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.hook_attention = None
|
||||
|
||||
for name, param in self.rwkv.named_parameters():
|
||||
init.normal_(param.data)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
@ -256,7 +256,7 @@ class RWKVLMHeadModel(nn.Module):
|
|||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
lm_logits = self.transformer(input_ids)
|
||||
lm_logits = self.rwkv(input_ids)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
|
|
@ -13,7 +13,6 @@ import pytorch_lightning as pl
|
|||
import torch
|
||||
import torchmetrics
|
||||
|
||||
from model.modeling_wit import QWenLMHeadModel
|
||||
from configuration import ModelConfig, TrainConfig
|
||||
|
||||
|
||||
|
@ -166,7 +165,7 @@ class ModelRunner:
|
|||
|
||||
|
||||
class QwenModule(pl.LightningModule):
|
||||
def __init__(self, conf: TrainConfig = None):
|
||||
def __init__(self, conf: TrainConfig, model):
|
||||
self.config = conf
|
||||
pretrained_model_dir = conf.pretrain_model_name
|
||||
learning_rate = conf.learning_rate
|
||||
|
@ -176,7 +175,6 @@ class QwenModule(pl.LightningModule):
|
|||
self.save_hyperparameters()
|
||||
if mconf == None:
|
||||
mconf = ModelConfig()
|
||||
model = QWenLMHeadModel(mconf)
|
||||
if pretrained_model_dir != None:
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
|
|
@ -2,8 +2,9 @@ import pytorch_lightning as pl
|
|||
import torch
|
||||
|
||||
from model.qwen_module import QwenModule
|
||||
from model.tokenization_qwen import QWenTokenizer
|
||||
from logger import MLFLogger, TBLogger
|
||||
from model.modeling_wit import QWenLMHeadModel
|
||||
from model.modeling_rwkv7 import RWKVLMHeadModel
|
||||
from logger import TBLogger
|
||||
|
||||
import configuration
|
||||
import dataset.dataset as ds
|
||||
|
@ -42,7 +43,9 @@ if __name__ == "__main__":
|
|||
|
||||
torch.manual_seed(conf.seed)
|
||||
np.random.seed(conf.seed)
|
||||
qwen = QwenModule(conf)
|
||||
model = QWenLMHeadModel(conf.model_config)
|
||||
# model = RWKVLMHeadModel(conf.model_config)
|
||||
qwen = QwenModule(conf, model)
|
||||
|
||||
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
||||
# for i in range(len(train_dataloader)):
|
||||
|
|
Loading…
Reference in New Issue