Go to file
Colin fa7078b72d Update code. 2023-12-25 16:22:45 +08:00
chatglm Update code. 2023-12-25 16:22:45 +08:00
tools Add auto2d. 2023-12-21 21:20:49 +08:00
.gitignore Init code. 2023-12-21 16:53:47 +08:00
RMSNorm_weight.png Update readme. 2023-12-22 20:01:09 +08:00
Readme.md Update code. 2023-12-25 16:22:45 +08:00
demo.py Update code. 2023-12-25 16:22:45 +08:00
embedding.py Update code. 2023-12-22 18:01:57 +08:00
rotary_pos_emb.png Refine Code. 2023-12-21 20:50:10 +08:00
tensor.py Update code. 2023-12-25 16:22:45 +08:00

Readme.md

data flow

input_ids = tokenizer.build_chat_input(query, history=history, role=role)

for input_ids -> [1, 6] 1:batch_num 6:sequence_length inputs_embeds -> [6, 1, 4096] 4096:hidden_size rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin

hidden_states = inputs_embeds for layers : GLMBlock(hidden_states, rotary_pos_emb) hidden_states = RMSNorm(hidden_states) hidden_states = hidden_states[-1:] 截取最后一个sequence lm_logits = self.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]

probs = softmax(lm_logits) -> [1, 65024] next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num

if next_tokens == eos_token_id 推理结束退出循环

input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num

response = tokenizer.decode(outputs)

RMSNorm

hidden_states -> [6, 1, 4096] 4096:hidden_size variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1] hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数 self.weight -> [4096] return (self.weight * hidden_states) -> [6, 1, 4096]

MLP

Linear(hidden_states) no bias -> [6, 1, 27392] silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696]) Linear(intermediate_parallel) no bias -> [6, 1, 4096]

core_attention

query_layer=query_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128] key_layer=key_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128] value_layer=value_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128] context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128] softmax(QK^T/sqrt(in_dim))V att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = F.softmax(att, dim=-1) y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) context_layer = context_layer.permute(2, 0, 1, 3) context_layer = context_layer.reshape() -> [6, 1, 4096]

self_attention

hidden_states: [s, b, h] mixed_x_layer = Linear(hidden_states) -> [6, 1, 4608] 4608:4096+256+256

(query_layer, key_layer, value_layer) = mixed_x_layer.split -> [6, 1, 4096], [6, 1, 256], [6, 1, 256]
query_layer = query_layer.view -> [6, 1, 32, 128] key_layer = key_layer.view -> [6, 1, 2, 128] value_layer = value_layer.view -> [6, 1, 2, 128]

query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb) key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb)

key_layer = key_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128] key_layer = key_layer.expand -> [6, 1, 2, 16, 128] key_layer = key_layer.contiguous().view -> [6, 1, 32, 128]

value_layer = value_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128] value_layer = value_layer.expand -> [6, 1, 2, 16, 128] value_layer = value_layer.contiguous().view -> [6, 1, 32, 128]

context_layer = self.core_attention(query_layer, key_layer, value_layer) -> [6, 1, 4096] return Linear(context_layer) -> [6, 1, 4096]

GLMBlock

input |
| RMSNorm | self_attention | dropout | / Add |
| RMSNorm | mlp | dropout | / Add

所有的输出shape都是[6, 1, 4096], 6:sequence_length 1:batch_num 4096:hidden_size