From ebe48f8efcd97965edb5d4e2cd7d4f256b65cd9c Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 22 Dec 2023 20:01:09 +0800 Subject: [PATCH] Update readme. --- RMSNorm_weight.png | Bin 0 -> 2506 bytes Readme.md | 16 ++++++++++++---- chatglm/modeling_chatglm.py | 1 + tensor.py | 5 +++++ 4 files changed, 18 insertions(+), 4 deletions(-) create mode 100644 RMSNorm_weight.png diff --git a/RMSNorm_weight.png b/RMSNorm_weight.png new file mode 100644 index 0000000000000000000000000000000000000000..3d4b231447c93931b8cae3457a3309bf3aecd409 GIT binary patch literal 2506 zcmV;*2{rbKP)|wxU+IBUZ``gt7X%0+Sh(>TsRC+RwI%xbEJ>EawhaEmH=h-d?88K0Qy(}lDoIZowhxThaWAWAkGl`)Ga|?#?IMPe&1UHec@NHW~$d|9j3hMR4RFz2dKbxXGyH zm>W)WjJwGNXDxU(umo4j<QftFFxSTT85IPQHJ!&yPFtQ)u$}bp4?L*zH??{S^C;1xVol7 z%_g8CInC$RZiXaGhc`0TDv#+XF5mc-nGIll@T@~rI-yE9Wz1}tEt@Ih@gh*vXl;J- z1GKZ>EEb(#Iip5UEHzY2CondZ*{UtoIsIuX6H##e?B97dkj49tEnj@hjc+IdN7qEF zp3{(M8P%q1TxOXe*M-w;fL`!Zm(uEb^ABgFq4=tlW+ktlCBk&QYjvpk$;&!TV3VvY z-BRR=`TV2kqfcPz{^09f$uvWo(M~85?xP&%^Y^#fZ_SG*0uI4W<30=vktqf~-apS* zqo)b&M#kBe8`4pK{>5#yHYu%EeTf*J+vyz(938^|i)0#K7Fb@S70S{_PcJJ4)7Izy zb9Q+dac;Sto+cTv=;fp3<7-@e-2r3vxg36r{inVuS_)-8s}J9C4G@BTw2Q4=0MC8Z zYsa+x5MNTh+TiSW*RoZOe0J1QtI^Xi?4UF8&Dzsb3+0!E zWq{g@<-&?ogD}G~3ar&twzew4u6JJImnBoe&&a#e03=!P^pTgu^AopCC86ASI<+PI z%F78l#3eqI4|Jb?1bmJAS!gRT`ShuqyI?$OThE@tion`N8otK?E0T*}y{7hG_?r}M zMkKQ-zwi&pzLrmH>FEVj_2D$ECZxj(iI@&9K4h#I=$1rLuG#07Vy zL{_tcU{4o``qXz7D8zsQeew9of73FWG6sA*Ggep1{A}6=`Ob(DmAtI8*$;kOPL<8p zR~&)ErLB-2NS+O#u+2iO^fG4E#)mBTf)%%Gq+@g+@^*_A#MuQKY@@LVBE49JNUxY} z^z_9y6zzK9VIh3-e-Wn}W%Z-~PSpPCcaP^Fn~ z)7gzGn}(jt$3>Fy%OxI{8M8q)awE^Mu|Ms}#NAju5juhuqE}h^04>s$lJE`ygG8<+(Ps=_g-ubSUwwe^%mD%&jn*^LM{iJ!M$gy?jLZ z;u)Z=D66Xe#lKXUi!W2#Tn=E%eopPhvkUIWiV%h&vS%7h7;Yj+JU(M`xh*w5sDAie z2xZ7irWZIjr7bs~F~TSBpTH`|*iH|j4UWyuUW2>D2VyzK<~kd+9c`FHQ{h#H1l1?l zYqq1@;kU~){`}wGk2M4RBp+-6EEmKO4|IF+q2>NDXoW~`Lxe5Y?6q35<@B~teWhYK z-v};KtI=Wts9@PO8fIGfdq0%DrX_s_6{mAmmyNdiT9>0`n|wQJ_Ebd~d3JHN)zZ%U z#|~h&GOY5x3|NIIE3RH;`;!9c!3U9cw)2ER#>i_sr)91*&O3v6>A?TgPbo>d#37}pXa$h z_M%$`j%}GD$>BcRvJ6R1_}WfFRL)HlE+=0tdB*fKL(>^><~^_lGxSS;1wLM8;KG-jh%`SDkr=UhDhlLiN0-mGo!e838|ljMGcsxhKo*2> zzLP@kXB!?@i{+~!TfwHUmbo83dTeJ6FWY%m856x=o*G>QUoIB6_|))-ZmpM978Y6cWl)Sj!4jO(SslYMFhv>eMh@6J6n&6=Hj z3+ZpjcX&cn_> [1, 1, 65024] probs = softmax(lm_logits) -> [1, 65024] - next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num + next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num if next_tokens == eos_token_id 推理结束退出循环 input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num -response = tokenizer.decode(outputs) \ No newline at end of file +response = tokenizer.decode(outputs) + +## RMSNorm + +hidden_states -> [6, 1, 4096] 4096:hidden_size +variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1] +hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数 +self.weight -> [4096] +return (self.weight * hidden_states) -> [6, 1, 4096] \ No newline at end of file diff --git a/chatglm/modeling_chatglm.py b/chatglm/modeling_chatglm.py index a32728c..8e60b6c 100644 --- a/chatglm/modeling_chatglm.py +++ b/chatglm/modeling_chatglm.py @@ -68,6 +68,7 @@ class RMSNorm(torch.nn.Module): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + # show.DumpTensorToImage(self.weight, "RMSNorm_weight.png") return (self.weight * hidden_states).to(input_dtype) diff --git a/tensor.py b/tensor.py index 47235bf..4ef7566 100644 --- a/tensor.py +++ b/tensor.py @@ -29,3 +29,8 @@ print(x.prod(0)) print() print(x.unsqueeze(1).shape) print(x.unsqueeze(1).squeeze(1).shape) + +x = torch.tensor([[1, 2], [3, 4]]).to(float) +print(x.mean(1)) +print(x.mean(0)) +print(x.mean(0, keepdim=True))