Witllm/Readme.md

163 lines
5.3 KiB
Markdown
Raw Normal View History

2023-12-22 18:01:57 +08:00
## data flow
2023-12-31 15:26:02 +08:00
```
2024-01-01 22:45:16 +08:00
query -> "你好"
2023-12-29 19:55:53 +08:00
|
2024-01-01 22:45:16 +08:00
tokenizer -> [6]
2023-12-29 19:55:53 +08:00
|
rotary_pos_emb embedding -> [1, 6, 4096]
\ /
GLMBlock x 28 -> [6, 1, 4096] <━━━┓
RMSNorm -> [6, 1, 4096] ┃
[-1:] -> [1, 1, 4096] ┃
Linear -> [1, 1, 65024] ┃
softmax -> [1, 65024] ┃
multinomial -> [1] ┃
cat([input_ids, next_tokens]) ━━━┛
2023-12-22 18:01:57 +08:00
input_ids = tokenizer.build_chat_input(query, history=history, role=role)
2023-12-25 17:26:19 +08:00
for:
input_ids -> [1, 6] 1:batch_num 6:sequence_length
inputs_embeds -> [6, 1, 4096] 4096:hidden_size
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
2023-12-22 19:14:22 +08:00
2023-12-25 17:26:19 +08:00
hidden_states = inputs_embeds
2023-12-27 19:58:52 +08:00
for layers :
GLMBlock(hidden_states, rotary_pos_emb)
2023-12-25 22:53:53 +08:00
hidden_states = RMSNorm(hidden_states) # final_layernorm -> [6, 1, 4096]
hidden_states = hidden_states[-1:] 截取最后一个sequence -> [1, 1, 4096]
lm_logits = Linear(hidden_states) -> [1, 1, 65024]
2023-12-25 17:26:19 +08:00
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
2023-12-22 19:14:22 +08:00
2023-12-25 22:53:53 +08:00
probs = softmax(lm_logits) -> [1, 65024] {Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
2023-12-25 17:26:19 +08:00
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num
2023-12-22 19:14:22 +08:00
2023-12-25 17:26:19 +08:00
if next_tokens == eos_token_id 推理结束退出循环
2023-12-22 19:14:22 +08:00
2023-12-27 19:58:52 +08:00
input_ids = torch.cat([input_ids, next_tokens]) -> [1, 7] 1:batch_num
2023-12-22 18:01:57 +08:00
2023-12-22 20:01:09 +08:00
response = tokenizer.decode(outputs)
2023-12-31 15:26:02 +08:00
```
2023-12-22 20:01:09 +08:00
## RMSNorm
2023-12-31 15:26:02 +08:00
```
2023-12-26 18:14:11 +08:00
hidden_states -> [6, 1, 4096]
/ \
| pow(2) -> [6, 1, 4096]
| |
| mean -> [6, 1, 1]
| ↓
| rsqrt( + eps) -> [6, 1, 1]
\ /
mul -> [6, 1, 4096]
\ weight -> [4096]
\ /
mul -> [6, 1, 4096]
2023-12-22 20:01:09 +08:00
hidden_states -> [6, 1, 4096] 4096:hidden_size
variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1]
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
self.weight -> [4096]
2023-12-25 16:22:45 +08:00
return (self.weight * hidden_states) -> [6, 1, 4096]
2023-12-31 15:26:02 +08:00
```
2023-12-25 16:22:45 +08:00
## MLP
2023-12-31 15:26:02 +08:00
```
2023-12-26 18:14:11 +08:00
hidden_states -> [6, 1, 4096]
Linear -> [6, 1, 27392]
/ \
chunk1 chunk0 -> [6, 1, 13696]
| | \
| | sigmoid
| | /
| mul
\ /
mul -> [6, 1, 13696]
Linear -> [6, 1, 4096]
2023-12-25 16:22:45 +08:00
Linear(hidden_states) no bias -> [6, 1, 27392]
silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696])
Linear(intermediate_parallel) no bias -> [6, 1, 4096]
2023-12-31 15:26:02 +08:00
```
2023-12-25 16:22:45 +08:00
## self_attention
2023-12-31 15:26:02 +08:00
```
2023-12-26 18:14:11 +08:00
x -> [6, 1, 4096]
|
Linear -> [6, 1, 4608]
/ | \
[6, 1, 32, 128] <- q k v
/ | \
pos_emb pos_emb \
2023-12-26 18:59:28 +08:00
| | |
2023-12-26 18:14:11 +08:00
| expand expand -> [6, 1, 32, 128]
\ / |
2023-12-29 20:39:40 +08:00
┏---- dot |
2023-12-31 17:42:21 +08:00
┃ += attention_mask /
attention┃ softmax /
┃ \ /
2023-12-29 20:39:40 +08:00
┗---- dot -> [1, 32, 6, 128] -> [6, 1, 4096]
2023-12-26 18:14:11 +08:00
Linear -> [6, 1, 4096]
2023-12-25 16:22:45 +08:00
hidden_states: [s, b, h]
mixed_x_layer = Linear(hidden_states) -> [6, 1, 4608] 4608:4096+256+256
(query_layer, key_layer, value_layer) = mixed_x_layer.split -> [6, 1, 4096], [6, 1, 256], [6, 1, 256]
query_layer = query_layer.view -> [6, 1, 32, 128]
key_layer = key_layer.view -> [6, 1, 2, 128]
value_layer = value_layer.view -> [6, 1, 2, 128]
query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb)
key_layer = key_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128]
key_layer = key_layer.expand -> [6, 1, 2, 16, 128]
key_layer = key_layer.contiguous().view -> [6, 1, 32, 128]
value_layer = value_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128]
value_layer = value_layer.expand -> [6, 1, 2, 16, 128]
value_layer = value_layer.contiguous().view -> [6, 1, 32, 128]
2023-12-25 17:26:19 +08:00
query_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128]
key_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128]
value_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128]
context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128]
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
context_layer = context_layer.permute(2, 0, 1, 3).reshape() -> [6, 1, 4096]
2023-12-25 16:22:45 +08:00
return Linear(context_layer) -> [6, 1, 4096]
2023-12-31 15:26:02 +08:00
```
2023-12-25 16:22:45 +08:00
## GLMBlock
2023-12-31 15:26:29 +08:00
```
2023-12-25 16:22:45 +08:00
input
| \
| RMSNorm
| self_attention
| dropout
| /
Add
| \
| RMSNorm
2023-12-29 20:39:40 +08:00
| MLP
2023-12-25 16:22:45 +08:00
| dropout
| /
Add
2023-12-31 15:26:29 +08:00
```
2023-12-25 16:22:45 +08:00
所有的输出shape都是[6, 1, 4096], 6:sequence_length 1:batch_num 4096:hidden_size