Spaces:
Runtime error
Runtime error
from typing import Generator, List, Optional, Tuple, Union | |
import os | |
from dotenv import load_dotenv | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
# Load environment variables | |
load_dotenv() | |
# Read system prompt from file | |
prompt_path = os.path.join(os.path.dirname(__file__), "prompt.md") | |
with open(prompt_path, "r", encoding="utf-8") as f: | |
SYSTEM_PROMPT = f.read() | |
# Initialize HuggingFace client with token from env | |
client = InferenceClient(token=os.getenv("HUGGING_FACE_HUB_TOKEN")) | |
def format_prompt(message: str, history: List[Tuple[str, str]]) -> str: | |
"""Format the conversation prompt for the model. | |
Args: | |
message: The current user message. | |
history: List of past interactions (user, assistant). | |
Returns: | |
str: The formatted prompt for the model. | |
""" | |
prompt = f"<|system|>\n{SYSTEM_PROMPT}</s>" | |
for user, assistant in history: | |
prompt += f"<|user|>\n{user}</s><|assistant|>\n{assistant}</s>" | |
prompt += f"<|user|>\n{message}</s><|assistant|>\n" | |
return prompt | |
def predict( | |
message: str, | |
history: List[Tuple[str, str]], | |
temperature: float = 0.7, | |
max_new_tokens: int = 1024, | |
top_p: float = 0.9 | |
) -> str: | |
"""Generate a complete response to the user message. | |
Args: | |
message: The current user message. | |
history: List of past interactions. | |
temperature: Sampling temperature. | |
max_new_tokens: Maximum number of tokens to generate. | |
top_p: Top-p sampling parameter. | |
Returns: | |
str: Complete generated response. | |
""" | |
response = client.text_generation( | |
format_prompt(message, history), | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=1.2, | |
do_sample=True, | |
seed=42, | |
) | |
return response | |
# Main Gradio interface | |
with gr.Blocks(css="styles.css", title="AstroAlien 3000") as demo: | |
gr.Markdown("# 👽🛸 Bem-vindo ao AstroAlien 3000! 🌌✨", elem_classes="glass-card") | |
chatbot = gr.Chatbot(type="messages", elem_classes="glass-card") | |
with gr.Row(): | |
temperature = gr.Slider(0, 1, value=0.7, label="Temperatura", elem_classes="glass-card") | |
max_tokens = gr.Slider(1, 2048, value=1024, step=1, label="Máximos Tokens", elem_classes="glass-card") | |
top_p = gr.Slider(0, 1, value=0.9, label="Top-p", elem_classes="glass-card") | |
gr.ChatInterface( | |
predict, | |
chatbot=chatbot, | |
additional_inputs=[temperature, max_tokens, top_p], | |
type="messages" | |
) | |
demo.launch() |