Witllm/chatglm/Readme.md

5.3 KiB

data flow


                        query  ->  "你好"       
                          |
                      tokenizer  -> [6]
                          |
 rotary_pos_emb       embedding  ->  [1, 6, 4096]   
               \     /
               GLMBlock x 28  ->  [6, 1, 4096]    <━━━┓
                                                      ┃
                RMSNorm  ->  [6, 1, 4096]             ┃
                                                      ┃
                 [-1:]  ->  [1, 1, 4096]              ┃
                                                      ┃
                Linear  ->  [1, 1, 65024]             ┃
                                                      ┃
                softmax  ->  [1, 65024]               ┃
                                                      ┃
               multinomial  ->  [1]                   ┃
                                                      ┃
          cat([input_ids, next_tokens])            ━━━┛


input_ids = tokenizer.build_chat_input(query, history=history, role=role)

for:
    input_ids -> [1, 6]  1:batch_num  6:sequence_length
    inputs_embeds -> [6, 1, 4096]  4096:hidden_size
    rotary_pos_emb -> [6, 1, 32, 2]  32:pos的编码维度  2:cos+sin

    hidden_states = inputs_embeds
    for layers :
        GLMBlock(hidden_states, rotary_pos_emb)
    hidden_states = RMSNorm(hidden_states)  # final_layernorm  ->  [6, 1, 4096]
    hidden_states = hidden_states[-1:] 截取最后一个sequence  ->  [1, 1, 4096]
    lm_logits = Linear(hidden_states)  ->  [1, 1, 65024]
    lm_logits = lm_logits.transpose(0, 1).contiguous()  -> [1, 1, 65024]

    probs = softmax(lm_logits) -> [1, 65024]  {Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
    next_tokens = torch.multinomial(probs, num_samples=1)  采样  -> [1]  1:batch_num 

    if next_tokens == eos_token_id 推理结束退出循环

    input_ids = torch.cat([input_ids, next_tokens])  -> [1, 7]  1:batch_num

response = tokenizer.decode(outputs)

RMSNorm

hidden_states   -> [6, 1, 4096]
 /       \
|       pow(2)  -> [6, 1, 4096]
|        |
|       mean    -> [6, 1, 1]
|        ↓  
| rsqrt(   + eps)  -> [6, 1, 1]
 \   /
  mul              -> [6, 1, 4096]
    \     weight   -> [4096]
     \    /
      mul          -> [6, 1, 4096]

hidden_states -> [6, 1, 4096]  4096:hidden_size
variance = hidden_states.pow(2).mean(-1, keepdim=True)  -> [6, 1, 1]
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
self.weight -> [4096]
return (self.weight * hidden_states)  -> [6, 1, 4096]

MLP

        hidden_states    ->  [6, 1, 4096]
          Linear         ->  [6, 1, 27392]
          /    \
      chunk1   chunk0    ->  [6, 1, 13696]
         |      |  \
         |      |  sigmoid
         |      |  /
         |      mul
          \    /
            mul        ->  [6, 1, 13696]
           Linear      ->  [6, 1, 4096]

Linear(hidden_states)  no bias  ->  [6, 1, 27392]
silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696])
Linear(intermediate_parallel)  no bias  ->  [6, 1, 4096]

self_attention

                        x              -> [6, 1, 4096]
                        |
                      Linear           -> [6, 1, 4608]
                     /  |  \
[6, 1, 32, 128] <-  q   k   v
                   /    |    \
              pos_emb pos_emb \
                  |     |      |
                  |   expand  expand   -> [6, 1, 32, 128]
                   \   /       |
         ┏----      dot        |
         ┃  += attention_mask  /
attention┃        softmax     /
         ┃             \     /
         ┗----           dot           ->  [1, 32, 6, 128]  ->  [6, 1, 4096]
                        Linear         ->  [6, 1, 4096]

hidden_states: [s, b, h]
mixed_x_layer = Linear(hidden_states)  -> [6, 1, 4608]  4608:4096+256+256

(query_layer, key_layer, value_layer) = mixed_x_layer.split  -> [6, 1, 4096], [6, 1, 256], [6, 1, 256]  
query_layer = query_layer.view  ->  [6, 1, 32, 128]
key_layer = key_layer.view  ->  [6, 1, 2, 128]
value_layer = value_layer.view  ->  [6, 1, 2, 128]

query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb)

key_layer = key_layer.unsqueeze(-2)  ->  [6, 1, 2, 1, 128]
key_layer = key_layer.expand  ->  [6, 1, 2, 16, 128]
key_layer = key_layer.contiguous().view  ->  [6, 1, 32, 128]

value_layer = value_layer.unsqueeze(-2)  ->  [6, 1, 2, 1, 128]
value_layer = value_layer.expand  ->  [6, 1, 2, 16, 128]
value_layer = value_layer.contiguous().view  ->  [6, 1, 32, 128]

query_layer permute(1, 2, 0, 3)  ->  [1, 32, 6, 128]
key_layer permute(1, 2, 0, 3)  ->  [1, 32, 6, 128]
value_layer permute(1, 2, 0, 3)  ->  [1, 32, 6, 128]
context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer)  ->  [1, 32, 6, 128]
    att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
    att = F.softmax(att, dim=-1)
    y = att @ v  ->  (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
context_layer = context_layer.permute(2, 0, 1, 3).reshape()  ->  [6, 1, 4096]

return Linear(context_layer)  ->  [6, 1, 4096]

GLMBlock

 input
 |   \
 |   RMSNorm
 |   self_attention
 |   dropout
 |   /
 Add
 |  \
 |  RMSNorm
 |  MLP
 |  dropout
 |  /
 Add

所有的输出shape都是[6, 1, 4096], 6:sequence_length 1:batch_num 4096:hidden_size