From d9b64e40259fcc9e6d02ce571458fbacd9f5d155 Mon Sep 17 00:00:00 2001 From: colin Date: Sun, 31 Dec 2023 15:26:02 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20Readme.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Readme.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/Readme.md b/Readme.md index 190e9da..5bb53c5 100644 --- a/Readme.md +++ b/Readme.md @@ -3,6 +3,8 @@ ## data flow +``` + query | tokenizer -> input_ids @@ -47,9 +49,11 @@ for: input_ids = torch.cat([input_ids, next_tokens]) -> [1, 7] 1:batch_num response = tokenizer.decode(outputs) +``` ## RMSNorm +``` hidden_states -> [6, 1, 4096] / \ | pow(2) -> [6, 1, 4096] @@ -68,9 +72,11 @@ variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1] hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数 self.weight -> [4096] return (self.weight * hidden_states) -> [6, 1, 4096] +``` ## MLP +``` hidden_states -> [6, 1, 4096] Linear -> [6, 1, 27392] / \ @@ -86,9 +92,10 @@ return (self.weight * hidden_states) -> [6, 1, 4096] Linear(hidden_states) no bias -> [6, 1, 27392] silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696]) Linear(intermediate_parallel) no bias -> [6, 1, 4096] +``` ## self_attention - +``` x -> [6, 1, 4096] | Linear -> [6, 1, 4608] @@ -134,6 +141,7 @@ context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer context_layer = context_layer.permute(2, 0, 1, 3).reshape() -> [6, 1, 4096] return Linear(context_layer) -> [6, 1, 4096] +``` ## GLMBlock