diff --git a/mamba/demo.ipynb b/mamba/demo.ipynb new file mode 100644 index 0000000..d612674 --- /dev/null +++ b/mamba/demo.ipynb @@ -0,0 +1,236 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "531467a2-5160-4073-a990-0d81d574b014", + "metadata": {}, + "source": [ + "## (1) Load model" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "d9337043-4e7a-4b20-9d89-6c6257245334", + "metadata": {}, + "outputs": [], + "source": [ + "from model import Mamba, ModelArgs\n", + "from transformers import AutoTokenizer\n", + "\n", + "# One of:\n", + "# 'state-spaces/mamba-2.8b-slimpj'\n", + "# 'state-spaces/mamba-2.8b'\n", + "# 'state-spaces/mamba-1.4b'\n", + "# 'state-spaces/mamba-790m'\n", + "# 'state-spaces/mamba-370m'\n", + "# 'state-spaces/mamba-130m'\n", + "pretrained_model_name = 'state-spaces/mamba-370m'\n", + "\n", + "model = Mamba.from_pretrained(pretrained_model_name)\n", + "tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')" + ] + }, + { + "cell_type": "markdown", + "id": "0b2efb17-37ad-472b-b029-9567acf17629", + "metadata": {}, + "source": [ + "## (2) Generate Text" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c4b2d62d-0d95-4a3f-bd98-aa37e3f26b39", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "def generate(model,\n", + " tokenizer,\n", + " prompt: str,\n", + " n_tokens_to_gen: int = 50,\n", + " sample: bool = True,\n", + " top_k: int = 40):\n", + " model.eval()\n", + " \n", + " input_ids = tokenizer(prompt, return_tensors='pt').input_ids\n", + " \n", + " for token_n in range(n_tokens_to_gen):\n", + " with torch.no_grad():\n", + " indices_to_input = input_ids\n", + " next_token_logits = model(indices_to_input)[:, -1]\n", + " \n", + " probs = F.softmax(next_token_logits, dim=-1)\n", + " (batch, vocab_size) = probs.shape\n", + " \n", + " if top_k is not None:\n", + " (values, indices) = torch.topk(probs, k=top_k)\n", + " probs[probs < values[:, -1, None]] = 0\n", + " probs = probs / probs.sum(axis=1, keepdims=True)\n", + " \n", + " if sample:\n", + " next_indices = torch.multinomial(probs, num_samples=1)\n", + " else:\n", + " next_indices = torch.argmax(probs, dim=-1)[:, None]\n", + " \n", + " input_ids = torch.cat([input_ids, next_indices], dim=1)\n", + "\n", + " output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n", + " \n", + " return output_completions" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ee877143-2042-4579-8042-a96db6200517", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)\n" + ] + } + ], + "source": [ + "print(generate(model, tokenizer, 'Mamba is the'))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "65d70549-597f-49ca-9185-2184d2576f7d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "John: Hi!\n", + "Sally: Hey!\n", + "John: So, when's the wedding?\n", + "Sally: We haven't decided.\n", + "John: It's in September.\n", + "Sally: Yeah, we were thinking July or\n", + "August.\n", + "John: I'm not too\n" + ] + } + ], + "source": [ + "print(generate(model, tokenizer, 'John: Hi!\\nSally:'))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6d419fc9-066b-4818-812c-2f1952528bc6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The meaning of life is \n", + "just this: It is the best you can do.\n", + "\n", + "--K.J.\n", + "\n", + "And finally: How to handle your emotions. \n", + "\n", + "<|endoftext|>Q:\n", + "\n", + "Error creating an EntityManager instance in JavaEE 7\n", + "\n", + "This is\n" + ] + } + ], + "source": [ + "print(generate(model, tokenizer, 'The meaning of life is '))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2b189e6e-6a96-4770-88cf-7c5de22cb321", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "def reverse_string(text, result):\n", + " # find the position of the start of the string.\n", + " start = text.index(text[0:-1])\n", + " # find the position where the string begins changing.\n", + " end = text.index\n" + ] + } + ], + "source": [ + "print(generate(model, tokenizer, 'def reverse_string('))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be3afb51-5093-4c64-ac3f-43c2e6b20b10", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6531acc0-b18f-472a-8e99-cee64dd51cd8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0efe197-891a-4ab8-8cea-413d1fb1acda", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e99509b-df7b-4bac-b6a2-669f601ec1c8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mamba/demo.py b/mamba/demo.py new file mode 100644 index 0000000..d652b9b --- /dev/null +++ b/mamba/demo.py @@ -0,0 +1,50 @@ +from model import Mamba, ModelArgs +from transformers import AutoTokenizer +import torch +import torch.nn.functional as F + + +# One of: +# 'state-spaces/mamba-2.8b-slimpj' +# 'state-spaces/mamba-2.8b' +# 'state-spaces/mamba-1.4b' +# 'state-spaces/mamba-790m' +# 'state-spaces/mamba-370m' +# 'state-spaces/mamba-130m' +pretrained_model_name = "state-spaces/mamba-370m" + +model = Mamba.from_pretrained(pretrained_model_name) +tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + + +def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 50, sample: bool = True, top_k: int = 40): + model.eval() + + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + for token_n in range(n_tokens_to_gen): + with torch.no_grad(): + indices_to_input = input_ids + next_token_logits = model(indices_to_input)[:, -1] + + probs = F.softmax(next_token_logits, dim=-1) + (batch, vocab_size) = probs.shape + + if top_k is not None: + (values, indices) = torch.topk(probs, k=top_k) + probs[probs < values[:, -1, None]] = 0 + probs = probs / probs.sum(axis=1, keepdims=True) + + if sample: + next_indices = torch.multinomial(probs, num_samples=1) + else: + next_indices = torch.argmax(probs, dim=-1)[:, None] + + input_ids = torch.cat([input_ids, next_indices], dim=1) + + output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0] + + return output_completions + + +print(generate(model, tokenizer, "Mamba is the")) diff --git a/mamba/model.py b/mamba/model.py new file mode 100644 index 0000000..ab24d22 --- /dev/null +++ b/mamba/model.py @@ -0,0 +1,341 @@ +"""Simple, minimal implementation of Mamba in one file of PyTorch. + +Suggest reading the following before/while reading the code: + [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao) + https://arxiv.org/abs/2312.00752 + [2] The Annotated S4 (Sasha Rush and Sidd Karamcheti) + https://srush.github.io/annotated-s4 + +Glossary: + b: batch size (`B` in Mamba paper [1] Algorithm 2) + l: sequence length (`L` in [1] Algorithm 2) + d or d_model: hidden dim + n or d_state: latent state dim (`N` in [1] Algorithm 2) + expand: expansion factor (`E` in [1] Section 3.4) + d_in or d_inner: d * expand (`D` in [1] Algorithm 2) + A, B, C, D: state space parameters (See any state space representation formula) + (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not) + Δ or delta: input-dependent step size + dt_rank: rank of Δ (See [1] Section 3.6 "Parameterization of ∆") + +""" +from __future__ import annotations +import math +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass +from einops import rearrange, repeat, einsum + + +@dataclass +class ModelArgs: + d_model: int + n_layer: int + vocab_size: int + d_state: int = 16 + expand: int = 2 + dt_rank: Union[int, str] = 'auto' + d_conv: int = 4 + pad_vocab_size_multiple: int = 8 + conv_bias: bool = True + bias: bool = False + + def __post_init__(self): + self.d_inner = int(self.expand * self.d_model) + + if self.dt_rank == 'auto': + self.dt_rank = math.ceil(self.d_model / 16) + + if self.vocab_size % self.pad_vocab_size_multiple != 0: + self.vocab_size += (self.pad_vocab_size_multiple + - self.vocab_size % self.pad_vocab_size_multiple) + + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + """Full Mamba model.""" + super().__init__() + self.args = args + + self.embedding = nn.Embedding(args.vocab_size, args.d_model) + self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)]) + self.norm_f = RMSNorm(args.d_model) + + self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) + self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights. + # See "Weight Tying" paper + + + def forward(self, input_ids): + """ + Args: + input_ids (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...) + + Returns: + logits: shape (b, l, vocab_size) + + Official Implementation: + class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173 + + """ + x = self.embedding(input_ids) + + for layer in self.layers: + x = layer(x) + + x = self.norm_f(x) + logits = self.lm_head(x) + + return logits + + + @staticmethod + def from_pretrained(pretrained_model_name: str): + """Load pretrained weights from HuggingFace into model. + + Args: + pretrained_model_name: One of + * 'state-spaces/mamba-2.8b-slimpj' + * 'state-spaces/mamba-2.8b' + * 'state-spaces/mamba-1.4b' + * 'state-spaces/mamba-790m' + * 'state-spaces/mamba-370m' + * 'state-spaces/mamba-130m' + + Returns: + model: Mamba model with weights loaded + + """ + from transformers.utils import WEIGHTS_NAME, CONFIG_NAME + from transformers.utils.hub import cached_file + + def load_config_hf(model_name): + resolved_archive_file = cached_file(model_name, CONFIG_NAME, + _raise_exceptions_for_missing_entries=False) + return json.load(open(resolved_archive_file)) + + + def load_state_dict_hf(model_name, device=None, dtype=None): + resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, + _raise_exceptions_for_missing_entries=False) + return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True) + + config_data = load_config_hf(pretrained_model_name) + args = ModelArgs( + d_model=config_data['d_model'], + n_layer=config_data['n_layer'], + vocab_size=config_data['vocab_size'] + ) + model = Mamba(args) + + state_dict = load_state_dict_hf(pretrained_model_name) + new_state_dict = {} + for key in state_dict: + new_key = key.replace('backbone.', '') + new_state_dict[new_key] = state_dict[key] + model.load_state_dict(new_state_dict) + + return model + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + """Simple block wrapping Mamba block with normalization and residual connection.""" + super().__init__() + self.args = args + self.mixer = MambaBlock(args) + self.norm = RMSNorm(args.d_model) + + + def forward(self, x): + """ + Args: + x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...) + + Returns: + output: shape (b, l, d) + + Official Implementation: + Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297 + + Note: the official repo chains residual blocks that look like + [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ... + where the first Add is a no-op. This is purely for performance reasons as this + allows them to fuse the Add->Norm. + + We instead implement our blocks as the more familiar, simpler, and numerically equivalent + [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> .... + + """ + output = self.mixer(self.norm(x)) + x + + return output + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" + super().__init__() + self.args = args + + self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias) + + self.conv1d = nn.Conv1d( + in_channels=args.d_inner, + out_channels=args.d_inner, + bias=args.conv_bias, + kernel_size=args.d_conv, + groups=args.d_inner, + padding=args.d_conv - 1, + ) + + # x_proj takes in `x` and outputs the input-specific Δ, B, C + self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False) + + # dt_proj projects Δ from dt_rank to d_in + self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) + + A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner) + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(args.d_inner)) + self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias) + + + def forward(self, x): + """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1]. + + Args: + x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...) + + Returns: + output: shape (b, l, d) + + Official Implementation: + class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119 + mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 + + """ + (b, l, d) = x.shape + + x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in) + (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1) + + x = rearrange(x, 'b l d_in -> b d_in l') + x = self.conv1d(x)[:, :, :l] + x = rearrange(x, 'b d_in l -> b l d_in') + + x = F.silu(x) + + y = self.ssm(x) + + y = y * F.silu(res) + + output = self.out_proj(y) + + return output + + + def ssm(self, x): + """Runs the SSM. See: + - Algorithm 2 in Section 3.2 in the Mamba paper [1] + - run_SSM(A, B, C, u) in The Annotated S4 [2] + + Args: + x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) + + Returns: + output: shape (b, l, d_in) + + Official Implementation: + mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 + + """ + (d_in, n) = self.A_log.shape + + # Compute ∆ A B C D, the state space parameters. + # A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + # ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + # and is why Mamba is called **selective** state spaces) + + A = -torch.exp(self.A_log.float()) # shape (d_in, n) + D = self.D.float() + + x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n) + + (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n) + delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) + + y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2] + + return y + + + def selective_scan(self, u, delta, A, B, C, D): + """Does selective scan algorithm. See: + - Section 2 State Space Models in the Mamba paper [1] + - Algorithm 2 in Section 3.2 in the Mamba paper [1] + - run_SSM(A, B, C, u) in The Annotated S4 [2] + + This is the classic discrete state space formula: + x(t + 1) = Ax(t) + Bu(t) + y(t) = Cx(t) + Du(t) + except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t). + + Args: + u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) + delta: shape (b, l, d_in) + A: shape (d_in, n) + B: shape (b, l, n) + C: shape (b, l, n) + D: shape (d_in,) + + Returns: + output: shape (b, l, d_in) + + Official Implementation: + selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 + Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly. + + """ + (b, l, d_in) = u.shape + n = A.shape[1] + + # Discretize continuous parameters (A, B) + # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1]) + # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: + # "A is the more important term and the performance doesn't change much with the simplification on B" + deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n')) + deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n') + + # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) + # Note that the below is sequential, while the official implementation does a much faster parallel scan that + # is additionally hardware-aware (like FlashAttention). + x = torch.zeros((b, d_in, n), device=deltaA.device) + ys = [] + for i in range(l): + x = deltaA[:, i] * x + deltaB_u[:, i] + y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') + ys.append(y) + y = torch.stack(ys, dim=1) # shape (b, l, d_in) + + y = y + u * D + + return y + + +class RMSNorm(nn.Module): + def __init__(self, + d_model: int, + eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(d_model)) + + + def forward(self, x): + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight + + return output + diff --git a/test/einsum.py b/test/einsum.py new file mode 100644 index 0000000..ac744a8 --- /dev/null +++ b/test/einsum.py @@ -0,0 +1,67 @@ +import torch +import numpy as np + +#例5,向量内积 +A = torch.randn(10) +B = torch.randn(10) +#C=torch.dot(A,B) +C = torch.einsum("i,i->",A,B) +print("before:",A.shape, B.shape) +print("after:",C.shape) + +#例6,向量外积 +A = torch.randn(10) +B = torch.randn(5) +#C = torch.outer(A,B) +C = torch.einsum("i,j->ij",A,B) +print("before:",A.shape, B.shape) +print("after:",C.shape) + +#例7,矩阵乘法 +A = torch.randn(5,4) +B = torch.randn(4,6) +#C = torch.matmul(A,B) +C = torch.einsum("ik,kj->ij",A,B) +print("before:",A.shape, B.shape) +print("after:",C.shape) + +#例8,张量缩并 +A = torch.randn(3,4,5) +B = torch.randn(4,3,6) +#C = torch.tensordot(A,B,dims=[(0,1),(1,0)]) +C = torch.einsum("ijk,jih->kh",A,B) +print("before:",A.shape, B.shape) +print("after:",C.shape) + + + +a = torch.randn(2,3) +b = torch.randn(5,3,7) +c = torch.randn(2,7) +# i = 2, k = 3, j = 5, l = 7 +torch_ein_out = torch.einsum('ik,jkl,il->ij', [a, b, c]).numpy() +m = torch.nn.Bilinear(3, 7, 5, bias=False) +m.weight.data = b +torch_org_out = m(a, c).detach().numpy() + +np_a = a.numpy() +np_b = b.numpy() +np_c = c.numpy() +np_out = np.empty((2, 5), dtype=np.float32) +# 自由索引外循环 这里是 i 和 j +for i in range(0, 2): + for j in range(0, 5): + # 求和索引内循环 这里是 k 和 l + sum_result = 0 + for k in range(0, 3): + for l in range(0, 7): + sum_result += np_a[i, k] * np_b[j, k, l] * np_c[i, l] + np_out[i, j] = sum_result + +# print("matrix a:\n", np_a) +# print("matrix b:\n", np_b) +print("torch ein out: \n", torch_ein_out) +print("torch org out: \n", torch_org_out) +print("numpy out: \n", np_out) +print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out)) +print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))