File size: 6,966 Bytes
a9491cd
 
 
 
 
 
 
7ea0bd8
 
a9491cd
7ea0bd8
 
a9491cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ea0bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9491cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ea0bd8
 
 
86676cc
 
 
 
 
 
 
a9491cd
e929245
a9491cd
86676cc
 
 
a9491cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3fb675
a9491cd
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

from smolagents import Tool, CodeAgent, HfApiModel, OpenAIServerModel
import dotenv
from ac_tools import DuckDuckGoSearchToolWH
import requests
import os
from PIL import Image
import wikipedia

from transformers import pipeline
import requests
from bs4 import BeautifulSoup


DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"


def init_agent():
    dotenv.load_dotenv()
    model = OpenAIServerModel(model_id="gpt-4o")
    agent = BasicSmolAgent(model=model)
    return agent


def download_file(task_id: str, filename: str) -> str:
    """
    Downloads a file associated with the given task_id and saves it to the specified filename.

    Args:
        task_id (str): The task identifier used to fetch the file.
        filename (str): The desired filename to save the file as.

    Returns:
        str: The absolute path to the saved file.
    """
    api_url = DEFAULT_API_URL
    file_url = f"{api_url}/files/{task_id}"
    folder = 'data'
    print(f"📡 Fetching file from: {file_url}")
    try:
        response = requests.get(file_url, timeout=15)
        response.raise_for_status()

        # Save binary content to the given filename
        fpath = os.path.join(folder, filename)
        with open(fpath, "wb") as f:
            f.write(response.content)

        abs_path = os.path.abspath(fpath)
        print(f"✅ File saved as: {abs_path}")
        return abs_path

    except requests.exceptions.RequestException as e:
        error_msg = f"❌ Failed to download file for task {task_id}: {e}"
        print(error_msg)
        raise RuntimeError(error_msg)


class WikipediaSearchTool(Tool):
    name = "wikipedia_search"
    description = "Searches Wikipedia and returns a short summary of the most relevant article."
    inputs = {
        "query": {"type": "string", "description": "The search term or topic to look up on Wikipedia."}
    }
    output_type = "string"

    def __init__(self, summary_sentences=3):
        super().__init__()
        self.summary_sentences = summary_sentences

    def forward(self, query: str) -> str:
        try:
            page_title = wikipedia.search(query)[0]
            page = wikipedia.page(page_title)
            return f"**{page.title}**\n\n{page.content}"
        except IndexError:
            return "No Wikipedia results found for that query."
        except Exception as e:
            return f"Error during Wikipedia search: {e}"


class WebpageReaderTool(Tool):
    name = "read_webpage"
    description = "Fetches the text content from a given URL and returns the main body text."
    inputs = {
        "url": {"type": "string", "description": "The URL of the webpage to read."}
    }
    output_type = "string"

    def forward(self, url: str) -> str:
        try:
            headers = {
                "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
                              "(KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36"
            }
            response = requests.get(url, headers=headers, timeout=10)
            response.raise_for_status()

            soup = BeautifulSoup(response.text, "html.parser")

            # Extract visible text (ignore scripts/styles)
            for tag in soup(["script", "style", "noscript"]):
                tag.extract()
            text = soup.get_text(separator="\n")
            cleaned = "\n".join(line.strip() for line in text.splitlines() if line.strip())

            return cleaned[:5000]  # Optionally limit to 5,000 chars
        except Exception as e:
            return f"Error reading webpage: {e}"


class BasicAgent:
    def __init__(self):
        print("BasicAgent initialized.")
    def __call__(self, question_item: dict) -> str:
        task_id = question_item.get("task_id")
        question_text = question_item.get("question")
        file_name = question_item.get("file_name")
        print(f"Agent received question (first 50 chars): {question_text[:50]}...")
        fixed_answer = "This is a default answer."
        print(f"Agent returning fixed answer: {fixed_answer}")
        return fixed_answer


class BasicSmolAgent:
    def __init__(self, model=None):
        print("BasicSmolAgent initialized.")
        if not model:
            model = HfApiModel()
        search_tool = DuckDuckGoSearchToolWH()
        wiki_tool = WikipediaSearchTool()
        webpage_tool = WebpageReaderTool()
        self.agent = CodeAgent(tools=[search_tool, wiki_tool, webpage_tool], model=model, max_steps=10, additional_authorized_imports=['pandas'])
        self.prompt = ("The question is the following:\n ```{}```"
                       " YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings."
                       " If you are asked for a number, don't use comma to write your number neither use units"
                       " such as $ or percent sign unless specified otherwise. If you are asked for a string,"
                       " don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise."
                       " If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."
                       )
        # Load the Whisper pipeline
        self.mp3_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")

    def get_logs(self):
        return self.agent.memory.steps

    def __call__(self, question_item: dict) -> str:
        task_id = question_item.get("task_id")
        question_text = question_item.get("question")
        file_name = question_item.get("file_name")

        print(f"Agent received question (first 50 chars): {question_text[:50]}...")
        prompted_question = self.prompt.format(question_text)

        images = []
        if file_name:
            fpath = download_file(task_id, file_name)
            if fpath.endswith('.png'):
                image = Image.open(fpath).convert("RGB")
                images.append(image)
            if fpath.endswith('xlsx') or fpath.endswith('.py'):
                data = open(fpath, "rb").read().decode("utf-8", errors="ignore")
                prompted_question += f"\nThere is textual data included with the question, it is from a file {fpath} and is: ```{data}```"
            if fpath.endswith('.mp3'):
                try:
                    result = self.mp3_pipe(fpath, return_timestamps=True)
                    text = result["text"]
                    prompted_question += f"\nThere is textual data included with the question, it is from a file {fpath} and is: ```{text}```"
                except Exception as e:
                    print("Exception occurred during mp3 transcription: ", e)

        result = self.agent.run(prompted_question, images=images)
        print(f"Agent returning answer: {result}")
        return result