Witllm/mamba/demo.ipynb

237 lines
6.0 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "531467a2-5160-4073-a990-0d81d574b014",
"metadata": {},
"source": [
"## (1) Load model"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "d9337043-4e7a-4b20-9d89-6c6257245334",
"metadata": {},
"outputs": [],
"source": [
"from model import Mamba, ModelArgs\n",
"from transformers import AutoTokenizer\n",
"\n",
"# One of:\n",
"# 'state-spaces/mamba-2.8b-slimpj'\n",
"# 'state-spaces/mamba-2.8b'\n",
"# 'state-spaces/mamba-1.4b'\n",
"# 'state-spaces/mamba-790m'\n",
"# 'state-spaces/mamba-370m'\n",
"# 'state-spaces/mamba-130m'\n",
"pretrained_model_name = 'state-spaces/mamba-370m'\n",
"\n",
"model = Mamba.from_pretrained(pretrained_model_name)\n",
"tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')"
]
},
{
"cell_type": "markdown",
"id": "0b2efb17-37ad-472b-b029-9567acf17629",
"metadata": {},
"source": [
"## (2) Generate Text"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c4b2d62d-0d95-4a3f-bd98-aa37e3f26b39",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"def generate(model,\n",
" tokenizer,\n",
" prompt: str,\n",
" n_tokens_to_gen: int = 50,\n",
" sample: bool = True,\n",
" top_k: int = 40):\n",
" model.eval()\n",
" \n",
" input_ids = tokenizer(prompt, return_tensors='pt').input_ids\n",
" \n",
" for token_n in range(n_tokens_to_gen):\n",
" with torch.no_grad():\n",
" indices_to_input = input_ids\n",
" next_token_logits = model(indices_to_input)[:, -1]\n",
" \n",
" probs = F.softmax(next_token_logits, dim=-1)\n",
" (batch, vocab_size) = probs.shape\n",
" \n",
" if top_k is not None:\n",
" (values, indices) = torch.topk(probs, k=top_k)\n",
" probs[probs < values[:, -1, None]] = 0\n",
" probs = probs / probs.sum(axis=1, keepdims=True)\n",
" \n",
" if sample:\n",
" next_indices = torch.multinomial(probs, num_samples=1)\n",
" else:\n",
" next_indices = torch.argmax(probs, dim=-1)[:, None]\n",
" \n",
" input_ids = torch.cat([input_ids, next_indices], dim=1)\n",
"\n",
" output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n",
" \n",
" return output_completions"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "ee877143-2042-4579-8042-a96db6200517",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'Mamba is the'))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "65d70549-597f-49ca-9185-2184d2576f7d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"John: Hi!\n",
"Sally: Hey!\n",
"John: So, when's the wedding?\n",
"Sally: We haven't decided.\n",
"John: It's in September.\n",
"Sally: Yeah, we were thinking July or\n",
"August.\n",
"John: I'm not too\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'John: Hi!\\nSally:'))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6d419fc9-066b-4818-812c-2f1952528bc6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The meaning of life is \n",
"just this: It is the best you can do.\n",
"\n",
"--K.J.\n",
"\n",
"And finally: How to handle your emotions. \n",
"\n",
"<|endoftext|>Q:\n",
"\n",
"Error creating an EntityManager instance in JavaEE 7\n",
"\n",
"This is\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'The meaning of life is '))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2b189e6e-6a96-4770-88cf-7c5de22cb321",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def reverse_string(text, result):\n",
" # find the position of the start of the string.\n",
" start = text.index(text[0:-1])\n",
" # find the position where the string begins changing.\n",
" end = text.index\n"
]
}
],
"source": [
"print(generate(model, tokenizer, 'def reverse_string('))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be3afb51-5093-4c64-ac3f-43c2e6b20b10",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "6531acc0-b18f-472a-8e99-cee64dd51cd8",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0efe197-891a-4ab8-8cea-413d1fb1acda",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e99509b-df7b-4bac-b6a2-669f601ec1c8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}