diff --git a/Readme.md b/Readme.md index 0760be7..af96411 100644 --- a/Readme.md +++ b/Readme.md @@ -5,7 +5,7 @@ input_ids = tokenizer.build_chat_input(query, history=history, role=role) -input_ids -> [1, 6] +input_ids -> [1, 6] 1:batch_num 6:sequence_length inputs_embeds -> [6, 1, 4096] 4096:hidden_size rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin @@ -17,7 +17,7 @@ lm_logits = self.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024] probs = softmax(lm_logits) -> [1, 65024] -next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] -input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] +next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num +input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num response = tokenizer.decode(outputs) \ No newline at end of file diff --git a/chatglm/modeling_chatglm.py b/chatglm/modeling_chatglm.py index 4264e92..9e16ea4 100644 --- a/chatglm/modeling_chatglm.py +++ b/chatglm/modeling_chatglm.py @@ -676,11 +676,9 @@ class ChatGLMForConditionalGeneration(nn.Module): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) - unfinished_sequences = torch.ones( + isFinished = torch.zeros( input_ids.shape[0], dtype=torch.long, device=input_ids.device ) - - this_peer_finished = False # used by synced_gpus only while True: input_ids_in = input_ids batch_size, seq_length = input_ids_in.shape @@ -699,21 +697,13 @@ class ChatGLMForConditionalGeneration(nn.Module): probs = nn.functional.softmax(next_token_logits, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - # finished sentences should have their next token be a padding token - next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( - 1 - unfinished_sequences - ) - - # update generated ids, model inputs, and length for next step + # finished sentences should add a padding token to next + pad_token = pad_token_id * isFinished + next_tokens = next_tokens * (1 - isFinished) + pad_token input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - # if eos_token was found in one sentence, set sentence to finished - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) - ) - if unfinished_sequences.max() == 0: + isFinished = isFinished | next_tokens.eq(eos_token_id_tensor) + if isFinished.min() == 1: # all batch is finish break return input_ids diff --git a/tensor.py b/tensor.py new file mode 100644 index 0000000..47235bf --- /dev/null +++ b/tensor.py @@ -0,0 +1,31 @@ +import torch + +x = torch.tensor([[1, 2], [3, 4]]) + +print(x) +print("x.tile((2)) -> ", x.tile((2)).shape) +print(x.tile((2))) + +print() +print("x.tile((2, 1)) -> ", x.tile((2, 1)).shape) +print(x.tile((2, 1))) + +print() +print("x.tile((2, 1, 2)) -> ", x.tile((2, 1, 2)).shape) +print(x.tile((2, 1, 2))) + +print() +print("x.tile((2, 1, 1)) -> ", x.tile((2, 1, 1)).shape) +print(x.tile((2, 1, 1))) + +print() +y = torch.tensor([[2, 1], [3, 4]]) +print(y.ne(x)) + +print() +print(x.prod(1)) +print(x.prod(0)) + +print() +print(x.unsqueeze(1).shape) +print(x.unsqueeze(1).squeeze(1).shape)