{ "cells": [ { "cell_type": "markdown", "id": "531467a2-5160-4073-a990-0d81d574b014", "metadata": {}, "source": [ "## (1) Load model" ] }, { "cell_type": "code", "execution_count": 12, "id": "d9337043-4e7a-4b20-9d89-6c6257245334", "metadata": {}, "outputs": [], "source": [ "from model import Mamba, ModelArgs\n", "from transformers import AutoTokenizer\n", "\n", "# One of:\n", "# 'state-spaces/mamba-2.8b-slimpj'\n", "# 'state-spaces/mamba-2.8b'\n", "# 'state-spaces/mamba-1.4b'\n", "# 'state-spaces/mamba-790m'\n", "# 'state-spaces/mamba-370m'\n", "# 'state-spaces/mamba-130m'\n", "pretrained_model_name = 'state-spaces/mamba-370m'\n", "\n", "model = Mamba.from_pretrained(pretrained_model_name)\n", "tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')" ] }, { "cell_type": "markdown", "id": "0b2efb17-37ad-472b-b029-9567acf17629", "metadata": {}, "source": [ "## (2) Generate Text" ] }, { "cell_type": "code", "execution_count": 3, "id": "c4b2d62d-0d95-4a3f-bd98-aa37e3f26b39", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F\n", "\n", "\n", "def generate(model,\n", " tokenizer,\n", " prompt: str,\n", " n_tokens_to_gen: int = 50,\n", " sample: bool = True,\n", " top_k: int = 40):\n", " model.eval()\n", " \n", " input_ids = tokenizer(prompt, return_tensors='pt').input_ids\n", " \n", " for token_n in range(n_tokens_to_gen):\n", " with torch.no_grad():\n", " indices_to_input = input_ids\n", " next_token_logits = model(indices_to_input)[:, -1]\n", " \n", " probs = F.softmax(next_token_logits, dim=-1)\n", " (batch, vocab_size) = probs.shape\n", " \n", " if top_k is not None:\n", " (values, indices) = torch.topk(probs, k=top_k)\n", " probs[probs < values[:, -1, None]] = 0\n", " probs = probs / probs.sum(axis=1, keepdims=True)\n", " \n", " if sample:\n", " next_indices = torch.multinomial(probs, num_samples=1)\n", " else:\n", " next_indices = torch.argmax(probs, dim=-1)[:, None]\n", " \n", " input_ids = torch.cat([input_ids, next_indices], dim=1)\n", "\n", " output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n", " \n", " return output_completions" ] }, { "cell_type": "code", "execution_count": 10, "id": "ee877143-2042-4579-8042-a96db6200517", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)\n" ] } ], "source": [ "print(generate(model, tokenizer, 'Mamba is the'))" ] }, { "cell_type": "code", "execution_count": 9, "id": "65d70549-597f-49ca-9185-2184d2576f7d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "John: Hi!\n", "Sally: Hey!\n", "John: So, when's the wedding?\n", "Sally: We haven't decided.\n", "John: It's in September.\n", "Sally: Yeah, we were thinking July or\n", "August.\n", "John: I'm not too\n" ] } ], "source": [ "print(generate(model, tokenizer, 'John: Hi!\\nSally:'))" ] }, { "cell_type": "code", "execution_count": 8, "id": "6d419fc9-066b-4818-812c-2f1952528bc6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The meaning of life is \n", "just this: It is the best you can do.\n", "\n", "--K.J.\n", "\n", "And finally: How to handle your emotions. \n", "\n", "<|endoftext|>Q:\n", "\n", "Error creating an EntityManager instance in JavaEE 7\n", "\n", "This is\n" ] } ], "source": [ "print(generate(model, tokenizer, 'The meaning of life is '))" ] }, { "cell_type": "code", "execution_count": 11, "id": "2b189e6e-6a96-4770-88cf-7c5de22cb321", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def reverse_string(text, result):\n", " # find the position of the start of the string.\n", " start = text.index(text[0:-1])\n", " # find the position where the string begins changing.\n", " end = text.index\n" ] } ], "source": [ "print(generate(model, tokenizer, 'def reverse_string('))" ] }, { "cell_type": "code", "execution_count": null, "id": "be3afb51-5093-4c64-ac3f-43c2e6b20b10", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "6531acc0-b18f-472a-8e99-cee64dd51cd8", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d0efe197-891a-4ab8-8cea-413d1fb1acda", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "2e99509b-df7b-4bac-b6a2-669f601ec1c8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }