Spaces:
Sleeping
Sleeping
from smolagents import Tool, CodeAgent, HfApiModel, OpenAIServerModel | |
import dotenv | |
from ac_tools import DuckDuckGoSearchToolWH | |
import requests | |
import os | |
from PIL import Image | |
import wikipedia | |
from transformers import pipeline | |
import requests | |
from bs4 import BeautifulSoup | |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
def init_agent(): | |
dotenv.load_dotenv() | |
model = OpenAIServerModel(model_id="gpt-4o") | |
agent = BasicSmolAgent(model=model) | |
return agent | |
def download_file(task_id: str, filename: str) -> str: | |
""" | |
Downloads a file associated with the given task_id and saves it to the specified filename. | |
Args: | |
task_id (str): The task identifier used to fetch the file. | |
filename (str): The desired filename to save the file as. | |
Returns: | |
str: The absolute path to the saved file. | |
""" | |
api_url = DEFAULT_API_URL | |
file_url = f"{api_url}/files/{task_id}" | |
folder = 'data' | |
print(f"📡 Fetching file from: {file_url}") | |
try: | |
response = requests.get(file_url, timeout=15) | |
response.raise_for_status() | |
# Save binary content to the given filename | |
fpath = os.path.join(folder, filename) | |
with open(fpath, "wb") as f: | |
f.write(response.content) | |
abs_path = os.path.abspath(fpath) | |
print(f"✅ File saved as: {abs_path}") | |
return abs_path | |
except requests.exceptions.RequestException as e: | |
error_msg = f"❌ Failed to download file for task {task_id}: {e}" | |
print(error_msg) | |
raise RuntimeError(error_msg) | |
class WikipediaSearchTool(Tool): | |
name = "wikipedia_search" | |
description = "Searches Wikipedia and returns a short summary of the most relevant article." | |
inputs = { | |
"query": {"type": "string", "description": "The search term or topic to look up on Wikipedia."} | |
} | |
output_type = "string" | |
def __init__(self, summary_sentences=3): | |
super().__init__() | |
self.summary_sentences = summary_sentences | |
def forward(self, query: str) -> str: | |
try: | |
page_title = wikipedia.search(query)[0] | |
page = wikipedia.page(page_title) | |
return f"**{page.title}**\n\n{page.content}" | |
except IndexError: | |
return "No Wikipedia results found for that query." | |
except Exception as e: | |
return f"Error during Wikipedia search: {e}" | |
class WebpageReaderTool(Tool): | |
name = "read_webpage" | |
description = "Fetches the text content from a given URL and returns the main body text." | |
inputs = { | |
"url": {"type": "string", "description": "The URL of the webpage to read."} | |
} | |
output_type = "string" | |
def forward(self, url: str) -> str: | |
try: | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " | |
"(KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36" | |
} | |
response = requests.get(url, headers=headers, timeout=10) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.text, "html.parser") | |
# Extract visible text (ignore scripts/styles) | |
for tag in soup(["script", "style", "noscript"]): | |
tag.extract() | |
text = soup.get_text(separator="\n") | |
cleaned = "\n".join(line.strip() for line in text.splitlines() if line.strip()) | |
return cleaned[:5000] # Optionally limit to 5,000 chars | |
except Exception as e: | |
return f"Error reading webpage: {e}" | |
class BasicAgent: | |
def __init__(self): | |
print("BasicAgent initialized.") | |
def __call__(self, question_item: dict) -> str: | |
task_id = question_item.get("task_id") | |
question_text = question_item.get("question") | |
file_name = question_item.get("file_name") | |
print(f"Agent received question (first 50 chars): {question_text[:50]}...") | |
fixed_answer = "This is a default answer." | |
print(f"Agent returning fixed answer: {fixed_answer}") | |
return fixed_answer | |
class BasicSmolAgent: | |
def __init__(self, model=None): | |
print("BasicSmolAgent initialized.") | |
if not model: | |
model = HfApiModel() | |
search_tool = DuckDuckGoSearchToolWH() | |
wiki_tool = WikipediaSearchTool() | |
webpage_tool = WebpageReaderTool() | |
self.agent = CodeAgent(tools=[search_tool, wiki_tool, webpage_tool], model=model, max_steps=10, additional_authorized_imports=['pandas']) | |
self.prompt = ("The question is the following:\n ```{}```" | |
" YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings." | |
" If you are asked for a number, don't use comma to write your number neither use units" | |
" such as $ or percent sign unless specified otherwise. If you are asked for a string," | |
" don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise." | |
" If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string." | |
) | |
# Load the Whisper pipeline | |
self.mp3_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-tiny") | |
def get_logs(self): | |
return self.agent.memory.steps | |
def __call__(self, question_item: dict) -> str: | |
task_id = question_item.get("task_id") | |
question_text = question_item.get("question") | |
file_name = question_item.get("file_name") | |
print(f"Agent received question (first 50 chars): {question_text[:50]}...") | |
prompted_question = self.prompt.format(question_text) | |
images = [] | |
if file_name: | |
fpath = download_file(task_id, file_name) | |
if fpath.endswith('.png'): | |
image = Image.open(fpath).convert("RGB") | |
images.append(image) | |
if fpath.endswith('xlsx') or fpath.endswith('.py'): | |
data = open(fpath, "rb").read().decode("utf-8", errors="ignore") | |
prompted_question += f"\nThere is textual data included with the question, it is from a file {fpath} and is: ```{data}```" | |
if fpath.endswith('.mp3'): | |
try: | |
result = self.mp3_pipe(fpath, return_timestamps=True) | |
text = result["text"] | |
prompted_question += f"\nThere is textual data included with the question, it is from a file {fpath} and is: ```{text}```" | |
except Exception as e: | |
print("Exception occurred during mp3 transcription: ", e) | |
result = self.agent.run(prompted_question, images=images) | |
print(f"Agent returning answer: {result}") | |
return result | |