Update rwkv train.
This commit is contained in:
parent
0600d46f2f
commit
1efda9fe25
|
@ -12,6 +12,7 @@ if __name__ == "__main__":
|
||||||
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
||||||
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
|
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
|
||||||
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
||||||
|
# checkpoint_path = "log/bigger/version_11/checkpoints/epoch=25-step=128336.ckpt"
|
||||||
|
|
||||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
qwen.eval()
|
qwen.eval()
|
||||||
|
@ -41,4 +42,4 @@ if __name__ == "__main__":
|
||||||
if item[i] != next_token:
|
if item[i] != next_token:
|
||||||
node.set_seq_prop(i, "ERR_" + str(next_token))
|
node.set_seq_prop(i, "ERR_" + str(next_token))
|
||||||
print(str(item[i]) + " " + str(next_token) + " ERROR")
|
print(str(item[i]) + " " + str(next_token) + " ERROR")
|
||||||
# node.print()
|
node.print()
|
||||||
|
|
|
@ -6,24 +6,23 @@ import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import torch.nn.init as init
|
||||||
|
|
||||||
# for 0.1B
|
# for 0.1B
|
||||||
n_layer = 12
|
n_layer = 3
|
||||||
n_embd = 768
|
n_embd = 256
|
||||||
D_DECAY_LORA = 64
|
D_DECAY_LORA = 64
|
||||||
D_AAA_LORA = 64
|
D_AAA_LORA = 64
|
||||||
D_MV_LORA = 32
|
D_MV_LORA = 32
|
||||||
D_GATE_LORA = 128
|
D_GATE_LORA = 128
|
||||||
|
|
||||||
dim_att = 768
|
dim_att = n_embd
|
||||||
dim_ffn = 3072
|
dim_ffn = n_embd * 4
|
||||||
|
|
||||||
# vocab_size = 65536
|
vocab_size = 32
|
||||||
vocab_size = 65536
|
|
||||||
|
|
||||||
# DTYPE = torch.bfloat16
|
# DTYPE = torch.bfloat16
|
||||||
DTYPE = torch.half # better
|
DTYPE = torch.float32 # better
|
||||||
|
|
||||||
head_size_a = 64 # don't change
|
head_size_a = 64 # don't change
|
||||||
HS = head_size_a
|
HS = head_size_a
|
||||||
|
@ -172,8 +171,8 @@ class RWKV_CMix_x070(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
|
||||||
with torch.no_grad():
|
# with torch.no_grad():
|
||||||
self.x_k = nn.Parameter(torch.empty(1, 1, n_embd))
|
self.x_k = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||||
|
|
||||||
self.key = nn.Linear(n_embd, dim_ffn, bias=False)
|
self.key = nn.Linear(n_embd, dim_ffn, bias=False)
|
||||||
self.value = nn.Linear(dim_ffn, n_embd, bias=False)
|
self.value = nn.Linear(dim_ffn, n_embd, bias=False)
|
||||||
|
@ -218,8 +217,6 @@ class Block(nn.Module):
|
||||||
class RWKV(nn.Module):
|
class RWKV(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dim_att = n_embd
|
|
||||||
dim_ffn = n_embd * 4
|
|
||||||
self.emb = nn.Embedding(vocab_size, n_embd)
|
self.emb = nn.Embedding(vocab_size, n_embd)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([Block(i) for i in range(n_layer)])
|
self.blocks = nn.ModuleList([Block(i) for i in range(n_layer)])
|
||||||
|
@ -245,10 +242,13 @@ class RWKVLMHeadModel(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.transformer = RWKV()
|
self.rwkv = RWKV()
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
self.hook_attention = None
|
self.hook_attention = None
|
||||||
|
|
||||||
|
for name, param in self.rwkv.named_parameters():
|
||||||
|
init.normal_(param.data)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
@ -256,7 +256,7 @@ class RWKVLMHeadModel(nn.Module):
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
lm_logits = self.transformer(input_ids)
|
lm_logits = self.rwkv(input_ids)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|
|
@ -13,7 +13,6 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
|
|
||||||
from model.modeling_wit import QWenLMHeadModel
|
|
||||||
from configuration import ModelConfig, TrainConfig
|
from configuration import ModelConfig, TrainConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,7 +165,7 @@ class ModelRunner:
|
||||||
|
|
||||||
|
|
||||||
class QwenModule(pl.LightningModule):
|
class QwenModule(pl.LightningModule):
|
||||||
def __init__(self, conf: TrainConfig = None):
|
def __init__(self, conf: TrainConfig, model):
|
||||||
self.config = conf
|
self.config = conf
|
||||||
pretrained_model_dir = conf.pretrain_model_name
|
pretrained_model_dir = conf.pretrain_model_name
|
||||||
learning_rate = conf.learning_rate
|
learning_rate = conf.learning_rate
|
||||||
|
@ -176,7 +175,6 @@ class QwenModule(pl.LightningModule):
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
if mconf == None:
|
if mconf == None:
|
||||||
mconf = ModelConfig()
|
mconf = ModelConfig()
|
||||||
model = QWenLMHeadModel(mconf)
|
|
||||||
if pretrained_model_dir != None:
|
if pretrained_model_dir != None:
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,9 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from model.qwen_module import QwenModule
|
from model.qwen_module import QwenModule
|
||||||
from model.tokenization_qwen import QWenTokenizer
|
from model.modeling_wit import QWenLMHeadModel
|
||||||
from logger import MLFLogger, TBLogger
|
from model.modeling_rwkv7 import RWKVLMHeadModel
|
||||||
|
from logger import TBLogger
|
||||||
|
|
||||||
import configuration
|
import configuration
|
||||||
import dataset.dataset as ds
|
import dataset.dataset as ds
|
||||||
|
@ -42,7 +43,9 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
torch.manual_seed(conf.seed)
|
torch.manual_seed(conf.seed)
|
||||||
np.random.seed(conf.seed)
|
np.random.seed(conf.seed)
|
||||||
qwen = QwenModule(conf)
|
model = QWenLMHeadModel(conf.model_config)
|
||||||
|
# model = RWKVLMHeadModel(conf.model_config)
|
||||||
|
qwen = QwenModule(conf, model)
|
||||||
|
|
||||||
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
||||||
# for i in range(len(train_dataloader)):
|
# for i in range(len(train_dataloader)):
|
||||||
|
|
Loading…
Reference in New Issue