Add mamba.

This commit is contained in:
Colin 2024-04-02 15:38:49 +08:00
parent 7a8815cceb
commit e2b48c0ab4
4 changed files with 694 additions and 0 deletions

236
mamba/demo.ipynb Normal file
View File

@ -0,0 +1,236 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "531467a2-5160-4073-a990-0d81d574b014",
"metadata": {},
"source": [
"## (1) Load model"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "d9337043-4e7a-4b20-9d89-6c6257245334",
"metadata": {},
"outputs": [],
"source": [
"from model import Mamba, ModelArgs\n",
"from transformers import AutoTokenizer\n",
"\n",
"# One of:\n",
"# 'state-spaces/mamba-2.8b-slimpj'\n",
"# 'state-spaces/mamba-2.8b'\n",
"# 'state-spaces/mamba-1.4b'\n",
"# 'state-spaces/mamba-790m'\n",
"# 'state-spaces/mamba-370m'\n",
"# 'state-spaces/mamba-130m'\n",
"pretrained_model_name = 'state-spaces/mamba-370m'\n",
"\n",
"model = Mamba.from_pretrained(pretrained_model_name)\n",
"tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')"
]
},
{
"cell_type": "markdown",
"id": "0b2efb17-37ad-472b-b029-9567acf17629",
"metadata": {},
"source": [
"## (2) Generate Text"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c4b2d62d-0d95-4a3f-bd98-aa37e3f26b39",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"def generate(model,\n",
" tokenizer,\n",
" prompt: str,\n",
" n_tokens_to_gen: int = 50,\n",
" sample: bool = True,\n",
" top_k: int = 40):\n",
" model.eval()\n",
" \n",
" input_ids = tokenizer(prompt, return_tensors='pt').input_ids\n",
" \n",
" for token_n in range(n_tokens_to_gen):\n",
" with torch.no_grad():\n",
" indices_to_input = input_ids\n",
" next_token_logits = model(indices_to_input)[:, -1]\n",
" \n",
" probs = F.softmax(next_token_logits, dim=-1)\n",
" (batch, vocab_size) = probs.shape\n",
" \n",
" if top_k is not None:\n",
" (values, indices) = torch.topk(probs, k=top_k)\n",
" probs[probs < values[:, -1, None]] = 0\n",
" probs = probs / probs.sum(axis=1, keepdims=True)\n",
" \n",
" if sample:\n",
" next_indices = torch.multinomial(probs, num_samples=1)\n",
" else:\n",
" next_indices = torch.argmax(probs, dim=-1)[:, None]\n",
" \n",
" input_ids = torch.cat([input_ids, next_indices], dim=1)\n",
"\n",
" output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n",
" \n",
" return output_completions"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "ee877143-2042-4579-8042-a96db6200517",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'Mamba is the'))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "65d70549-597f-49ca-9185-2184d2576f7d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"John: Hi!\n",
"Sally: Hey!\n",
"John: So, when's the wedding?\n",
"Sally: We haven't decided.\n",
"John: It's in September.\n",
"Sally: Yeah, we were thinking July or\n",
"August.\n",
"John: I'm not too\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'John: Hi!\\nSally:'))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6d419fc9-066b-4818-812c-2f1952528bc6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The meaning of life is \n",
"just this: It is the best you can do.\n",
"\n",
"--K.J.\n",
"\n",
"And finally: How to handle your emotions. \n",
"\n",
"<|endoftext|>Q:\n",
"\n",
"Error creating an EntityManager instance in JavaEE 7\n",
"\n",
"This is\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'The meaning of life is '))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2b189e6e-6a96-4770-88cf-7c5de22cb321",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def reverse_string(text, result):\n",
" # find the position of the start of the string.\n",
" start = text.index(text[0:-1])\n",
" # find the position where the string begins changing.\n",
" end = text.index\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'def reverse_string('))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be3afb51-5093-4c64-ac3f-43c2e6b20b10",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "6531acc0-b18f-472a-8e99-cee64dd51cd8",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0efe197-891a-4ab8-8cea-413d1fb1acda",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e99509b-df7b-4bac-b6a2-669f601ec1c8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

50
mamba/demo.py Normal file
View File

@ -0,0 +1,50 @@
from model import Mamba, ModelArgs
from transformers import AutoTokenizer
import torch
import torch.nn.functional as F
# One of:
# 'state-spaces/mamba-2.8b-slimpj'
# 'state-spaces/mamba-2.8b'
# 'state-spaces/mamba-1.4b'
# 'state-spaces/mamba-790m'
# 'state-spaces/mamba-370m'
# 'state-spaces/mamba-130m'
pretrained_model_name = "state-spaces/mamba-370m"
model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 50, sample: bool = True, top_k: int = 40):
model.eval()
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
for token_n in range(n_tokens_to_gen):
with torch.no_grad():
indices_to_input = input_ids
next_token_logits = model(indices_to_input)[:, -1]
probs = F.softmax(next_token_logits, dim=-1)
(batch, vocab_size) = probs.shape
if top_k is not None:
(values, indices) = torch.topk(probs, k=top_k)
probs[probs < values[:, -1, None]] = 0
probs = probs / probs.sum(axis=1, keepdims=True)
if sample:
next_indices = torch.multinomial(probs, num_samples=1)
else:
next_indices = torch.argmax(probs, dim=-1)[:, None]
input_ids = torch.cat([input_ids, next_indices], dim=1)
output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
return output_completions
print(generate(model, tokenizer, "Mamba is the"))

341
mamba/model.py Normal file
View File

@ -0,0 +1,341 @@
"""Simple, minimal implementation of Mamba in one file of PyTorch.
Suggest reading the following before/while reading the code:
[1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
https://arxiv.org/abs/2312.00752
[2] The Annotated S4 (Sasha Rush and Sidd Karamcheti)
https://srush.github.io/annotated-s4
Glossary:
b: batch size (`B` in Mamba paper [1] Algorithm 2)
l: sequence length (`L` in [1] Algorithm 2)
d or d_model: hidden dim
n or d_state: latent state dim (`N` in [1] Algorithm 2)
expand: expansion factor (`E` in [1] Section 3.4)
d_in or d_inner: d * expand (`D` in [1] Algorithm 2)
A, B, C, D: state space parameters (See any state space representation formula)
(B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)
Δ or delta: input-dependent step size
dt_rank: rank of Δ (See [1] Section 3.6 "Parameterization of ∆")
"""
from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum
@dataclass
class ModelArgs:
d_model: int
n_layer: int
vocab_size: int
d_state: int = 16
expand: int = 2
dt_rank: Union[int, str] = 'auto'
d_conv: int = 4
pad_vocab_size_multiple: int = 8
conv_bias: bool = True
bias: bool = False
def __post_init__(self):
self.d_inner = int(self.expand * self.d_model)
if self.dt_rank == 'auto':
self.dt_rank = math.ceil(self.d_model / 16)
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple)
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
"""Full Mamba model."""
super().__init__()
self.args = args
self.embedding = nn.Embedding(args.vocab_size, args.d_model)
self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
self.norm_f = RMSNorm(args.d_model)
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights.
# See "Weight Tying" paper
def forward(self, input_ids):
"""
Args:
input_ids (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
logits: shape (b, l, vocab_size)
Official Implementation:
class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
"""
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm_f(x)
logits = self.lm_head(x)
return logits
@staticmethod
def from_pretrained(pretrained_model_name: str):
"""Load pretrained weights from HuggingFace into model.
Args:
pretrained_model_name: One of
* 'state-spaces/mamba-2.8b-slimpj'
* 'state-spaces/mamba-2.8b'
* 'state-spaces/mamba-1.4b'
* 'state-spaces/mamba-790m'
* 'state-spaces/mamba-370m'
* 'state-spaces/mamba-130m'
Returns:
model: Mamba model with weights loaded
"""
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file
def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME,
_raise_exceptions_for_missing_entries=False)
return json.load(open(resolved_archive_file))
def load_state_dict_hf(model_name, device=None, dtype=None):
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
config_data = load_config_hf(pretrained_model_name)
args = ModelArgs(
d_model=config_data['d_model'],
n_layer=config_data['n_layer'],
vocab_size=config_data['vocab_size']
)
model = Mamba(args)
state_dict = load_state_dict_hf(pretrained_model_name)
new_state_dict = {}
for key in state_dict:
new_key = key.replace('backbone.', '')
new_state_dict[new_key] = state_dict[key]
model.load_state_dict(new_state_dict)
return model
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
"""Simple block wrapping Mamba block with normalization and residual connection."""
super().__init__()
self.args = args
self.mixer = MambaBlock(args)
self.norm = RMSNorm(args.d_model)
def forward(self, x):
"""
Args:
x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
output: shape (b, l, d)
Official Implementation:
Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
Note: the official repo chains residual blocks that look like
[Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
where the first Add is a no-op. This is purely for performance reasons as this
allows them to fuse the Add->Norm.
We instead implement our blocks as the more familiar, simpler, and numerically equivalent
[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
"""
output = self.mixer(self.norm(x)) + x
return output
class MambaBlock(nn.Module):
def __init__(self, args: ModelArgs):
"""A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
super().__init__()
self.args = args
self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
self.conv1d = nn.Conv1d(
in_channels=args.d_inner,
out_channels=args.d_inner,
bias=args.conv_bias,
kernel_size=args.d_conv,
groups=args.d_inner,
padding=args.d_conv - 1,
)
# x_proj takes in `x` and outputs the input-specific Δ, B, C
self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
# dt_proj projects Δ from dt_rank to d_in
self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(args.d_inner))
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
def forward(self, x):
"""Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
Args:
x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
output: shape (b, l, d)
Official Implementation:
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(b, l, d) = x.shape
x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
x = rearrange(x, 'b l d_in -> b d_in l')
x = self.conv1d(x)[:, :, :l]
x = rearrange(x, 'b d_in l -> b l d_in')
x = F.silu(x)
y = self.ssm(x)
y = y * F.silu(res)
output = self.out_proj(y)
return output
def ssm(self, x):
"""Runs the SSM. See:
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
- run_SSM(A, B, C, u) in The Annotated S4 [2]
Args:
x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
output: shape (b, l, d_in)
Official Implementation:
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(d_in, n) = self.A_log.shape
# Compute ∆ A B C D, the state space parameters.
# A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
# ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
# and is why Mamba is called **selective** state spaces)
A = -torch.exp(self.A_log.float()) # shape (d_in, n)
D = self.D.float()
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
return y
def selective_scan(self, u, delta, A, B, C, D):
"""Does selective scan algorithm. See:
- Section 2 State Space Models in the Mamba paper [1]
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
- run_SSM(A, B, C, u) in The Annotated S4 [2]
This is the classic discrete state space formula:
x(t + 1) = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)
except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
Args:
u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
delta: shape (b, l, d_in)
A: shape (d_in, n)
B: shape (b, l, n)
C: shape (b, l, n)
D: shape (d_in,)
Returns:
output: shape (b, l, d_in)
Official Implementation:
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
"""
(b, l, d_in) = u.shape
n = A.shape[1]
# Discretize continuous parameters (A, B)
# - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
# - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
# "A is the more important term and the performance doesn't change much with the simplification on B"
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
# Perform selective scan (see scan_SSM() in The Annotated S4 [2])
# Note that the below is sequential, while the official implementation does a much faster parallel scan that
# is additionally hardware-aware (like FlashAttention).
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
ys.append(y)
y = torch.stack(ys, dim=1) # shape (b, l, d_in)
y = y + u * D
return y
class RMSNorm(nn.Module):
def __init__(self,
d_model: int,
eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output

67
test/einsum.py Normal file
View File

@ -0,0 +1,67 @@
import torch
import numpy as np
#例5向量内积
A = torch.randn(10)
B = torch.randn(10)
#C=torch.dot(A,B)
C = torch.einsum("i,i->",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
#例6向量外积
A = torch.randn(10)
B = torch.randn(5)
#C = torch.outer(A,B)
C = torch.einsum("i,j->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
#例7矩阵乘法
A = torch.randn(5,4)
B = torch.randn(4,6)
#C = torch.matmul(A,B)
C = torch.einsum("ik,kj->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
#例8张量缩并
A = torch.randn(3,4,5)
B = torch.randn(4,3,6)
#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
C = torch.einsum("ijk,jih->kh",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
# i = 2, k = 3, j = 5, l = 7
torch_ein_out = torch.einsum('ik,jkl,il->ij', [a, b, c]).numpy()
m = torch.nn.Bilinear(3, 7, 5, bias=False)
m.weight.data = b
torch_org_out = m(a, c).detach().numpy()
np_a = a.numpy()
np_b = b.numpy()
np_c = c.numpy()
np_out = np.empty((2, 5), dtype=np.float32)
# 自由索引外循环 这里是 i 和 j
for i in range(0, 2):
for j in range(0, 5):
# 求和索引内循环 这里是 k 和 l
sum_result = 0
for k in range(0, 3):
for l in range(0, 7):
sum_result += np_a[i, k] * np_b[j, k, l] * np_c[i, l]
np_out[i, j] = sum_result
# print("matrix a:\n", np_a)
# print("matrix b:\n", np_b)
print("torch ein out: \n", torch_ein_out)
print("torch org out: \n", torch_org_out)
print("numpy out: \n", np_out)
print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))