Go to file
Colin ebe48f8efc Update readme. 2023-12-22 20:01:09 +08:00
chatglm Update readme. 2023-12-22 20:01:09 +08:00
tools Add auto2d. 2023-12-21 21:20:49 +08:00
.gitignore Init code. 2023-12-21 16:53:47 +08:00
RMSNorm_weight.png Update readme. 2023-12-22 20:01:09 +08:00
Readme.md Update readme. 2023-12-22 20:01:09 +08:00
demo.py Update code. 2023-12-22 18:01:57 +08:00
embedding.py Update code. 2023-12-22 18:01:57 +08:00
rotary_pos_emb.png Refine Code. 2023-12-21 20:50:10 +08:00
tensor.py Update readme. 2023-12-22 20:01:09 +08:00

Readme.md

data flow

input_ids = tokenizer.build_chat_input(query, history=history, role=role)

for input_ids -> [1, 6] 1:batch_num 6:sequence_length inputs_embeds -> [6, 1, 4096] 4096:hidden_size rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin

hidden_states = inputs_embeds for layers : GLMBlock(hidden_states, rotary_pos_emb) hidden_states = RMSNorm(hidden_states) hidden_states = hidden_states[-1:] 截取最后一个sequence lm_logits = self.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]

probs = softmax(lm_logits) -> [1, 65024] next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num

if next_tokens == eos_token_id 推理结束退出循环

input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num

response = tokenizer.decode(outputs)

RMSNorm

hidden_states -> [6, 1, 4096] 4096:hidden_size variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1] hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数 self.weight -> [4096] return (self.weight * hidden_states) -> [6, 1, 4096]