Spaces:
Runtime error
Runtime error
File size: 3,020 Bytes
5d906de db4c88c 74d73f1 db4c88c f61eeff f8e11cc 74d73f1 f61eeff f8e11cc db4c88c 74d73f1 f8e11cc db4c88c f8e11cc f61eeff db4c88c f8e11cc 77ac825 f8e11cc 58aa497 f61eeff 2d5895c 50eae3e f8e11cc 50eae3e f8e11cc f61eeff f8e11cc e18b2e5 f8e11cc f61eeff 5d906de f61eeff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import torch.nn.functional as F
import einops
from einops import rearrange
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16)
def pred(text_in, temperature, top_k, top_p, gen_length, cg, return_dict_in_generate, output_scores, enable_timing):
tokens = tokenizer(text_in, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)
attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + genlen
out = model.generate(
input_ids=input_ids,
max_length=max_length,
cg=cg,
return_dict_in_generate=return_dict_in_generate,
output_scores=output_scores,
enable_timing=enable_timing,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
text_out = tokenizer.batch_decode(out.sequences.tolist(), skip_special_tokens=True)
return text_out[0]
demo = gr.Interface(
fn=pred,
inputs=[
gr.Textbox(label="Input Text"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Temperature"),
gr.Slider(minimum=1, maximum=10, value=10, label="Top K"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
gr.Slider(minimum=50, maximum=650, value=200, label="Generation Length (gen_length)"),
gr.Checkbox(value=True, label="Cache Graph (cg)"),
gr.Checkbox(value=True, label="Return Dict in Generate"),
gr.Checkbox(value=True, label="Output Scores"),
gr.Checkbox(value=False, label="Enable Timing"),
],
outputs="text",
title="Welcome👋🏻to🌟Tonic's🐍Mamba 2.8B! 🚀",
description="""🐍Mamba is quite special because it uses a unique model architecture, has reasonable🏆performance, and a👌🏻tiny size. You can use this Space to test out the current model 🐍[state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) You can also use 🐍mamba-2.8b by cloning this space. Simply click here: [Duplicate Space](https://huggingface.co/spaces/Tonic1/VLChat?duplicate=true)
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to 🌟 [DataTonic](https://github.com/Tonic-AI/DataTonic) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""
)
if __name__ == "__main__":
demo.launch() |