from model import Mamba, ModelArgs from transformers import AutoTokenizer import torch import torch.nn.functional as F # One of: # 'state-spaces/mamba-2.8b-slimpj' # 'state-spaces/mamba-2.8b' # 'state-spaces/mamba-1.4b' # 'state-spaces/mamba-790m' # 'state-spaces/mamba-370m' # 'state-spaces/mamba-130m' pretrained_model_name = "state-spaces/mamba-370m" model = Mamba.from_pretrained(pretrained_model_name) tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 50, sample: bool = True, top_k: int = 40): model.eval() input_ids = tokenizer(prompt, return_tensors="pt").input_ids for token_n in range(n_tokens_to_gen): with torch.no_grad(): indices_to_input = input_ids next_token_logits = model(indices_to_input)[:, -1] probs = F.softmax(next_token_logits, dim=-1) (batch, vocab_size) = probs.shape if top_k is not None: (values, indices) = torch.topk(probs, k=top_k) probs[probs < values[:, -1, None]] = 0 probs = probs / probs.sum(axis=1, keepdims=True) if sample: next_indices = torch.multinomial(probs, num_samples=1) else: next_indices = torch.argmax(probs, dim=-1)[:, None] input_ids = torch.cat([input_ids, next_indices], dim=1) output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0] return output_completions print(generate(model, tokenizer, "Mamba is the"))