|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import mimetypes |
|
import os |
|
import re |
|
import shutil |
|
from typing import Optional |
|
|
|
from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types |
|
from smolagents.agents import ActionStep, MultiStepAgent |
|
from smolagents.memory import MemoryStep |
|
from smolagents.utils import _is_package_available |
|
|
|
|
|
def pull_messages_from_step( |
|
step_log: MemoryStep, |
|
): |
|
"""Extract ChatMessage objects from agent steps with proper nesting""" |
|
import gradio as gr |
|
|
|
|
|
if isinstance(step_log, ActionStep): |
|
|
|
step_number = f"Step {step_log.step_number}" if step_log.step_number is not None else "" |
|
if step_log.step_number == 1: |
|
yield gr.ChatMessage(role="assistant", content=f"***Thinking***") |
|
yield gr.ChatMessage(role="assistant", content=f"**{step_number}**") |
|
|
|
|
|
if hasattr(step_log, "model_output") and step_log.model_output is not None: |
|
|
|
model_output = step_log.model_output.strip() |
|
|
|
model_output = re.sub(r"```\s*<end_code>", "```", model_output) |
|
model_output = re.sub(r"<end_code>\s*```", "```", model_output) |
|
model_output = re.sub(r"```\s*\n\s*<end_code>", "```", model_output) |
|
model_output = model_output.strip() |
|
yield gr.ChatMessage(role="assistant", content=model_output) |
|
|
|
|
|
if hasattr(step_log, "tool_calls") and step_log.tool_calls is not None: |
|
first_tool_call = step_log.tool_calls[0] |
|
used_code = first_tool_call.name == "python_interpreter" |
|
parent_id = f"call_{len(step_log.tool_calls)}" |
|
|
|
|
|
|
|
args = first_tool_call.arguments |
|
if isinstance(args, dict): |
|
content = str(args.get("answer", str(args))) |
|
else: |
|
content = str(args).strip() |
|
|
|
if used_code: |
|
|
|
content = re.sub(r"```.*?\n", "", content) |
|
content = re.sub(r"\s*<end_code>\s*", "", content) |
|
content = content.strip() |
|
if not content.startswith("```python"): |
|
content = f"```python\n{content}\n```" |
|
|
|
parent_message_tool = gr.ChatMessage( |
|
role="assistant", |
|
content=content, |
|
metadata={ |
|
"title": f"🛠️ Used tool {first_tool_call.name}", |
|
"id": parent_id, |
|
"status": "pending", |
|
}, |
|
) |
|
yield parent_message_tool |
|
|
|
|
|
if hasattr(step_log, "observations") and ( |
|
step_log.observations is not None and step_log.observations.strip() |
|
): |
|
log_content = step_log.observations.strip() |
|
if log_content: |
|
log_content = re.sub(r"^Execution logs:\s*", "", log_content) |
|
yield gr.ChatMessage( |
|
role="assistant", |
|
content=f"{log_content}", |
|
metadata={"title": "📝 Execution Logs", "parent_id": parent_id, "status": "done"}, |
|
) |
|
|
|
|
|
if hasattr(step_log, "error") and step_log.error is not None: |
|
yield gr.ChatMessage( |
|
role="assistant", |
|
content=str(step_log.error), |
|
metadata={"title": "💥 Error", "parent_id": parent_id, "status": "done"}, |
|
) |
|
|
|
|
|
parent_message_tool.metadata["status"] = "done" |
|
|
|
|
|
elif hasattr(step_log, "error") and step_log.error is not None: |
|
yield gr.ChatMessage(role="assistant", content=str(step_log.error), metadata={"title": "💥 Error"}) |
|
|
|
|
|
step_footnote = f"{step_number}" |
|
if hasattr(step_log, "input_token_count") and hasattr(step_log, "output_token_count"): |
|
token_str = ( |
|
f" | Input-tokens:{step_log.input_token_count:,} | Output-tokens:{step_log.output_token_count:,}" |
|
) |
|
step_footnote += token_str |
|
if hasattr(step_log, "duration"): |
|
step_duration = f" | Duration: {round(float(step_log.duration), 2)}" if step_log.duration else None |
|
step_footnote += step_duration |
|
step_footnote = f"""<span style="color: #bbbbc2; font-size: 12px;">{step_footnote}</span> """ |
|
yield gr.ChatMessage(role="assistant", content=f"{step_footnote}") |
|
yield gr.ChatMessage(role="assistant", content="-----") |
|
|
|
|
|
def stream_to_gradio( |
|
agent, |
|
task: str, |
|
reset_agent_memory: bool = False, |
|
additional_args: Optional[dict] = None, |
|
): |
|
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" |
|
if not _is_package_available("gradio"): |
|
raise ModuleNotFoundError( |
|
"Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`" |
|
) |
|
import gradio as gr |
|
|
|
total_input_tokens = 0 |
|
total_output_tokens = 0 |
|
|
|
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): |
|
|
|
if hasattr(agent.model, "last_input_token_count"): |
|
total_input_tokens += agent.model.last_input_token_count |
|
total_output_tokens += agent.model.last_output_token_count |
|
if isinstance(step_log, ActionStep): |
|
step_log.input_token_count = agent.model.last_input_token_count |
|
step_log.output_token_count = agent.model.last_output_token_count |
|
|
|
for message in pull_messages_from_step( |
|
step_log, |
|
): |
|
yield message |
|
|
|
final_answer = step_log |
|
final_answer = handle_agent_output_types(final_answer) |
|
|
|
if isinstance(final_answer, AgentText): |
|
yield gr.ChatMessage( |
|
role="assistant", |
|
content=f"**Final answer:**\n{final_answer.to_string()}\n", |
|
) |
|
elif isinstance(final_answer, AgentImage): |
|
yield gr.ChatMessage( |
|
role="assistant", |
|
content={"path": final_answer.to_string(), "mime_type": "image/png"}, |
|
) |
|
elif isinstance(final_answer, AgentAudio): |
|
yield gr.ChatMessage( |
|
role="assistant", |
|
content={"path": final_answer.to_string(), "mime_type": "audio/wav"}, |
|
) |
|
else: |
|
yield gr.ChatMessage(role="assistant", content=f"***Final answer:*** {str(final_answer)}") |
|
|
|
|
|
class GradioUI: |
|
"""A one-line interface to launch your agent in Gradio""" |
|
|
|
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None): |
|
if not _is_package_available("gradio"): |
|
raise ModuleNotFoundError( |
|
"Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`" |
|
) |
|
self.agent = agent |
|
self.file_upload_folder = file_upload_folder |
|
if self.file_upload_folder is not None: |
|
if not os.path.exists(file_upload_folder): |
|
os.mkdir(file_upload_folder) |
|
|
|
def interact_with_agent(self, prompt, messages): |
|
import gradio as gr |
|
|
|
|
|
messages.append(gr.ChatMessage(role="user", content=prompt)) |
|
yield messages |
|
|
|
|
|
loading_msg = gr.ChatMessage( |
|
role="assistant", |
|
content="🤔 Thinking...", |
|
metadata={"class": "loading-message"} |
|
) |
|
messages.append(loading_msg) |
|
yield messages |
|
|
|
|
|
for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False): |
|
|
|
if len(messages) > 0 and messages[-1].metadata.get("class") == "loading-message": |
|
messages.pop() |
|
messages.append(msg) |
|
yield messages |
|
|
|
|
|
if len(messages) > 0 and messages[-1].metadata.get("class") == "loading-message": |
|
messages.pop() |
|
yield messages |
|
|
|
def upload_file( |
|
self, |
|
file, |
|
file_uploads_log, |
|
allowed_file_types=[ |
|
"application/pdf", |
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", |
|
"text/plain", |
|
], |
|
): |
|
""" |
|
Handle file uploads, default allowed types are .pdf, .docx, and .txt |
|
""" |
|
import gradio as gr |
|
|
|
if file is None: |
|
return gr.Textbox("No file uploaded", visible=True), file_uploads_log |
|
|
|
try: |
|
mime_type, _ = mimetypes.guess_type(file.name) |
|
except Exception as e: |
|
return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log |
|
|
|
if mime_type not in allowed_file_types: |
|
return gr.Textbox("File type disallowed", visible=True), file_uploads_log |
|
|
|
|
|
original_name = os.path.basename(file.name) |
|
sanitized_name = re.sub( |
|
r"[^\w\-.]", "_", original_name |
|
) |
|
|
|
type_to_ext = {} |
|
for ext, t in mimetypes.types_map.items(): |
|
if t not in type_to_ext: |
|
type_to_ext[t] = ext |
|
|
|
|
|
sanitized_name = sanitized_name.split(".")[:-1] |
|
sanitized_name.append("" + type_to_ext[mime_type]) |
|
sanitized_name = "".join(sanitized_name) |
|
|
|
|
|
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name)) |
|
shutil.copy(file.name, file_path) |
|
|
|
return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path] |
|
|
|
def log_user_message(self, text_input, file_uploads_log): |
|
return ( |
|
text_input |
|
+ ( |
|
f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}" |
|
if len(file_uploads_log) > 0 |
|
else "" |
|
), |
|
"", |
|
) |
|
|
|
def launch(self, **kwargs): |
|
import gradio as gr |
|
|
|
with gr.Blocks( |
|
title="Festival Finding Agent", |
|
fill_height=True, |
|
css=""" |
|
.loading-message { |
|
animation: pulse 2s cubic-bezier(.4,0,.6,1) infinite; |
|
} |
|
@keyframes pulse { |
|
0%, 100% { opacity: 1; } |
|
50% { opacity: 0.5; } |
|
} |
|
""" |
|
) as demo: |
|
gr.Markdown( |
|
""" |
|
# Multimodal Festival Finding Agent |
|
I can help you find information and generate images about upcoming festivals, fair and celebration events in the location of your interest. Ask me for the recommendation on the upcoming festivals and their details! |
|
""" |
|
) |
|
stored_messages = gr.State([]) |
|
file_uploads_log = gr.State([]) |
|
chatbot = gr.Chatbot( |
|
label="Agent", |
|
type="messages", |
|
avatar_images=( |
|
None, |
|
"https://pbs.twimg.com/profile_images/1443448585870139392/A-BW_5Yd_400x400.jpg", |
|
), |
|
resizeable=True, |
|
scale=1, |
|
) |
|
text_input = gr.Textbox( |
|
label="Ask me anything about festivals, fairs and celebrations", |
|
placeholder="Type your question here...", |
|
) |
|
submit_button = gr.Button("Send") |
|
gr.Examples( |
|
examples=[ |
|
"Generate the image of the Chinese New Year's festival", |
|
"Can you recommend me a festival in Tokyo that will happen in June 2025?", |
|
"How does the Toronto Black Film Festival look like?", |
|
"What is the most popular celebration event in Lithuania during spring?", |
|
"What is the soonest upcoming festival in Toronto, Canada?" |
|
], |
|
inputs=text_input, |
|
label="Example Questions" |
|
) |
|
text_input.submit( |
|
self.log_user_message, |
|
[text_input, file_uploads_log], |
|
[stored_messages, text_input], |
|
).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot]) |
|
|
|
submit_button.click( |
|
self.log_user_message, |
|
[text_input, file_uploads_log], |
|
[stored_messages, text_input], |
|
).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot]) |
|
|
|
demo.launch(debug=True, share=True, **kwargs) |
|
|
|
|
|
__all__ = ["stream_to_gradio", "GradioUI"] |