Spaces:
Runtime error
Runtime error
from huggingface_hub import InferenceClient | |
import gradio as gr | |
from deep_translator import GoogleTranslator | |
# Initialize the Hugging Face client with your model | |
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
# Define the translation functions | |
def translate_to_arabic(text): | |
return GoogleTranslator(source='auto', target='ar').translate(text) | |
def translate_to_english(text): | |
return GoogleTranslator(source='auto', target='en').translate(text) | |
# Format the prompt for the model | |
def format_prompt(message, history): | |
prompt = "<s>" | |
for user_prompt, bot_response in history: | |
prompt += f"[INST] {user_prompt} [/INST]" | |
prompt += f" {bot_response}</s> " | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
# Generate a response from the model | |
def generate(prompt, history=[], temperature=0.1, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0): | |
# Translate the Arabic prompt to English before sending to the model | |
prompt_in_english = translate_to_english(prompt) | |
formatted_prompt = format_prompt(prompt_in_english, history) | |
generate_kwargs = { | |
"temperature": temperature, | |
"max_new_tokens": max_new_tokens, | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"do_sample": True, | |
"seed": 42, # Seed for reproducibility, remove or change if randomness is preferred | |
} | |
# Generate the response | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response["token"]["text"] | |
# Translate the English response back to Arabic | |
response_in_arabic = translate_to_arabic(output) | |
return response_in_arabic | |
# Define additional inputs for Gradio interface | |
additional_inputs = [ | |
gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05), | |
gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=1048, step=64), | |
gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1.0, step=0.05), | |
gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05) | |
] | |
# Set up the Gradio interface | |
iface = gr.Interface( | |
fn=generate, | |
inputs=[ | |
gr.Textbox(lines=5, placeholder='Type your Arabic query here...', label='Arabic Query'), | |
*additional_inputs | |
], | |
outputs='text', | |
title="DorjGPT Arabic-English Translation Chatbot", | |
) | |
# Launch the Gradio interface | |
iface.launch() | |