51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
from model import Mamba, ModelArgs
|
|
from transformers import AutoTokenizer
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
# One of:
|
|
# 'state-spaces/mamba-2.8b-slimpj'
|
|
# 'state-spaces/mamba-2.8b'
|
|
# 'state-spaces/mamba-1.4b'
|
|
# 'state-spaces/mamba-790m'
|
|
# 'state-spaces/mamba-370m'
|
|
# 'state-spaces/mamba-130m'
|
|
pretrained_model_name = "state-spaces/mamba-370m"
|
|
|
|
model = Mamba.from_pretrained(pretrained_model_name)
|
|
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
|
|
|
|
|
def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 50, sample: bool = True, top_k: int = 40):
|
|
model.eval()
|
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
|
|
|
for token_n in range(n_tokens_to_gen):
|
|
with torch.no_grad():
|
|
indices_to_input = input_ids
|
|
next_token_logits = model(indices_to_input)[:, -1]
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1)
|
|
(batch, vocab_size) = probs.shape
|
|
|
|
if top_k is not None:
|
|
(values, indices) = torch.topk(probs, k=top_k)
|
|
probs[probs < values[:, -1, None]] = 0
|
|
probs = probs / probs.sum(axis=1, keepdims=True)
|
|
|
|
if sample:
|
|
next_indices = torch.multinomial(probs, num_samples=1)
|
|
else:
|
|
next_indices = torch.argmax(probs, dim=-1)[:, None]
|
|
|
|
input_ids = torch.cat([input_ids, next_indices], dim=1)
|
|
|
|
output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
|
|
|
|
return output_completions
|
|
|
|
|
|
print(generate(model, tokenizer, "Mamba is the"))
|