Format rwkv/RWKV-v7/rwkv_v7_demo.py
This commit is contained in:
parent
002f132818
commit
4f18296e40
|
@ -5,7 +5,9 @@
|
||||||
import torch, types, os, gc, math, json
|
import torch, types, os, gc, math, json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.nn import Module
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
@ -14,9 +16,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
||||||
torch._C._jit_set_autocast_mode(False)
|
torch._C._jit_set_autocast_mode(False)
|
||||||
|
|
||||||
'''
|
"""
|
||||||
This will load RWKV-7 "Goose" x070 and inference in GPT-mode (slower than RNN-mode for autoregressive generation)
|
This will load RWKV-7 "Goose" x070 and inference in GPT-mode (slower than RNN-mode for autoregressive generation)
|
||||||
'''
|
"""
|
||||||
|
|
||||||
args = types.SimpleNamespace()
|
args = types.SimpleNamespace()
|
||||||
|
|
||||||
|
@ -35,35 +37,33 @@ D_GATE_LORA = 128
|
||||||
args.vocab_size = 65536
|
args.vocab_size = 65536
|
||||||
|
|
||||||
# DTYPE = torch.bfloat16
|
# DTYPE = torch.bfloat16
|
||||||
DTYPE = torch.half # better
|
DTYPE = torch.half # better
|
||||||
|
|
||||||
args.head_size_a = 64 # don't change
|
args.head_size_a = 64 # don't change
|
||||||
HEAD_SIZE = args.head_size_a
|
HEAD_SIZE = args.head_size_a
|
||||||
|
|
||||||
USE_CUDA_KERNEL = True # False => UNOPTIMIZED, VERY SLOW
|
USE_CUDA_KERNEL = True # False => UNOPTIMIZED, VERY SLOW
|
||||||
|
|
||||||
MyModule = torch.jit.ScriptModule
|
|
||||||
MyFunction = torch.jit.script_method
|
|
||||||
MyStatic = torch.jit.script
|
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
# RWKV Tokenizer (slow version)
|
# RWKV Tokenizer (slow version)
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
class RWKV_TOKENIZER():
|
|
||||||
|
class RWKV_TOKENIZER:
|
||||||
table: list[list[list[bytes]]]
|
table: list[list[list[bytes]]]
|
||||||
good: list[set[int]]
|
good: list[set[int]]
|
||||||
wlen: list[int]
|
wlen: list[int]
|
||||||
|
|
||||||
def __init__(self, file_name):
|
def __init__(self, file_name):
|
||||||
self.idx2token = {}
|
self.idx2token = {}
|
||||||
sorted = [] # must be already sorted
|
sorted = [] # must be already sorted
|
||||||
lines = open(file_name, "r", encoding="utf-8").readlines()
|
lines = open(file_name, "r", encoding="utf-8").readlines()
|
||||||
for l in lines:
|
for l in lines:
|
||||||
idx = int(l[:l.index(' ')])
|
idx = int(l[: l.index(" ")])
|
||||||
x = eval(l[l.index(' '):l.rindex(' ')])
|
x = eval(l[l.index(" ") : l.rindex(" ")])
|
||||||
x = x.encode("utf-8") if isinstance(x, str) else x
|
x = x.encode("utf-8") if isinstance(x, str) else x
|
||||||
assert isinstance(x, bytes)
|
assert isinstance(x, bytes)
|
||||||
assert len(x) == int(l[l.rindex(' '):])
|
assert len(x) == int(l[l.rindex(" ") :])
|
||||||
sorted += [x]
|
sorted += [x]
|
||||||
self.idx2token[idx] = x
|
self.idx2token[idx] = x
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ class RWKV_TOKENIZER():
|
||||||
self.good = [set() for i in range(256)]
|
self.good = [set() for i in range(256)]
|
||||||
self.wlen = [0 for i in range(256)]
|
self.wlen = [0 for i in range(256)]
|
||||||
|
|
||||||
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
|
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
|
||||||
s = sorted[i]
|
s = sorted[i]
|
||||||
if len(s) >= 2:
|
if len(s) >= 2:
|
||||||
s0 = int(s[0])
|
s0 = int(s[0])
|
||||||
|
@ -107,25 +107,26 @@ class RWKV_TOKENIZER():
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def decodeBytes(self, tokens):
|
def decodeBytes(self, tokens):
|
||||||
return b''.join(map(lambda i: self.idx2token[i], tokens))
|
return b"".join(map(lambda i: self.idx2token[i], tokens))
|
||||||
|
|
||||||
def encode(self, src: str):
|
def encode(self, src: str):
|
||||||
return self.encodeBytes(src.encode("utf-8"))
|
return self.encodeBytes(src.encode("utf-8"))
|
||||||
|
|
||||||
def decode(self, tokens):
|
def decode(self, tokens):
|
||||||
return self.decodeBytes(tokens).decode('utf-8')
|
return self.decodeBytes(tokens).decode("utf-8")
|
||||||
|
|
||||||
def printTokens(self, tokens):
|
def printTokens(self, tokens):
|
||||||
for i in tokens:
|
for i in tokens:
|
||||||
s = self.idx2token[i]
|
s = self.idx2token[i]
|
||||||
try:
|
try:
|
||||||
s = s.decode('utf-8')
|
s = s.decode("utf-8")
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
print(f'{repr(s)}{i}', end=' ')
|
print(f"{repr(s)}{i}", end=" ")
|
||||||
# print(repr(s), i)
|
# print(repr(s), i)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt")
|
tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt")
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
@ -136,8 +137,21 @@ if USE_CUDA_KERNEL:
|
||||||
|
|
||||||
from torch.utils.cpp_extension import load
|
from torch.utils.cpp_extension import load
|
||||||
|
|
||||||
load(name="wkv7", sources=["cuda/wkv7_op.cpp", f"cuda/wkv7.cu"], is_python_module=False,
|
load(
|
||||||
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"])
|
name="wkv7",
|
||||||
|
sources=["cuda/wkv7_op.cpp", f"cuda/wkv7.cu"],
|
||||||
|
is_python_module=False,
|
||||||
|
verbose=True,
|
||||||
|
extra_cuda_cflags=[
|
||||||
|
"-res-usage",
|
||||||
|
"--use_fast_math",
|
||||||
|
"-O3",
|
||||||
|
"-Xptxas -O3",
|
||||||
|
"--extra-device-vectorization",
|
||||||
|
f"-D_N_={HEAD_SIZE}",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
class WKV_7(torch.autograd.Function):
|
class WKV_7(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, r, w, k, v, a, b):
|
def forward(ctx, r, w, k, v, a, b):
|
||||||
|
@ -186,7 +200,7 @@ else:
|
||||||
vv = v[:, t, :].view(B, H, N, 1)
|
vv = v[:, t, :].view(B, H, N, 1)
|
||||||
aa = a[:, t, :].view(B, H, N, 1)
|
aa = a[:, t, :].view(B, H, N, 1)
|
||||||
bb = b[:, t, :].view(B, H, 1, N)
|
bb = b[:, t, :].view(B, H, 1, N)
|
||||||
state = state * w[: , t, :, None, :] + state @ aa @ bb + vv @ kk
|
state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
|
||||||
out[:, t, :] = (state @ rr).view(B, H, N)
|
out[:, t, :] = (state @ rr).view(B, H, N)
|
||||||
|
|
||||||
# another method using einsum
|
# another method using einsum
|
||||||
|
@ -202,11 +216,13 @@ else:
|
||||||
|
|
||||||
return out.view(B, T, C).to(dtype=DTYPE)
|
return out.view(B, T, C).to(dtype=DTYPE)
|
||||||
|
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
# RWKV TimeMix
|
# RWKV TimeMix
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
class RWKV_Tmix_x070(MyModule):
|
|
||||||
|
class RWKV_Tmix_x070(Module):
|
||||||
def __init__(self, args, layer_id):
|
def __init__(self, args, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
|
@ -220,40 +236,39 @@ class RWKV_Tmix_x070(MyModule):
|
||||||
N = self.head_size
|
N = self.head_size
|
||||||
C = args.n_embd
|
C = args.n_embd
|
||||||
|
|
||||||
self.x_r = nn.Parameter(torch.empty(1,1,C))
|
self.x_r = nn.Parameter(torch.empty(1, 1, C))
|
||||||
self.x_w = nn.Parameter(torch.empty(1,1,C))
|
self.x_w = nn.Parameter(torch.empty(1, 1, C))
|
||||||
self.x_k = nn.Parameter(torch.empty(1,1,C))
|
self.x_k = nn.Parameter(torch.empty(1, 1, C))
|
||||||
self.x_v = nn.Parameter(torch.empty(1,1,C))
|
self.x_v = nn.Parameter(torch.empty(1, 1, C))
|
||||||
self.x_a = nn.Parameter(torch.empty(1,1,C))
|
self.x_a = nn.Parameter(torch.empty(1, 1, C))
|
||||||
self.x_g = nn.Parameter(torch.empty(1,1,C))
|
self.x_g = nn.Parameter(torch.empty(1, 1, C))
|
||||||
|
|
||||||
self.w0 = nn.Parameter(torch.empty(1,1,C))
|
self.w0 = nn.Parameter(torch.empty(1, 1, C))
|
||||||
self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA))
|
self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA))
|
||||||
self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C))
|
self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C))
|
||||||
|
|
||||||
self.a0 = nn.Parameter(torch.empty(1,1,C))
|
self.a0 = nn.Parameter(torch.empty(1, 1, C))
|
||||||
self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA))
|
self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA))
|
||||||
self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C))
|
self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C))
|
||||||
|
|
||||||
self.v0 = nn.Parameter(torch.empty(1,1,C))
|
self.v0 = nn.Parameter(torch.empty(1, 1, C))
|
||||||
self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA))
|
self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA))
|
||||||
self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C))
|
self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C))
|
||||||
|
|
||||||
self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA))
|
self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA))
|
||||||
self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C))
|
self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C))
|
||||||
|
|
||||||
self.k_k = nn.Parameter(torch.empty(1,1,C))
|
self.k_k = nn.Parameter(torch.empty(1, 1, C))
|
||||||
self.k_a = nn.Parameter(torch.empty(1,1,C))
|
self.k_a = nn.Parameter(torch.empty(1, 1, C))
|
||||||
self.r_k = nn.Parameter(torch.empty(H,N))
|
self.r_k = nn.Parameter(torch.empty(H, N))
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||||
self.receptance = nn.Linear(C, C, bias=False)
|
self.receptance = nn.Linear(C, C, bias=False)
|
||||||
self.key = nn.Linear(C, C, bias=False)
|
self.key = nn.Linear(C, C, bias=False)
|
||||||
self.value = nn.Linear(C, C, bias=False)
|
self.value = nn.Linear(C, C, bias=False)
|
||||||
self.output = nn.Linear(C, C, bias=False)
|
self.output = nn.Linear(C, C, bias=False)
|
||||||
self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!!
|
self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!!
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def forward(self, x, v_first):
|
def forward(self, x, v_first):
|
||||||
B, T, C = x.size()
|
B, T, C = x.size()
|
||||||
H = self.n_head
|
H = self.n_head
|
||||||
|
@ -267,32 +282,36 @@ class RWKV_Tmix_x070(MyModule):
|
||||||
xg = x + xx * self.x_g
|
xg = x + xx * self.x_g
|
||||||
|
|
||||||
r = self.receptance(xr)
|
r = self.receptance(xr)
|
||||||
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5)
|
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5)
|
||||||
k = self.key(xk)
|
k = self.key(xk)
|
||||||
v = self.value(xv)
|
v = self.value(xv)
|
||||||
if self.layer_id == 0:
|
if self.layer_id == 0:
|
||||||
v_first = v # store the v of the first layer
|
v_first = v # store the v of the first layer
|
||||||
else:
|
else:
|
||||||
v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual
|
v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual
|
||||||
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate"
|
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate"
|
||||||
g = torch.sigmoid(xg @ self.g1) @ self.g2
|
g = torch.sigmoid(xg @ self.g1) @ self.g2
|
||||||
|
|
||||||
kk = k * self.k_k
|
kk = k * self.k_k
|
||||||
kk = F.normalize(kk.view(B,T,H,-1), dim=-1, p=2.0).view(B,T,C)
|
kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C)
|
||||||
k = k * (1 + (a-1) * self.k_a)
|
k = k * (1 + (a - 1) * self.k_a)
|
||||||
|
|
||||||
x = RWKV7_OP(r, w, k, v, -kk, kk*a)
|
x = RWKV7_OP(r, w, k, v, -kk, kk * a)
|
||||||
x = self.ln_x(x.view(B * T, C)).view(B, T, C)
|
x = self.ln_x(x.view(B * T, C)).view(B, T, C)
|
||||||
|
|
||||||
x = x + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
|
x = x + (
|
||||||
|
(r.view(B, T, H, -1) * k.view(B, T, H, -1) * self.r_k).sum(dim=-1, keepdim=True) * v.view(B, T, H, -1)
|
||||||
|
).view(B, T, C)
|
||||||
x = self.output(x * g)
|
x = self.output(x * g)
|
||||||
return x, v_first
|
return x, v_first
|
||||||
|
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
# RWKV ChannelMix
|
# RWKV ChannelMix
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
class RWKV_CMix_x070(MyModule):
|
|
||||||
|
class RWKV_CMix_x070(Module):
|
||||||
def __init__(self, args, layer_id):
|
def __init__(self, args, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
|
@ -305,32 +324,32 @@ class RWKV_CMix_x070(MyModule):
|
||||||
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
||||||
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
xx = self.time_shift(x) - x
|
xx = self.time_shift(x) - x
|
||||||
|
|
||||||
k = x + xx * self.x_k
|
k = x + xx * self.x_k
|
||||||
k = torch.relu(self.key(k)) ** 2
|
k = torch.relu(self.key(k)) ** 2
|
||||||
return self.value(k)
|
return self.value(k)
|
||||||
|
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
# RWKV Block
|
# RWKV Block
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
class Block(MyModule):
|
|
||||||
|
class Block(Module):
|
||||||
def __init__(self, args, layer_id):
|
def __init__(self, args, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
|
||||||
self.ln0 = nn.LayerNorm(args.n_embd) # only used in block 0, should be fused with emb
|
self.ln0 = nn.LayerNorm(args.n_embd) # only used in block 0, should be fused with emb
|
||||||
self.ln1 = nn.LayerNorm(args.n_embd)
|
self.ln1 = nn.LayerNorm(args.n_embd)
|
||||||
self.ln2 = nn.LayerNorm(args.n_embd)
|
self.ln2 = nn.LayerNorm(args.n_embd)
|
||||||
|
|
||||||
self.att = RWKV_Tmix_x070(args, layer_id)
|
self.att = RWKV_Tmix_x070(args, layer_id)
|
||||||
self.ffn = RWKV_CMix_x070(args, layer_id)
|
self.ffn = RWKV_CMix_x070(args, layer_id)
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def forward(self, x, v_first):
|
def forward(self, x, v_first):
|
||||||
|
|
||||||
if self.layer_id == 0:
|
if self.layer_id == 0:
|
||||||
|
@ -342,10 +361,12 @@ class Block(MyModule):
|
||||||
|
|
||||||
return x, v_first
|
return x, v_first
|
||||||
|
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
# RWKV Model
|
# RWKV Model
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
|
|
||||||
class RWKV(nn.Module):
|
class RWKV(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -371,6 +392,7 @@ class RWKV(nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
# RWKV Inference
|
# RWKV Inference
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
@ -380,38 +402,38 @@ model_params = torch.load(MODEL_PATH, map_location="cpu")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
model = RWKV(args).to(dtype=DTYPE).cuda()
|
model = RWKV(args).to(dtype=DTYPE).cuda()
|
||||||
model.load_state_dict(model_params, strict=False) # we will ignore blocks.0.att.v0/v1/v2
|
model.load_state_dict(model_params, strict=False) # we will ignore blocks.0.att.v0/v1/v2
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
prompt = "中国的首都是在"
|
prompt = "中国的首都是在"
|
||||||
input = tokenizer.encode(prompt)
|
input = tokenizer.encode(prompt)
|
||||||
print(f'\nInput:\n{input}')
|
print(f"\nInput:\n{input}")
|
||||||
|
|
||||||
out = model.forward(torch.tensor(input).reshape(1,-1).cuda())
|
out = model.forward(torch.tensor(input).reshape(1, -1).cuda())
|
||||||
print(f'\nOutput:\n{out}')
|
print(f"\nOutput:\n{out}")
|
||||||
|
|
||||||
# logits of the last token => prediction for the next token
|
# logits of the last token => prediction for the next token
|
||||||
out = out[0, -1]
|
out = out[0, -1]
|
||||||
|
|
||||||
probs = F.softmax(out.float(), dim=-1) # compute softmax in float (more accurate)
|
|
||||||
|
|
||||||
print(f'\n{prompt}')
|
probs = F.softmax(out.float(), dim=-1) # compute softmax in float (more accurate)
|
||||||
|
|
||||||
_, indices = torch.topk(probs, 10) # print top-10 possibilities
|
print(f"\n{prompt}")
|
||||||
|
|
||||||
|
_, indices = torch.topk(probs, 10) # print top-10 possibilities
|
||||||
for i in range(len(indices)):
|
for i in range(len(indices)):
|
||||||
token_id = indices[i].item()
|
token_id = indices[i].item()
|
||||||
token = tokenizer.decode([token_id])
|
token = tokenizer.decode([token_id])
|
||||||
token_prob = probs[token_id].item()
|
token_prob = probs[token_id].item()
|
||||||
print(token, f'[probability {token_prob:.2%}]')
|
print(token, f"[probability {token_prob:.2%}]")
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f:
|
with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f:
|
||||||
todo = [json.loads(line) for line in f]
|
todo = [json.loads(line) for line in f]
|
||||||
todo = [[doc['text'].rsplit(' ', 1)[0], " " + doc['text'].rsplit(' ', 1)[1]] for doc in todo]
|
todo = [[doc["text"].rsplit(" ", 1)[0], " " + doc["text"].rsplit(" ", 1)[1]] for doc in todo]
|
||||||
|
|
||||||
print('\nCheck LAMBADA...')
|
print("\nCheck LAMBADA...")
|
||||||
xsum = 0
|
xsum = 0
|
||||||
xcnt = 0
|
xcnt = 0
|
||||||
xacc = 0
|
xacc = 0
|
||||||
|
@ -421,9 +443,9 @@ with torch.no_grad():
|
||||||
|
|
||||||
logits = 0
|
logits = 0
|
||||||
correct = True
|
correct = True
|
||||||
out = model.forward(torch.tensor(src+dst).reshape(1,-1).cuda())
|
out = model.forward(torch.tensor(src + dst).reshape(1, -1).cuda())
|
||||||
for i in range(len(dst)):
|
for i in range(len(dst)):
|
||||||
ooo = out[0,len(src)-1+i].float()
|
ooo = out[0, len(src) - 1 + i].float()
|
||||||
probs = F.softmax(ooo, dim=-1)
|
probs = F.softmax(ooo, dim=-1)
|
||||||
logits += math.log(probs[dst[i]])
|
logits += math.log(probs[dst[i]])
|
||||||
if torch.argmax(probs).item() != dst[i]:
|
if torch.argmax(probs).item() != dst[i]:
|
||||||
|
@ -433,4 +455,4 @@ with torch.no_grad():
|
||||||
xsum += logits
|
xsum += logits
|
||||||
xacc += 1 if correct else 0
|
xacc += 1 if correct else 0
|
||||||
if xcnt % 100 == 0 or xcnt == len(todo):
|
if xcnt % 100 == 0 or xcnt == len(todo):
|
||||||
print(xcnt, 'ppl', round(math.exp(-xsum / xcnt), 2), 'acc', round(xacc/xcnt*100, 2))
|
print(xcnt, "ppl", round(math.exp(-xsum / xcnt), 2), "acc", round(xacc / xcnt * 100, 2))
|
||||||
|
|
Loading…
Reference in New Issue