27 lines
864 B
Markdown
27 lines
864 B
Markdown
|
|
|
|
|
|
## data flow
|
|
|
|
input_ids = tokenizer.build_chat_input(query, history=history, role=role)
|
|
|
|
for
|
|
input_ids -> [1, 6] 1:batch_num 6:sequence_length
|
|
inputs_embeds -> [6, 1, 4096] 4096:hidden_size
|
|
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
|
|
|
|
hidden_states = inputs_embeds
|
|
for layers : GLMBlock(hidden_states, rotary_pos_emb)
|
|
hidden_states = self.final_layernorm(hidden_states)
|
|
hidden_states = hidden_states[-1:]
|
|
lm_logits = self.output_layer(hidden_states)
|
|
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
|
|
|
|
probs = softmax(lm_logits) -> [1, 65024]
|
|
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num
|
|
|
|
if next_tokens == eos_token_id 推理结束退出循环
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num
|
|
|
|
response = tokenizer.decode(outputs) |