Upload 8 files
Browse files- agents/agent.py +60 -10
- agents/configuration.py +64 -0
- agents/graph.py +316 -0
- agents/prompts.py +102 -0
- agents/state.py +50 -0
- agents/tools_and_schemas.py +23 -0
- agents/utils.py +166 -0
agents/agent.py
CHANGED
@@ -4,6 +4,7 @@ import logging
|
|
4 |
from google import genai
|
5 |
from google.genai.types import GenerateContentConfig
|
6 |
from ratelimit import limits, sleep_and_retry
|
|
|
7 |
|
8 |
RPM = 15
|
9 |
TPM = 1_000_000
|
@@ -13,10 +14,11 @@ SYSTEM_PROMPT_GAIA = "You are a general AI assistant. I will ask you a question.
|
|
13 |
logging.basicConfig(
|
14 |
level=logging.INFO,
|
15 |
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
16 |
-
datefmt="%Y-%m-%d %H:%M:%S"
|
17 |
)
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
|
|
20 |
class BasicAgent:
|
21 |
def __init__(self):
|
22 |
logger.info("BasicAgent initialized.")
|
@@ -27,8 +29,11 @@ class BasicAgent:
|
|
27 |
logger.info(f"Agent returning fixed answer: {fixed_answer}")
|
28 |
return fixed_answer
|
29 |
|
|
|
30 |
class SimpleGeminiAgent(BasicAgent):
|
31 |
-
def __init__(
|
|
|
|
|
32 |
super().__init__()
|
33 |
gemini_key = os.getenv("GEMINI_API_KEY")
|
34 |
self.client = genai.Client(api_key=gemini_key)
|
@@ -45,7 +50,7 @@ class SimpleGeminiAgent(BasicAgent):
|
|
45 |
if now - self.minute_start >= 60:
|
46 |
self.tokens_this_minute = 0
|
47 |
self.minute_start = now
|
48 |
-
|
49 |
# Enforce tokens per minute
|
50 |
if self.tokens_this_minute + self.token_count > TPM:
|
51 |
sleep_time = max(0, 60 - (now - self.minute_start))
|
@@ -53,18 +58,63 @@ class SimpleGeminiAgent(BasicAgent):
|
|
53 |
self.tokens_this_minute = 0
|
54 |
self.minute_start = time.time()
|
55 |
|
56 |
-
response = self.client.models.generate_content(
|
57 |
-
|
58 |
-
|
|
|
|
|
59 |
self.tokens_this_minute += response.usage_metadata.total_token_count
|
60 |
self.token_count += response.usage_metadata.total_token_count
|
61 |
-
logger.info(
|
|
|
|
|
62 |
logger.info(f"AdvancedAgent returning answer: {response.text}")
|
63 |
return response.text
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
if __name__ == "__main__":
|
66 |
# Example usage
|
67 |
-
agent = SimpleGeminiAgent()
|
68 |
-
|
|
|
|
|
69 |
answer = agent(question)
|
70 |
-
print(f"Question: {question}\nAnswer: {answer}")
|
|
|
4 |
from google import genai
|
5 |
from google.genai.types import GenerateContentConfig
|
6 |
from ratelimit import limits, sleep_and_retry
|
7 |
+
from graph import build_graph
|
8 |
|
9 |
RPM = 15
|
10 |
TPM = 1_000_000
|
|
|
14 |
logging.basicConfig(
|
15 |
level=logging.INFO,
|
16 |
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
17 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
18 |
)
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
+
|
22 |
class BasicAgent:
|
23 |
def __init__(self):
|
24 |
logger.info("BasicAgent initialized.")
|
|
|
29 |
logger.info(f"Agent returning fixed answer: {fixed_answer}")
|
30 |
return fixed_answer
|
31 |
|
32 |
+
|
33 |
class SimpleGeminiAgent(BasicAgent):
|
34 |
+
def __init__(
|
35 |
+
self, model="gemini-2.5-flash-preview-05-20"
|
36 |
+
): # model="gemini-2.0-flash"):
|
37 |
super().__init__()
|
38 |
gemini_key = os.getenv("GEMINI_API_KEY")
|
39 |
self.client = genai.Client(api_key=gemini_key)
|
|
|
50 |
if now - self.minute_start >= 60:
|
51 |
self.tokens_this_minute = 0
|
52 |
self.minute_start = now
|
53 |
+
|
54 |
# Enforce tokens per minute
|
55 |
if self.tokens_this_minute + self.token_count > TPM:
|
56 |
sleep_time = max(0, 60 - (now - self.minute_start))
|
|
|
58 |
self.tokens_this_minute = 0
|
59 |
self.minute_start = time.time()
|
60 |
|
61 |
+
response = self.client.models.generate_content(
|
62 |
+
model=self.model,
|
63 |
+
contents=question,
|
64 |
+
config=GenerateContentConfig(system_instruction=SYSTEM_PROMPT_GAIA),
|
65 |
+
)
|
66 |
self.tokens_this_minute += response.usage_metadata.total_token_count
|
67 |
self.token_count += response.usage_metadata.total_token_count
|
68 |
+
logger.info(
|
69 |
+
f"AdvancedAgent received question (first 50 chars): {question[:50]}..."
|
70 |
+
)
|
71 |
logger.info(f"AdvancedAgent returning answer: {response.text}")
|
72 |
return response.text
|
73 |
|
74 |
+
|
75 |
+
class DeepResearchGeminiAgent(BasicAgent):
|
76 |
+
def __init__(self): # model="gemini-2.0-flash"):
|
77 |
+
super().__init__()
|
78 |
+
self.graph = build_graph()
|
79 |
+
logger.info("Deep Research Agent initialized.")
|
80 |
+
self.minute_start = time.time()
|
81 |
+
self.tokens_this_minute = 0
|
82 |
+
self.token_count = 0
|
83 |
+
|
84 |
+
@sleep_and_retry
|
85 |
+
@limits(calls=RPM, period=PER_MINUTE)
|
86 |
+
def __call__(self, question: str) -> str:
|
87 |
+
now = time.time()
|
88 |
+
if now - self.minute_start >= 60:
|
89 |
+
self.tokens_this_minute = 0
|
90 |
+
self.minute_start = now
|
91 |
+
|
92 |
+
# Enforce tokens per minute
|
93 |
+
if self.tokens_this_minute + self.token_count > TPM:
|
94 |
+
sleep_time = max(0, 60 - (now - self.minute_start))
|
95 |
+
time.sleep(sleep_time)
|
96 |
+
self.tokens_this_minute = 0
|
97 |
+
self.minute_start = time.time()
|
98 |
+
|
99 |
+
inputs = {"messages": [{"role": "user", "content": question}]}
|
100 |
+
|
101 |
+
output = self.graph.invoke(inputs)
|
102 |
+
final_answer_message = output["messages"][-1]
|
103 |
+
|
104 |
+
self.tokens_this_minute += final_answer_message.usage_metadata["total_tokens"]
|
105 |
+
self.token_count += final_answer_message.usage_metadata["total_tokens"]
|
106 |
+
logger.info(
|
107 |
+
f"AdvancedAgent received question (first 50 chars): {question[:50]}..."
|
108 |
+
)
|
109 |
+
logger.info(f"AdvancedAgent returning answer: {final_answer_message.content}")
|
110 |
+
return final_answer_message.content
|
111 |
+
|
112 |
+
|
113 |
if __name__ == "__main__":
|
114 |
# Example usage
|
115 |
+
# agent = SimpleGeminiAgent()
|
116 |
+
agent = DeepResearchGeminiAgent()
|
117 |
+
# question = "What is the capital of France?"
|
118 |
+
question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia"
|
119 |
answer = agent(question)
|
120 |
+
print(f"Question: {question}\nAnswer: {answer}")
|
agents/configuration.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
from typing import Any, Optional
|
4 |
+
|
5 |
+
from langchain_core.runnables import RunnableConfig
|
6 |
+
|
7 |
+
|
8 |
+
class Configuration(BaseModel):
|
9 |
+
"""The configuration for the agent."""
|
10 |
+
|
11 |
+
query_generator_model: str = Field(
|
12 |
+
# default="gemini-2.0-flash",
|
13 |
+
default="gemini-2.5-flash-preview-05-20",
|
14 |
+
metadata={
|
15 |
+
"description": "The name of the language model to use for the agent's query generation."
|
16 |
+
},
|
17 |
+
)
|
18 |
+
|
19 |
+
reasoning_model: str = Field(
|
20 |
+
default="gemini-2.5-flash-preview-05-20",
|
21 |
+
# default="gemini-2.5-pro-experimental-03-25",
|
22 |
+
metadata={
|
23 |
+
"description": "The name of the language model to use for the agent's reflection."
|
24 |
+
},
|
25 |
+
)
|
26 |
+
|
27 |
+
answer_model: str = Field(
|
28 |
+
default="gemini-2.5-flash-preview-05-20",
|
29 |
+
# default="gemini-2.5-pro-preview-05-06",
|
30 |
+
# default="gemini-2.5-flash-preview-05-20",
|
31 |
+
metadata={
|
32 |
+
"description": "The name of the language model to use for the agent's answer."
|
33 |
+
},
|
34 |
+
)
|
35 |
+
|
36 |
+
number_of_initial_queries: int = Field(
|
37 |
+
default=5,
|
38 |
+
metadata={"description": "The number of initial search queries to generate."},
|
39 |
+
)
|
40 |
+
|
41 |
+
max_research_loops: int = Field(
|
42 |
+
default=5,
|
43 |
+
metadata={"description": "The maximum number of research loops to perform."},
|
44 |
+
)
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def from_runnable_config(
|
48 |
+
cls, config: Optional[RunnableConfig] = None
|
49 |
+
) -> "Configuration":
|
50 |
+
"""Create a Configuration instance from a RunnableConfig."""
|
51 |
+
configurable = (
|
52 |
+
config["configurable"] if config and "configurable" in config else {}
|
53 |
+
)
|
54 |
+
|
55 |
+
# Get raw values from environment or config
|
56 |
+
raw_values: dict[str, Any] = {
|
57 |
+
name: os.environ.get(name.upper(), configurable.get(name))
|
58 |
+
for name in cls.model_fields.keys()
|
59 |
+
}
|
60 |
+
|
61 |
+
# Filter out None values
|
62 |
+
values = {k: v for k, v in raw_values.items() if v is not None}
|
63 |
+
|
64 |
+
return cls(**values)
|
agents/graph.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from tools_and_schemas import SearchQueryList, Reflection
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from langchain_core.messages import AIMessage
|
6 |
+
from langgraph.types import Send
|
7 |
+
from langgraph.graph import StateGraph
|
8 |
+
from langgraph.graph import START, END
|
9 |
+
from langchain_core.runnables import RunnableConfig
|
10 |
+
from google.genai import Client
|
11 |
+
|
12 |
+
from state import (
|
13 |
+
OverallState,
|
14 |
+
QueryGenerationState,
|
15 |
+
ReflectionState,
|
16 |
+
WebSearchState,
|
17 |
+
)
|
18 |
+
from configuration import Configuration
|
19 |
+
from prompts import (
|
20 |
+
get_current_date,
|
21 |
+
query_writer_instructions,
|
22 |
+
web_searcher_instructions,
|
23 |
+
reflection_instructions,
|
24 |
+
answer_instructions,
|
25 |
+
gaia_system_instructions,
|
26 |
+
)
|
27 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
28 |
+
from utils import (
|
29 |
+
get_citations,
|
30 |
+
get_research_topic,
|
31 |
+
insert_citation_markers,
|
32 |
+
resolve_urls,
|
33 |
+
)
|
34 |
+
|
35 |
+
load_dotenv()
|
36 |
+
|
37 |
+
if os.getenv("GEMINI_API_KEY") is None:
|
38 |
+
raise ValueError("GEMINI_API_KEY is not set")
|
39 |
+
|
40 |
+
# Used for Google Search API
|
41 |
+
genai_client = Client(api_key=os.getenv("GEMINI_API_KEY"))
|
42 |
+
|
43 |
+
logging.basicConfig(
|
44 |
+
level=logging.INFO,
|
45 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
46 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
47 |
+
)
|
48 |
+
logger = logging.getLogger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
# Nodes
|
52 |
+
def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState:
|
53 |
+
"""LangGraph node that generates a search queries based on the User's question.
|
54 |
+
|
55 |
+
Uses Gemini 2.0 Flash to create an optimized search query for web research based on
|
56 |
+
the User's question.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
state: Current graph state containing the User's question
|
60 |
+
config: Configuration for the runnable, including LLM provider settings
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Dictionary with state update, including search_query key containing the generated query
|
64 |
+
"""
|
65 |
+
configurable = Configuration.from_runnable_config(config)
|
66 |
+
|
67 |
+
# check for custom initial search query count
|
68 |
+
if state.get("initial_search_query_count") is None:
|
69 |
+
state["initial_search_query_count"] = configurable.number_of_initial_queries
|
70 |
+
|
71 |
+
# init Gemini 2.0 Flash
|
72 |
+
llm = ChatGoogleGenerativeAI(
|
73 |
+
model=configurable.query_generator_model,
|
74 |
+
temperature=2.0,
|
75 |
+
max_retries=2,
|
76 |
+
api_key=os.getenv("GEMINI_API_KEY"),
|
77 |
+
)
|
78 |
+
structured_llm = llm.with_structured_output(SearchQueryList)
|
79 |
+
|
80 |
+
# Format the prompt
|
81 |
+
current_date = get_current_date()
|
82 |
+
formatted_prompt = query_writer_instructions.format(
|
83 |
+
current_date=current_date,
|
84 |
+
research_topic=get_research_topic(state["messages"]),
|
85 |
+
number_queries=state["initial_search_query_count"],
|
86 |
+
)
|
87 |
+
# Generate the search queries
|
88 |
+
result = structured_llm.invoke(formatted_prompt)
|
89 |
+
return {"query_list": result.query}
|
90 |
+
|
91 |
+
|
92 |
+
def continue_to_web_research(state: QueryGenerationState):
|
93 |
+
"""LangGraph node that sends the search queries to the web research node.
|
94 |
+
|
95 |
+
This is used to spawn n number of web research nodes, one for each search query.
|
96 |
+
"""
|
97 |
+
return [
|
98 |
+
Send("web_research", {"search_query": search_query, "id": int(idx)})
|
99 |
+
for idx, search_query in enumerate(state["query_list"])
|
100 |
+
]
|
101 |
+
|
102 |
+
|
103 |
+
def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
|
104 |
+
"""LangGraph node that performs web research using the native Google Search API tool.
|
105 |
+
|
106 |
+
Executes a web search using the native Google Search API tool in combination with Gemini 2.0 Flash.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
state: Current graph state containing the search query and research loop count
|
110 |
+
config: Configuration for the runnable, including search API settings
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results
|
114 |
+
"""
|
115 |
+
# Configure
|
116 |
+
configurable = Configuration.from_runnable_config(config)
|
117 |
+
formatted_prompt = web_searcher_instructions.format(
|
118 |
+
current_date=get_current_date(),
|
119 |
+
research_topic=state["search_query"],
|
120 |
+
)
|
121 |
+
|
122 |
+
# Uses the google genai client as the langchain client doesn't return grounding metadata
|
123 |
+
response = genai_client.models.generate_content(
|
124 |
+
model=configurable.query_generator_model,
|
125 |
+
contents=formatted_prompt,
|
126 |
+
config={
|
127 |
+
"tools": [{"google_search": {}}],
|
128 |
+
"temperature": 0,
|
129 |
+
},
|
130 |
+
)
|
131 |
+
# resolve the urls to short urls for saving tokens and time
|
132 |
+
resolved_urls = resolve_urls(
|
133 |
+
response.candidates[0].grounding_metadata.grounding_chunks, state["id"]
|
134 |
+
)
|
135 |
+
# Gets the citations and adds them to the generated text
|
136 |
+
citations = get_citations(response, resolved_urls)
|
137 |
+
modified_text = insert_citation_markers(response.text, citations)
|
138 |
+
sources_gathered = [item for citation in citations for item in citation["segments"]]
|
139 |
+
|
140 |
+
return {
|
141 |
+
"sources_gathered": sources_gathered,
|
142 |
+
"search_query": [state["search_query"]],
|
143 |
+
"web_research_result": [modified_text],
|
144 |
+
"web_research_result": [response.text],
|
145 |
+
}
|
146 |
+
|
147 |
+
|
148 |
+
def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
|
149 |
+
"""LangGraph node that identifies knowledge gaps and generates potential follow-up queries.
|
150 |
+
|
151 |
+
Analyzes the current summary to identify areas for further research and generates
|
152 |
+
potential follow-up queries. Uses structured output to extract
|
153 |
+
the follow-up query in JSON format.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
state: Current graph state containing the running summary and research topic
|
157 |
+
config: Configuration for the runnable, including LLM provider settings
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
Dictionary with state update, including search_query key containing the generated follow-up query
|
161 |
+
"""
|
162 |
+
configurable = Configuration.from_runnable_config(config)
|
163 |
+
# Increment the research loop count and get the reasoning model
|
164 |
+
state["research_loop_count"] = state.get("research_loop_count", 0) + 1
|
165 |
+
reasoning_model = state.get("reasoning_model") or configurable.reasoning_model
|
166 |
+
|
167 |
+
# Format the prompt
|
168 |
+
current_date = get_current_date()
|
169 |
+
formatted_prompt = reflection_instructions.format(
|
170 |
+
current_date=current_date,
|
171 |
+
research_topic=get_research_topic(state["messages"]),
|
172 |
+
summaries="\n\n---\n\n".join(state["web_research_result"]),
|
173 |
+
)
|
174 |
+
# init Reasoning Model
|
175 |
+
llm = ChatGoogleGenerativeAI(
|
176 |
+
model=reasoning_model,
|
177 |
+
temperature=0.3,
|
178 |
+
max_retries=2,
|
179 |
+
api_key=os.getenv("GEMINI_API_KEY"),
|
180 |
+
)
|
181 |
+
logger.info(f"Reflection node invoked with research prompt:\n{formatted_prompt}")
|
182 |
+
result = llm.with_structured_output(Reflection).invoke(formatted_prompt)
|
183 |
+
|
184 |
+
return {
|
185 |
+
"is_sufficient": result.is_sufficient,
|
186 |
+
"knowledge_gap": result.knowledge_gap,
|
187 |
+
"follow_up_queries": result.follow_up_queries,
|
188 |
+
"research_loop_count": state["research_loop_count"],
|
189 |
+
"number_of_ran_queries": len(state["search_query"]),
|
190 |
+
}
|
191 |
+
|
192 |
+
|
193 |
+
def evaluate_research(
|
194 |
+
state: ReflectionState,
|
195 |
+
config: RunnableConfig,
|
196 |
+
) -> OverallState:
|
197 |
+
"""LangGraph routing function that determines the next step in the research flow.
|
198 |
+
|
199 |
+
Controls the research loop by deciding whether to continue gathering information
|
200 |
+
or to finalize the summary based on the configured maximum number of research loops.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
state: Current graph state containing the research loop count
|
204 |
+
config: Configuration for the runnable, including max_research_loops setting
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
String literal indicating the next node to visit ("web_research" or "finalize_summary")
|
208 |
+
"""
|
209 |
+
configurable = Configuration.from_runnable_config(config)
|
210 |
+
max_research_loops = (
|
211 |
+
state.get("max_research_loops")
|
212 |
+
if state.get("max_research_loops") is not None
|
213 |
+
else configurable.max_research_loops
|
214 |
+
)
|
215 |
+
if state["is_sufficient"] or state["research_loop_count"] >= max_research_loops:
|
216 |
+
return "finalize_answer"
|
217 |
+
else:
|
218 |
+
return [
|
219 |
+
Send(
|
220 |
+
"web_research",
|
221 |
+
{
|
222 |
+
"search_query": follow_up_query,
|
223 |
+
"id": state["number_of_ran_queries"] + int(idx),
|
224 |
+
},
|
225 |
+
)
|
226 |
+
for idx, follow_up_query in enumerate(state["follow_up_queries"])
|
227 |
+
]
|
228 |
+
|
229 |
+
|
230 |
+
def finalize_answer(state: OverallState, config: RunnableConfig):
|
231 |
+
"""LangGraph node that finalizes the research summary.
|
232 |
+
|
233 |
+
Prepares the final output by deduplicating and formatting sources, then
|
234 |
+
combining them with the running summary to create a well-structured
|
235 |
+
research report with proper citations.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
state: Current graph state containing the running summary and sources gathered
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
Dictionary with state update, including running_summary key containing the formatted final summary with sources
|
242 |
+
"""
|
243 |
+
configurable = Configuration.from_runnable_config(config)
|
244 |
+
reasoning_model = state.get("reasoning_model") or configurable.reasoning_model
|
245 |
+
|
246 |
+
# Format the prompt
|
247 |
+
current_date = get_current_date()
|
248 |
+
formatted_prompt = answer_instructions.format(
|
249 |
+
current_date=current_date,
|
250 |
+
research_topic=get_research_topic(state["messages"]),
|
251 |
+
summaries="\n---\n\n".join(state["web_research_result"]),
|
252 |
+
)
|
253 |
+
|
254 |
+
# init Reasoning Model, default to Gemini 2.5 Flash
|
255 |
+
llm = ChatGoogleGenerativeAI(
|
256 |
+
model=reasoning_model,
|
257 |
+
temperature=0,
|
258 |
+
max_retries=5,
|
259 |
+
api_key=os.getenv("GEMINI_API_KEY"),
|
260 |
+
)
|
261 |
+
result = llm.invoke(formatted_prompt)
|
262 |
+
|
263 |
+
# Use the previous result as context for the GAIA final answer
|
264 |
+
gaia_question = state["messages"][-1].content
|
265 |
+
messages = [
|
266 |
+
("system", gaia_system_instructions),
|
267 |
+
("user", f"Context: {result.content}\nQuestion: {gaia_question}"),
|
268 |
+
]
|
269 |
+
gaia_result = llm.invoke(messages)
|
270 |
+
|
271 |
+
# Replace the short urls with the original urls and add all used urls to the sources_gathered
|
272 |
+
# unique_sources = []
|
273 |
+
# for source in state["sources_gathered"]:
|
274 |
+
# if source["short_url"] in result.content:
|
275 |
+
# result.content = result.content.replace(
|
276 |
+
# source["short_url"], source["value"]
|
277 |
+
# )
|
278 |
+
# unique_sources.append(source)
|
279 |
+
|
280 |
+
# return gaia_result
|
281 |
+
return {
|
282 |
+
# "messages": [AIMessage(content=gaia_result.content)],
|
283 |
+
"messages": [gaia_result],
|
284 |
+
# "sources_gathered": unique_sources,
|
285 |
+
}
|
286 |
+
|
287 |
+
|
288 |
+
def build_graph():
|
289 |
+
# Create our Agent Graph
|
290 |
+
builder = StateGraph(OverallState, config_schema=Configuration)
|
291 |
+
|
292 |
+
# Define the nodes we will cycle between
|
293 |
+
builder.add_node("generate_query", generate_query)
|
294 |
+
builder.add_node("web_research", web_research)
|
295 |
+
builder.add_node("reflection", reflection)
|
296 |
+
builder.add_node("finalize_answer", finalize_answer)
|
297 |
+
builder.add_node("evaluate_research", evaluate_research)
|
298 |
+
|
299 |
+
# Set the entrypoint as `generate_query`
|
300 |
+
# This means that this node is the first one called
|
301 |
+
builder.add_edge(START, "generate_query")
|
302 |
+
# Add conditional edge to continue with search queries in a parallel branch
|
303 |
+
builder.add_conditional_edges(
|
304 |
+
"generate_query", continue_to_web_research, ["web_research"]
|
305 |
+
)
|
306 |
+
# Reflect on the web research
|
307 |
+
builder.add_edge("web_research", "reflection")
|
308 |
+
# Evaluate the research
|
309 |
+
builder.add_conditional_edges(
|
310 |
+
"reflection", evaluate_research, ["web_research", "finalize_answer"]
|
311 |
+
)
|
312 |
+
# Finalize the answer
|
313 |
+
builder.add_edge("finalize_answer", END)
|
314 |
+
|
315 |
+
graph = builder.compile(name="pro-search-agent")
|
316 |
+
return graph
|
agents/prompts.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
|
3 |
+
|
4 |
+
# Get current date in a readable format
|
5 |
+
def get_current_date():
|
6 |
+
return datetime.now().strftime("%B %d, %Y")
|
7 |
+
|
8 |
+
|
9 |
+
query_writer_instructions = """Your goal is to generate sophisticated and diverse web search queries. These queries are intended for an advanced automated web research tool capable of analyzing complex results, following links, and synthesizing information.
|
10 |
+
|
11 |
+
Instructions:
|
12 |
+
- Always prefer a single search query, only add another query if the original question requests multiple aspects or elements and one query is not enough.
|
13 |
+
- Each query should focus on one specific aspect of the original question.
|
14 |
+
- Don't produce more than {number_queries} queries.
|
15 |
+
- Queries should be diverse, if the topic is broad, generate more than 1 query.
|
16 |
+
- Don't generate multiple similar queries, 1 is enough.
|
17 |
+
- By default the query should ensure that the most current information is gathered. The current date is {current_date}.
|
18 |
+
- However you should override the most up to date query if the user asks for historical data or trends, in that case you must respect the user's specified revisions and versions.
|
19 |
+
- You must respect and make sure that the users contrains are respected in the queries otherwise you will provide wrong answers.
|
20 |
+
|
21 |
+
Format:
|
22 |
+
- Format your response as a JSON object with ALL three of these exact keys:
|
23 |
+
- "rationale": Brief explanation of why these queries are relevant
|
24 |
+
- "query": A list of search queries
|
25 |
+
|
26 |
+
Example:
|
27 |
+
|
28 |
+
Topic: What revenue grew more last year apple stock or the number of people buying an iphone
|
29 |
+
```json
|
30 |
+
{{
|
31 |
+
"rationale": "To answer this comparative growth question accurately, we need specific data points on Apple's stock performance and iPhone sales metrics. These queries target the precise financial information needed: company revenue trends, product-specific unit sales figures, and stock price movement over the same fiscal period for direct comparison.",
|
32 |
+
"query": ["Apple total revenue growth fiscal year 2024", "iPhone unit sales growth fiscal year 2024", "Apple stock price growth fiscal year 2024"],
|
33 |
+
}}
|
34 |
+
```
|
35 |
+
|
36 |
+
Context: {research_topic}"""
|
37 |
+
|
38 |
+
|
39 |
+
web_searcher_instructions = """Conduct targeted Google Searches to gather the most recent, credible information on "{research_topic}" and synthesize it into a verifiable text artifact.
|
40 |
+
|
41 |
+
Instructions:
|
42 |
+
- By default the query should ensure that the most current information is gathered. The current date is {current_date}.
|
43 |
+
- However you should override the most up to date query if the user asks for historical data or trends, in that case you must respect the user's specified revisions and versions.
|
44 |
+
- You must respect and make sure that the users contrains are respected in the queries otherwise you will provide wrong answers These include specific revisions of certain sources. That is a filter and a different facet of a query
|
45 |
+
- Conduct multiple, diverse searches to gather comprehensive information.
|
46 |
+
- Consolidate key findings while meticulously tracking the source(s) for each specific piece of information.
|
47 |
+
- The output should be a well-written summary or report based on your search findings.
|
48 |
+
- Only include the information found in the search results, don't make up any information.
|
49 |
+
|
50 |
+
Research Topic:
|
51 |
+
{research_topic}
|
52 |
+
"""
|
53 |
+
|
54 |
+
reflection_instructions = """You are an expert research assistant analyzing summaries about "{research_topic}".
|
55 |
+
|
56 |
+
Instructions:
|
57 |
+
- Identify knowledge gaps or areas that need deeper exploration and generate a follow-up query. (1 or multiple).
|
58 |
+
- If provided summaries are sufficient to answer the user's question, don't generate a follow-up query.
|
59 |
+
- If there is a knowledge gap, generate a follow-up query that would help expand your understanding.
|
60 |
+
- Focus on technical details, implementation specifics, or emerging trends that weren't fully covered.
|
61 |
+
|
62 |
+
Requirements:
|
63 |
+
- Ensure the follow-up query is self-contained and includes necessary context for web search.
|
64 |
+
|
65 |
+
Output Format:
|
66 |
+
- Format your response as a JSON object with these exact keys:
|
67 |
+
- "is_sufficient": true or false
|
68 |
+
- "knowledge_gap": Describe what information is missing or needs clarification
|
69 |
+
- "follow_up_queries": Write a specific question to address this gap
|
70 |
+
|
71 |
+
Example:
|
72 |
+
```json
|
73 |
+
{{
|
74 |
+
"is_sufficient": true, // or false
|
75 |
+
"knowledge_gap": "The summary lacks information about performance metrics and benchmarks", // "" if is_sufficient is true
|
76 |
+
"follow_up_queries": ["What are typical performance benchmarks and metrics used to evaluate [specific technology]?"] // [] if is_sufficient is true
|
77 |
+
}}
|
78 |
+
```
|
79 |
+
|
80 |
+
Reflect carefully on the Summaries to identify knowledge gaps and produce a follow-up query. Then, produce your output following this JSON format:
|
81 |
+
|
82 |
+
Summaries:
|
83 |
+
{summaries}
|
84 |
+
"""
|
85 |
+
|
86 |
+
answer_instructions = """Generate a high-quality answer to the user's question based on the provided summaries.
|
87 |
+
|
88 |
+
Instructions:
|
89 |
+
- The current date is {current_date}.
|
90 |
+
- You are the final step of a multi-step research process, don't mention that you are the final step.
|
91 |
+
- You have access to all the information gathered from the previous steps.
|
92 |
+
- You have access to the user's question.
|
93 |
+
- Generate a high-quality answer to the user's question based on the provided summaries and the user's question.
|
94 |
+
- you MUST include all the citations from the summaries in the answer correctly.
|
95 |
+
|
96 |
+
User Context:
|
97 |
+
- {research_topic}
|
98 |
+
|
99 |
+
Summaries:
|
100 |
+
{summaries}"""
|
101 |
+
|
102 |
+
gaia_system_instructions = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
|
agents/state.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import TypedDict
|
5 |
+
|
6 |
+
from langgraph.graph import add_messages
|
7 |
+
from typing_extensions import Annotated
|
8 |
+
|
9 |
+
|
10 |
+
import operator
|
11 |
+
from dataclasses import dataclass, field
|
12 |
+
from typing_extensions import Annotated
|
13 |
+
|
14 |
+
|
15 |
+
class OverallState(TypedDict):
|
16 |
+
messages: Annotated[list, add_messages]
|
17 |
+
search_query: Annotated[list, operator.add]
|
18 |
+
web_research_result: Annotated[list, operator.add]
|
19 |
+
sources_gathered: Annotated[list, operator.add]
|
20 |
+
initial_search_query_count: int
|
21 |
+
max_research_loops: int
|
22 |
+
research_loop_count: int
|
23 |
+
reasoning_model: str
|
24 |
+
|
25 |
+
|
26 |
+
class ReflectionState(TypedDict):
|
27 |
+
is_sufficient: bool
|
28 |
+
knowledge_gap: str
|
29 |
+
follow_up_queries: Annotated[list, operator.add]
|
30 |
+
research_loop_count: int
|
31 |
+
number_of_ran_queries: int
|
32 |
+
|
33 |
+
|
34 |
+
class Query(TypedDict):
|
35 |
+
query: str
|
36 |
+
rationale: str
|
37 |
+
|
38 |
+
|
39 |
+
class QueryGenerationState(TypedDict):
|
40 |
+
query_list: list[Query]
|
41 |
+
|
42 |
+
|
43 |
+
class WebSearchState(TypedDict):
|
44 |
+
search_query: str
|
45 |
+
id: str
|
46 |
+
|
47 |
+
|
48 |
+
@dataclass(kw_only=True)
|
49 |
+
class SearchStateOutput:
|
50 |
+
running_summary: str = field(default=None) # Final report
|
agents/tools_and_schemas.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
|
4 |
+
|
5 |
+
class SearchQueryList(BaseModel):
|
6 |
+
query: List[str] = Field(
|
7 |
+
description="A list of search queries to be used for web research."
|
8 |
+
)
|
9 |
+
rationale: str = Field(
|
10 |
+
description="A brief explanation of why these queries are relevant to the research topic."
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class Reflection(BaseModel):
|
15 |
+
is_sufficient: bool = Field(
|
16 |
+
description="Whether the provided summaries are sufficient to answer the user's question."
|
17 |
+
)
|
18 |
+
knowledge_gap: str = Field(
|
19 |
+
description="A description of what information is missing or needs clarification."
|
20 |
+
)
|
21 |
+
follow_up_queries: List[str] = Field(
|
22 |
+
description="A list of follow-up queries to address the knowledge gap."
|
23 |
+
)
|
agents/utils.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List
|
2 |
+
from langchain_core.messages import AnyMessage, AIMessage, HumanMessage
|
3 |
+
|
4 |
+
|
5 |
+
def get_research_topic(messages: List[AnyMessage]) -> str:
|
6 |
+
"""
|
7 |
+
Get the research topic from the messages.
|
8 |
+
"""
|
9 |
+
# check if request has a history and combine the messages into a single string
|
10 |
+
if len(messages) == 1:
|
11 |
+
research_topic = messages[-1].content
|
12 |
+
else:
|
13 |
+
research_topic = ""
|
14 |
+
for message in messages:
|
15 |
+
if isinstance(message, HumanMessage):
|
16 |
+
research_topic += f"User: {message.content}\n"
|
17 |
+
elif isinstance(message, AIMessage):
|
18 |
+
research_topic += f"Assistant: {message.content}\n"
|
19 |
+
return research_topic
|
20 |
+
|
21 |
+
|
22 |
+
def resolve_urls(urls_to_resolve: List[Any], id: int) -> Dict[str, str]:
|
23 |
+
"""
|
24 |
+
Create a map of the vertex ai search urls (very long) to a short url with a unique id for each url.
|
25 |
+
Ensures each original URL gets a consistent shortened form while maintaining uniqueness.
|
26 |
+
"""
|
27 |
+
prefix = f"https://vertexaisearch.cloud.google.com/id/"
|
28 |
+
urls = [site.web.uri for site in urls_to_resolve]
|
29 |
+
|
30 |
+
# Create a dictionary that maps each unique URL to its first occurrence index
|
31 |
+
resolved_map = {}
|
32 |
+
for idx, url in enumerate(urls):
|
33 |
+
if url not in resolved_map:
|
34 |
+
resolved_map[url] = f"{prefix}{id}-{idx}"
|
35 |
+
|
36 |
+
return resolved_map
|
37 |
+
|
38 |
+
|
39 |
+
def insert_citation_markers(text, citations_list):
|
40 |
+
"""
|
41 |
+
Inserts citation markers into a text string based on start and end indices.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
text (str): The original text string.
|
45 |
+
citations_list (list): A list of dictionaries, where each dictionary
|
46 |
+
contains 'start_index', 'end_index', and
|
47 |
+
'segment_string' (the marker to insert).
|
48 |
+
Indices are assumed to be for the original text.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
str: The text with citation markers inserted.
|
52 |
+
"""
|
53 |
+
# Sort citations by end_index in descending order.
|
54 |
+
# If end_index is the same, secondary sort by start_index descending.
|
55 |
+
# This ensures that insertions at the end of the string don't affect
|
56 |
+
# the indices of earlier parts of the string that still need to be processed.
|
57 |
+
sorted_citations = sorted(
|
58 |
+
citations_list, key=lambda c: (c["end_index"], c["start_index"]), reverse=True
|
59 |
+
)
|
60 |
+
|
61 |
+
modified_text = text
|
62 |
+
for citation_info in sorted_citations:
|
63 |
+
# These indices refer to positions in the *original* text,
|
64 |
+
# but since we iterate from the end, they remain valid for insertion
|
65 |
+
# relative to the parts of the string already processed.
|
66 |
+
end_idx = citation_info["end_index"]
|
67 |
+
marker_to_insert = ""
|
68 |
+
for segment in citation_info["segments"]:
|
69 |
+
marker_to_insert += f" [{segment['label']}]({segment['short_url']})"
|
70 |
+
# Insert the citation marker at the original end_idx position
|
71 |
+
modified_text = (
|
72 |
+
modified_text[:end_idx] + marker_to_insert + modified_text[end_idx:]
|
73 |
+
)
|
74 |
+
|
75 |
+
return modified_text
|
76 |
+
|
77 |
+
|
78 |
+
def get_citations(response, resolved_urls_map):
|
79 |
+
"""
|
80 |
+
Extracts and formats citation information from a Gemini model's response.
|
81 |
+
|
82 |
+
This function processes the grounding metadata provided in the response to
|
83 |
+
construct a list of citation objects. Each citation object includes the
|
84 |
+
start and end indices of the text segment it refers to, and a string
|
85 |
+
containing formatted markdown links to the supporting web chunks.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
response: The response object from the Gemini model, expected to have
|
89 |
+
a structure including `candidates[0].grounding_metadata`.
|
90 |
+
It also relies on a `resolved_map` being available in its
|
91 |
+
scope to map chunk URIs to resolved URLs.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
list: A list of dictionaries, where each dictionary represents a citation
|
95 |
+
and has the following keys:
|
96 |
+
- "start_index" (int): The starting character index of the cited
|
97 |
+
segment in the original text. Defaults to 0
|
98 |
+
if not specified.
|
99 |
+
- "end_index" (int): The character index immediately after the
|
100 |
+
end of the cited segment (exclusive).
|
101 |
+
- "segments" (list[str]): A list of individual markdown-formatted
|
102 |
+
links for each grounding chunk.
|
103 |
+
- "segment_string" (str): A concatenated string of all markdown-
|
104 |
+
formatted links for the citation.
|
105 |
+
Returns an empty list if no valid candidates or grounding supports
|
106 |
+
are found, or if essential data is missing.
|
107 |
+
"""
|
108 |
+
citations = []
|
109 |
+
|
110 |
+
# Ensure response and necessary nested structures are present
|
111 |
+
if not response or not response.candidates:
|
112 |
+
return citations
|
113 |
+
|
114 |
+
candidate = response.candidates[0]
|
115 |
+
if (
|
116 |
+
not hasattr(candidate, "grounding_metadata")
|
117 |
+
or not candidate.grounding_metadata
|
118 |
+
or not hasattr(candidate.grounding_metadata, "grounding_supports")
|
119 |
+
):
|
120 |
+
return citations
|
121 |
+
|
122 |
+
for support in candidate.grounding_metadata.grounding_supports:
|
123 |
+
citation = {}
|
124 |
+
|
125 |
+
# Ensure segment information is present
|
126 |
+
if not hasattr(support, "segment") or support.segment is None:
|
127 |
+
continue # Skip this support if segment info is missing
|
128 |
+
|
129 |
+
start_index = (
|
130 |
+
support.segment.start_index
|
131 |
+
if support.segment.start_index is not None
|
132 |
+
else 0
|
133 |
+
)
|
134 |
+
|
135 |
+
# Ensure end_index is present to form a valid segment
|
136 |
+
if support.segment.end_index is None:
|
137 |
+
continue # Skip if end_index is missing, as it's crucial
|
138 |
+
|
139 |
+
# Add 1 to end_index to make it an exclusive end for slicing/range purposes
|
140 |
+
# (assuming the API provides an inclusive end_index)
|
141 |
+
citation["start_index"] = start_index
|
142 |
+
citation["end_index"] = support.segment.end_index
|
143 |
+
|
144 |
+
citation["segments"] = []
|
145 |
+
if (
|
146 |
+
hasattr(support, "grounding_chunk_indices")
|
147 |
+
and support.grounding_chunk_indices
|
148 |
+
):
|
149 |
+
for ind in support.grounding_chunk_indices:
|
150 |
+
try:
|
151 |
+
chunk = candidate.grounding_metadata.grounding_chunks[ind]
|
152 |
+
resolved_url = resolved_urls_map.get(chunk.web.uri, None)
|
153 |
+
citation["segments"].append(
|
154 |
+
{
|
155 |
+
"label": chunk.web.title.split(".")[:-1][0],
|
156 |
+
"short_url": resolved_url,
|
157 |
+
"value": chunk.web.uri,
|
158 |
+
}
|
159 |
+
)
|
160 |
+
except (IndexError, AttributeError, NameError):
|
161 |
+
# Handle cases where chunk, web, uri, or resolved_map might be problematic
|
162 |
+
# For simplicity, we'll just skip adding this particular segment link
|
163 |
+
# In a production system, you might want to log this.
|
164 |
+
pass
|
165 |
+
citations.append(citation)
|
166 |
+
return citations
|