Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import os | |
title = "# 🙋🏻♂️Tonic's ✒️InkubaLM-0.4B" | |
description = """✒️InkubaLM has been trained from scratch using 1.9 billion tokens of data for five African languages, along with English and French data, totaling 2.4 billion tokens of data. It is capable of understanding and generating content in five African languages: Swahili, Yoruba, Hausa, isiZulu, and isiXhosa, as well as English and French. | |
### Join us : | |
🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [](https://discord.gg/GWpVpekp) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗 | |
""" | |
hf_token = os.getenv("HF_TOKEN") | |
# Load the model and tokenizer | |
model_name = "lelapa/InkubaLM-0.4B" | |
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, token=hf_token) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=hf_token) | |
# Move model to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
def generate_text(prompt, max_length, repetition_penalty, temperature): | |
# Tokenize the input and create attention mask | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
input_ids = inputs.input_ids | |
attention_mask = inputs.attention_mask | |
# Generate the text using the model, with the attention mask and temperature | |
outputs = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_length=max_length, | |
repetition_penalty=repetition_penalty, | |
temperature=temperature, | |
pad_token_id=tokenizer.eos_token_id, | |
do_sample=True | |
) | |
# Decode the generated tokens and return the result | |
generated_text = tokenizer.batch_decode(outputs[:, input_ids.shape[1]:-1], skip_special_tokens=True)[0].strip() | |
return generated_text | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Enter your prompt here:", placeholder="Today I planned to ...") | |
max_length = gr.Slider(label="Max Length", minimum=70, maximum=1000, step=50, value=200) | |
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.2) | |
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.5) # Added slider for temperature | |
submit_button = gr.Button("Generate") | |
with gr.Column(): | |
output = gr.Textbox(label="✒️Inkuba.4B:") | |
submit_button.click(generate_text, inputs=[prompt, max_length, repetition_penalty, temperature], outputs=output) | |
# Launch the demo | |
demo.launch() |