Witllm/Readme.md

4.3 KiB

data flow

input_ids = tokenizer.build_chat_input(query, history=history, role=role)

for: input_ids -> [1, 6] 1:batch_num 6:sequence_length inputs_embeds -> [6, 1, 4096] 4096:hidden_size rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin

hidden_states = inputs_embeds
for layers : GLMBlock(hidden_states, rotary_pos_emb)
hidden_states = RMSNorm(hidden_states)  # final_layernorm  ->  [6, 1, 4096]
hidden_states = hidden_states[-1:] 截取最后一个sequence  ->  [1, 1, 4096]
lm_logits = Linear(hidden_states)  ->  [1, 1, 65024]
lm_logits = lm_logits.transpose(0, 1).contiguous()  -> [1, 1, 65024]

probs = softmax(lm_logits) -> [1, 65024]  {Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
next_tokens = torch.multinomial(probs, num_samples=1)  采样  -> [1]  1:batch_num 

if next_tokens == eos_token_id 推理结束退出循环

input_ids = torch.cat([input_ids, next_tokens)  -> [1, 7]  1:batch_num

response = tokenizer.decode(outputs)

RMSNorm

hidden_states -> [6, 1, 4096] /
| pow(2) -> [6, 1, 4096] | | | mean -> [6, 1, 1] | ↓
| rsqrt( + eps) -> [6, 1, 1] \ / mul -> [6, 1, 4096] \ weight -> [4096] \ / mul -> [6, 1, 4096]

hidden_states -> [6, 1, 4096] 4096:hidden_size variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1] hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数 self.weight -> [4096] return (self.weight * hidden_states) -> [6, 1, 4096]

MLP

    hidden_states    ->  [6, 1, 4096]
      Linear         ->  [6, 1, 27392]
      /    \
  chunk1   chunk0    ->  [6, 1, 13696]
     |      |  \
     |      |  sigmoid
     |      |  /
     |      mul
      \    /
        mul        ->  [6, 1, 13696]
       Linear      ->  [6, 1, 4096]

Linear(hidden_states) no bias -> [6, 1, 27392] silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696]) Linear(intermediate_parallel) no bias -> [6, 1, 4096]

self_attention

                    x              -> [6, 1, 4096]
                    |
                  Linear           -> [6, 1, 4608]
                 /  |  \

[6, 1, 32, 128] <- q k v / |
pos_emb pos_emb
| | | | expand expand -> [6, 1, 32, 128] \ / | dot | softmax / \ / dot -> [1, 32, 6, 128] -> [6, 1, 4096] Linear -> [6, 1, 4096]

hidden_states: [s, b, h] mixed_x_layer = Linear(hidden_states) -> [6, 1, 4608] 4608:4096+256+256

(query_layer, key_layer, value_layer) = mixed_x_layer.split -> [6, 1, 4096], [6, 1, 256], [6, 1, 256]
query_layer = query_layer.view -> [6, 1, 32, 128] key_layer = key_layer.view -> [6, 1, 2, 128] value_layer = value_layer.view -> [6, 1, 2, 128]

query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb) key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb)

key_layer = key_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128] key_layer = key_layer.expand -> [6, 1, 2, 16, 128] key_layer = key_layer.contiguous().view -> [6, 1, 32, 128]

value_layer = value_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128] value_layer = value_layer.expand -> [6, 1, 2, 16, 128] value_layer = value_layer.contiguous().view -> [6, 1, 32, 128]

query_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128] key_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128] value_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128] context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128] att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = F.softmax(att, dim=-1) y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) context_layer = context_layer.permute(2, 0, 1, 3).reshape() -> [6, 1, 4096]

return Linear(context_layer) -> [6, 1, 4096]

GLMBlock

input |
| RMSNorm | self_attention | dropout | / Add |
| RMSNorm | mlp | dropout | / Add

所有的输出shape都是[6, 1, 4096], 6:sequence_length 1:batch_num 4096:hidden_size