237 lines
6.0 KiB
Plaintext
237 lines
6.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "531467a2-5160-4073-a990-0d81d574b014",
|
|
"metadata": {},
|
|
"source": [
|
|
"## (1) Load model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "d9337043-4e7a-4b20-9d89-6c6257245334",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from model import Mamba, ModelArgs\n",
|
|
"from transformers import AutoTokenizer\n",
|
|
"\n",
|
|
"# One of:\n",
|
|
"# 'state-spaces/mamba-2.8b-slimpj'\n",
|
|
"# 'state-spaces/mamba-2.8b'\n",
|
|
"# 'state-spaces/mamba-1.4b'\n",
|
|
"# 'state-spaces/mamba-790m'\n",
|
|
"# 'state-spaces/mamba-370m'\n",
|
|
"# 'state-spaces/mamba-130m'\n",
|
|
"pretrained_model_name = 'state-spaces/mamba-370m'\n",
|
|
"\n",
|
|
"model = Mamba.from_pretrained(pretrained_model_name)\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0b2efb17-37ad-472b-b029-9567acf17629",
|
|
"metadata": {},
|
|
"source": [
|
|
"## (2) Generate Text"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "c4b2d62d-0d95-4a3f-bd98-aa37e3f26b39",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torch.nn.functional as F\n",
|
|
"\n",
|
|
"\n",
|
|
"def generate(model,\n",
|
|
" tokenizer,\n",
|
|
" prompt: str,\n",
|
|
" n_tokens_to_gen: int = 50,\n",
|
|
" sample: bool = True,\n",
|
|
" top_k: int = 40):\n",
|
|
" model.eval()\n",
|
|
" \n",
|
|
" input_ids = tokenizer(prompt, return_tensors='pt').input_ids\n",
|
|
" \n",
|
|
" for token_n in range(n_tokens_to_gen):\n",
|
|
" with torch.no_grad():\n",
|
|
" indices_to_input = input_ids\n",
|
|
" next_token_logits = model(indices_to_input)[:, -1]\n",
|
|
" \n",
|
|
" probs = F.softmax(next_token_logits, dim=-1)\n",
|
|
" (batch, vocab_size) = probs.shape\n",
|
|
" \n",
|
|
" if top_k is not None:\n",
|
|
" (values, indices) = torch.topk(probs, k=top_k)\n",
|
|
" probs[probs < values[:, -1, None]] = 0\n",
|
|
" probs = probs / probs.sum(axis=1, keepdims=True)\n",
|
|
" \n",
|
|
" if sample:\n",
|
|
" next_indices = torch.multinomial(probs, num_samples=1)\n",
|
|
" else:\n",
|
|
" next_indices = torch.argmax(probs, dim=-1)[:, None]\n",
|
|
" \n",
|
|
" input_ids = torch.cat([input_ids, next_indices], dim=1)\n",
|
|
"\n",
|
|
" output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n",
|
|
" \n",
|
|
" return output_completions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "ee877143-2042-4579-8042-a96db6200517",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(generate(model, tokenizer, 'Mamba is the'))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "65d70549-597f-49ca-9185-2184d2576f7d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"John: Hi!\n",
|
|
"Sally: Hey!\n",
|
|
"John: So, when's the wedding?\n",
|
|
"Sally: We haven't decided.\n",
|
|
"John: It's in September.\n",
|
|
"Sally: Yeah, we were thinking July or\n",
|
|
"August.\n",
|
|
"John: I'm not too\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(generate(model, tokenizer, 'John: Hi!\\nSally:'))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "6d419fc9-066b-4818-812c-2f1952528bc6",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The meaning of life is \n",
|
|
"just this: It is the best you can do.\n",
|
|
"\n",
|
|
"--K.J.\n",
|
|
"\n",
|
|
"And finally: How to handle your emotions. \n",
|
|
"\n",
|
|
"<|endoftext|>Q:\n",
|
|
"\n",
|
|
"Error creating an EntityManager instance in JavaEE 7\n",
|
|
"\n",
|
|
"This is\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(generate(model, tokenizer, 'The meaning of life is '))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "2b189e6e-6a96-4770-88cf-7c5de22cb321",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"def reverse_string(text, result):\n",
|
|
" # find the position of the start of the string.\n",
|
|
" start = text.index(text[0:-1])\n",
|
|
" # find the position where the string begins changing.\n",
|
|
" end = text.index\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(generate(model, tokenizer, 'def reverse_string('))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "be3afb51-5093-4c64-ac3f-43c2e6b20b10",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6531acc0-b18f-472a-8e99-cee64dd51cd8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d0efe197-891a-4ab8-8cea-413d1fb1acda",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2e99509b-df7b-4bac-b6a2-669f601ec1c8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|