diff --git a/Readme.md b/Readme.md index 09dfe4a..d30d939 100644 --- a/Readme.md +++ b/Readme.md @@ -28,6 +28,19 @@ response = tokenizer.decode(outputs) ## RMSNorm +hidden_states -> [6, 1, 4096] + / \ +| pow(2) -> [6, 1, 4096] +| | +| mean -> [6, 1, 1] +| ↓ +| rsqrt( + eps) -> [6, 1, 1] + \ / + mul -> [6, 1, 4096] + \ weight -> [4096] + \ / + mul -> [6, 1, 4096] + hidden_states -> [6, 1, 4096] 4096:hidden_size variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1] hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数 @@ -36,12 +49,40 @@ return (self.weight * hidden_states) -> [6, 1, 4096] ## MLP + hidden_states -> [6, 1, 4096] + Linear -> [6, 1, 27392] + / \ + chunk1 chunk0 -> [6, 1, 13696] + | | \ + | | sigmoid + | | / + | mul + \ / + mul -> [6, 1, 13696] + Linear -> [6, 1, 4096] + Linear(hidden_states) no bias -> [6, 1, 27392] silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696]) Linear(intermediate_parallel) no bias -> [6, 1, 4096] ## self_attention + x -> [6, 1, 4096] + | + Linear -> [6, 1, 4608] + / | \ +[6, 1, 32, 128] <- q k v + / | \ + pos_emb pos_emb \ + | | \ + | expand expand -> [6, 1, 32, 128] + \ / | + dot | + softmax / + \ / + dot -> [1, 32, 6, 128] + Linear -> [6, 1, 4096] + hidden_states: [s, b, h] mixed_x_layer = Linear(hidden_states) -> [6, 1, 4608] 4608:4096+256+256 diff --git a/RMSNorm_weight.png b/generated/RMSNorm_weight.png similarity index 100% rename from RMSNorm_weight.png rename to generated/RMSNorm_weight.png diff --git a/rotary_pos_emb.png b/generated/rotary_pos_emb.png similarity index 100% rename from rotary_pos_emb.png rename to generated/rotary_pos_emb.png diff --git a/test_tokenizer.py b/test_tokenizer.py index 6902524..b05fb3e 100644 --- a/test_tokenizer.py +++ b/test_tokenizer.py @@ -23,9 +23,8 @@ init_kwargs["name_or_path"] = pretrained_model_name_or_path tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs) -aa = tokenizer.build_chat_input("骉") -ab = tokenizer.encode("骉") -a = tokenizer.decode([236,173,140]) +a = tokenizer.encode("骉") +b = tokenizer.decode([236,173,140])