更新 Readme.md

This commit is contained in:
colin 2023-12-31 15:26:02 +08:00
parent dff2b9231f
commit d9b64e4025
1 changed files with 9 additions and 1 deletions

View File

@ -3,6 +3,8 @@
## data flow
```
query
|
tokenizer -> input_ids
@ -47,9 +49,11 @@ for:
input_ids = torch.cat([input_ids, next_tokens]) -> [1, 7] 1:batch_num
response = tokenizer.decode(outputs)
```
## RMSNorm
```
hidden_states -> [6, 1, 4096]
/ \
| pow(2) -> [6, 1, 4096]
@ -68,9 +72,11 @@ variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1]
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
self.weight -> [4096]
return (self.weight * hidden_states) -> [6, 1, 4096]
```
## MLP
```
hidden_states -> [6, 1, 4096]
Linear -> [6, 1, 27392]
/ \
@ -86,9 +92,10 @@ return (self.weight * hidden_states) -> [6, 1, 4096]
Linear(hidden_states) no bias -> [6, 1, 27392]
silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696])
Linear(intermediate_parallel) no bias -> [6, 1, 4096]
```
## self_attention
```
x -> [6, 1, 4096]
|
Linear -> [6, 1, 4608]
@ -134,6 +141,7 @@ context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer
context_layer = context_layer.permute(2, 0, 1, 3).reshape() -> [6, 1, 4096]
return Linear(context_layer) -> [6, 1, 4096]
```
## GLMBlock