Spaces:
Running
Running
from typing import Optional, List, Dict, Any, Set | |
import json | |
from .config import Config | |
from .memory import Memory | |
from .utils.enum import ReportSource, ReportType, Tone | |
from .llm_provider import GenericLLMProvider | |
from .vector_store import VectorStoreWrapper | |
# Research skills | |
from .skills.researcher import ResearchConductor | |
from .skills.writer import ReportGenerator | |
from .skills.context_manager import ContextManager | |
from .skills.browser import BrowserManager | |
from .skills.curator import SourceCurator | |
from .actions import ( | |
add_references, | |
extract_headers, | |
extract_sections, | |
table_of_contents, | |
get_retrievers, | |
choose_agent | |
) | |
class GPTResearcher: | |
def __init__( | |
self, | |
query: str, | |
report_type: str = ReportType.ResearchReport.value, | |
report_format: str = "markdown", | |
report_source: str = ReportSource.Web.value, | |
tone: Tone = Tone.Objective, | |
source_urls=None, | |
document_urls=None, | |
complement_source_urls=False, | |
documents=None, | |
vector_store=None, | |
vector_store_filter=None, | |
config_path=None, | |
websocket=None, | |
agent=None, | |
role=None, | |
parent_query: str = "", | |
subtopics: list = [], | |
visited_urls: set = set(), | |
verbose: bool = True, | |
context=[], | |
headers: dict = None, | |
max_subtopics: int = 5, | |
log_handler=None, | |
): | |
self.query = query | |
self.report_type = report_type | |
self.cfg = Config(config_path) | |
self.llm = GenericLLMProvider(self.cfg) | |
self.report_source = report_source if report_source else getattr(self.cfg, 'report_source', None) | |
self.report_format = report_format | |
self.max_subtopics = max_subtopics | |
self.tone = tone if isinstance(tone, Tone) else Tone.Objective | |
self.source_urls = source_urls | |
self.document_urls = document_urls | |
self.complement_source_urls: bool = complement_source_urls | |
self.research_sources = [] # The list of scraped sources including title, content and images | |
self.research_images = [] # The list of selected research images | |
self.documents = documents | |
self.vector_store = VectorStoreWrapper(vector_store) if vector_store else None | |
self.vector_store_filter = vector_store_filter | |
self.websocket = websocket | |
self.agent = agent | |
self.role = role | |
self.parent_query = parent_query | |
self.subtopics = subtopics | |
self.visited_urls = visited_urls | |
self.verbose = verbose | |
self.context = context | |
self.headers = headers or {} | |
self.research_costs = 0.0 | |
self.retrievers = get_retrievers(self.headers, self.cfg) | |
self.memory = Memory( | |
self.cfg.embedding_provider, self.cfg.embedding_model, **self.cfg.embedding_kwargs | |
) | |
self.log_handler = log_handler | |
# Initialize components | |
self.research_conductor: ResearchConductor = ResearchConductor(self) | |
self.report_generator: ReportGenerator = ReportGenerator(self) | |
self.context_manager: ContextManager = ContextManager(self) | |
self.scraper_manager: BrowserManager = BrowserManager(self) | |
self.source_curator: SourceCurator = SourceCurator(self) | |
async def _log_event(self, event_type: str, **kwargs): | |
"""Helper method to handle logging events""" | |
if self.log_handler: | |
try: | |
if event_type == "tool": | |
await self.log_handler.on_tool_start(kwargs.get('tool_name', ''), **kwargs) | |
elif event_type == "action": | |
await self.log_handler.on_agent_action(kwargs.get('action', ''), **kwargs) | |
elif event_type == "research": | |
await self.log_handler.on_research_step(kwargs.get('step', ''), kwargs.get('details', {})) | |
# Add direct logging as backup | |
import logging | |
research_logger = logging.getLogger('research') | |
research_logger.info(f"{event_type}: {json.dumps(kwargs, default=str)}") | |
except Exception as e: | |
import logging | |
logging.getLogger('research').error(f"Error in _log_event: {e}", exc_info=True) | |
async def conduct_research(self): | |
await self._log_event("research", step="start", details={ | |
"query": self.query, | |
"report_type": self.report_type, | |
"agent": self.agent, | |
"role": self.role | |
}) | |
if not (self.agent and self.role): | |
await self._log_event("action", action="choose_agent") | |
self.agent, self.role = await choose_agent( | |
query=self.query, | |
cfg=self.cfg, | |
parent_query=self.parent_query, | |
cost_callback=self.add_costs, | |
headers=self.headers, | |
) | |
await self._log_event("action", action="agent_selected", details={ | |
"agent": self.agent, | |
"role": self.role | |
}) | |
await self._log_event("research", step="conducting_research", details={ | |
"agent": self.agent, | |
"role": self.role | |
}) | |
self.context = await self.research_conductor.conduct_research() | |
await self._log_event("research", step="research_completed", details={ | |
"context_length": len(self.context) | |
}) | |
return self.context | |
async def write_report(self, existing_headers: list = [], relevant_written_contents: list = [], ext_context=None) -> str: | |
await self._log_event("research", step="writing_report", details={ | |
"existing_headers": existing_headers, | |
"context_source": "external" if ext_context else "internal" | |
}) | |
report = await self.report_generator.write_report( | |
existing_headers, | |
relevant_written_contents, | |
ext_context or self.context | |
) | |
await self._log_event("research", step="report_completed", details={ | |
"report_length": len(report) | |
}) | |
return report | |
async def write_report_conclusion(self, report_body: str) -> str: | |
await self._log_event("research", step="writing_conclusion") | |
conclusion = await self.report_generator.write_report_conclusion(report_body) | |
await self._log_event("research", step="conclusion_completed") | |
return conclusion | |
async def write_introduction(self): | |
await self._log_event("research", step="writing_introduction") | |
intro = await self.report_generator.write_introduction() | |
await self._log_event("research", step="introduction_completed") | |
return intro | |
async def get_subtopics(self): | |
return await self.report_generator.get_subtopics() | |
async def get_draft_section_titles(self, current_subtopic: str): | |
return await self.report_generator.get_draft_section_titles(current_subtopic) | |
async def get_similar_written_contents_by_draft_section_titles( | |
self, | |
current_subtopic: str, | |
draft_section_titles: List[str], | |
written_contents: List[Dict], | |
max_results: int = 10 | |
) -> List[str]: | |
return await self.context_manager.get_similar_written_contents_by_draft_section_titles( | |
current_subtopic, | |
draft_section_titles, | |
written_contents, | |
max_results | |
) | |
# Utility methods | |
def get_research_images(self, top_k=10) -> List[Dict[str, Any]]: | |
return self.research_images[:top_k] | |
def add_research_images(self, images: List[Dict[str, Any]]) -> None: | |
self.research_images.extend(images) | |
def get_research_sources(self) -> List[Dict[str, Any]]: | |
return self.research_sources | |
def add_research_sources(self, sources: List[Dict[str, Any]]) -> None: | |
self.research_sources.extend(sources) | |
def add_references(self, report_markdown: str, visited_urls: set) -> str: | |
return add_references(report_markdown, visited_urls) | |
def extract_headers(self, markdown_text: str) -> List[Dict]: | |
return extract_headers(markdown_text) | |
def extract_sections(self, markdown_text: str) -> List[Dict]: | |
return extract_sections(markdown_text) | |
def table_of_contents(self, markdown_text: str) -> str: | |
return table_of_contents(markdown_text) | |
def get_source_urls(self) -> list: | |
return list(self.visited_urls) | |
def get_research_context(self) -> list: | |
return self.context | |
def get_costs(self) -> float: | |
return self.research_costs | |
def set_verbose(self, verbose: bool): | |
self.verbose = verbose | |
def add_costs(self, cost: float) -> None: | |
if not isinstance(cost, (float, int)): | |
raise ValueError("Cost must be an integer or float") | |
self.research_costs += cost | |
if self.log_handler: | |
self._log_event("research", step="cost_update", details={ | |
"cost": cost, | |
"total_cost": self.research_costs | |
}) | |