diff --git a/Readme.md b/Readme.md index 190e9da..5bb53c5 100644 --- a/Readme.md +++ b/Readme.md @@ -3,6 +3,8 @@ ## data flow +``` + query | tokenizer -> input_ids @@ -47,9 +49,11 @@ for: input_ids = torch.cat([input_ids, next_tokens]) -> [1, 7] 1:batch_num response = tokenizer.decode(outputs) +``` ## RMSNorm +``` hidden_states -> [6, 1, 4096] / \ | pow(2) -> [6, 1, 4096] @@ -68,9 +72,11 @@ variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1] hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数 self.weight -> [4096] return (self.weight * hidden_states) -> [6, 1, 4096] +``` ## MLP +``` hidden_states -> [6, 1, 4096] Linear -> [6, 1, 27392] / \ @@ -86,9 +92,10 @@ return (self.weight * hidden_states) -> [6, 1, 4096] Linear(hidden_states) no bias -> [6, 1, 27392] silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696]) Linear(intermediate_parallel) no bias -> [6, 1, 4096] +``` ## self_attention - +``` x -> [6, 1, 4096] | Linear -> [6, 1, 4608] @@ -134,6 +141,7 @@ context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer context_layer = context_layer.permute(2, 0, 1, 3).reshape() -> [6, 1, 4096] return Linear(context_layer) -> [6, 1, 4096] +``` ## GLMBlock