Refien sample code.
This commit is contained in:
parent
10268c4414
commit
72787b9268
|
@ -5,7 +5,7 @@
|
|||
|
||||
input_ids = tokenizer.build_chat_input(query, history=history, role=role)
|
||||
|
||||
input_ids -> [1, 6]
|
||||
input_ids -> [1, 6] 1:batch_num 6:sequence_length
|
||||
inputs_embeds -> [6, 1, 4096] 4096:hidden_size
|
||||
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
|
||||
|
||||
|
@ -17,7 +17,7 @@ lm_logits = self.output_layer(hidden_states)
|
|||
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
|
||||
|
||||
probs = softmax(lm_logits) -> [1, 65024]
|
||||
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1]
|
||||
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7]
|
||||
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num
|
||||
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num
|
||||
|
||||
response = tokenizer.decode(outputs)
|
|
@ -676,11 +676,9 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
|||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device)
|
||||
|
||||
unfinished_sequences = torch.ones(
|
||||
isFinished = torch.zeros(
|
||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while True:
|
||||
input_ids_in = input_ids
|
||||
batch_size, seq_length = input_ids_in.shape
|
||||
|
@ -699,21 +697,13 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
|||
probs = nn.functional.softmax(next_token_logits, dim=-1)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
||||
1 - unfinished_sequences
|
||||
)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
# finished sentences should add a padding token to next
|
||||
pad_token = pad_token_id * isFinished
|
||||
next_tokens = next_tokens * (1 - isFinished) + pad_token
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1)
|
||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
if unfinished_sequences.max() == 0:
|
||||
isFinished = isFinished | next_tokens.eq(eos_token_id_tensor)
|
||||
if isFinished.min() == 1: # all batch is finish
|
||||
break
|
||||
|
||||
return input_ids
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
import torch
|
||||
|
||||
x = torch.tensor([[1, 2], [3, 4]])
|
||||
|
||||
print(x)
|
||||
print("x.tile((2)) -> ", x.tile((2)).shape)
|
||||
print(x.tile((2)))
|
||||
|
||||
print()
|
||||
print("x.tile((2, 1)) -> ", x.tile((2, 1)).shape)
|
||||
print(x.tile((2, 1)))
|
||||
|
||||
print()
|
||||
print("x.tile((2, 1, 2)) -> ", x.tile((2, 1, 2)).shape)
|
||||
print(x.tile((2, 1, 2)))
|
||||
|
||||
print()
|
||||
print("x.tile((2, 1, 1)) -> ", x.tile((2, 1, 1)).shape)
|
||||
print(x.tile((2, 1, 1)))
|
||||
|
||||
print()
|
||||
y = torch.tensor([[2, 1], [3, 4]])
|
||||
print(y.ne(x))
|
||||
|
||||
print()
|
||||
print(x.prod(1))
|
||||
print(x.prod(0))
|
||||
|
||||
print()
|
||||
print(x.unsqueeze(1).shape)
|
||||
print(x.unsqueeze(1).squeeze(1).shape)
|
Loading…
Reference in New Issue