## (1) Load model

In [12]:
from model import Mamba, ModelArgs
from transformers import AutoTokenizer

# One of:
# 'state-spaces/mamba-2.8b-slimpj'
# 'state-spaces/mamba-2.8b'
# 'state-spaces/mamba-1.4b'
# 'state-spaces/mamba-790m'
# 'state-spaces/mamba-370m'
# 'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-370m'

model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

## (2) Generate Text

In [3]:
import torch
import torch.nn.functional as F


def generate(model,
 tokenizer,
 prompt: str,
 n_tokens_to_gen: int = 50,
 sample: bool = True,
 top_k: int = 40):
 model.eval()
 
 input_ids = tokenizer(prompt, return_tensors='pt').input_ids
 
 for token_n in range(n_tokens_to_gen):
 with torch.no_grad():
 indices_to_input = input_ids
 next_token_logits = model(indices_to_input)[:, -1]
 
 probs = F.softmax(next_token_logits, dim=-1)
 (batch, vocab_size) = probs.shape
 
 if top_k is not None:
 (values, indices) = torch.topk(probs, k=top_k)
 probs[probs < values[:, -1, None]] = 0
 probs = probs / probs.sum(axis=1, keepdims=True)
 
 if sample:
 next_indices = torch.multinomial(probs, num_samples=1)
 else:
 next_indices = torch.argmax(probs, dim=-1)[:, None]
 
 input_ids = torch.cat([input_ids, next_indices], dim=1)

 output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
 
 return output_completions

In [10]:
print(generate(model, tokenizer, 'Mamba is the'))

Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)


In [9]:
print(generate(model, tokenizer, 'John: Hi!\nSally:'))

John: Hi!
Sally: Hey!
John: So, when's the wedding?
Sally: We haven't decided.
John: It's in September.
Sally: Yeah, we were thinking July or
August.
John: I'm not too


In [8]:
print(generate(model, tokenizer, 'The meaning of life is '))

The meaning of life is 
just this: It is the best you can do.

--K.J.

And finally: How to handle your emotions. 

<|endoftext|>Q:

Error creating an EntityManager instance in JavaEE 7

This is


In [11]:
print(generate(model, tokenizer, 'def reverse_string('))

def reverse_string(text, result):
 # find the position of the start of the string.
 start = text.index(text[0:-1])
 # find the position where the string begins changing.
 end = text.index
