Shreyas094's picture
Upload 528 files
372531f verified
from datetime import datetime
import asyncio
from typing import Dict, List, Optional
from langgraph.graph import StateGraph, END
from .utils.views import print_agent_output
from .utils.llms import call_model
from ..memory.draft import DraftState
from . import ResearchAgent, ReviewerAgent, ReviserAgent
class EditorAgent:
"""Agent responsible for editing and managing code."""
def __init__(self, websocket=None, stream_output=None, headers=None):
self.websocket = websocket
self.stream_output = stream_output
self.headers = headers or {}
async def plan_research(self, research_state: Dict[str, any]) -> Dict[str, any]:
"""
Plan the research outline based on initial research and task parameters.
:param research_state: Dictionary containing research state information
:return: Dictionary with title, date, and planned sections
"""
initial_research = research_state.get("initial_research")
task = research_state.get("task")
include_human_feedback = task.get("include_human_feedback")
human_feedback = research_state.get("human_feedback")
max_sections = task.get("max_sections")
prompt = self._create_planning_prompt(
initial_research, include_human_feedback, human_feedback, max_sections)
print_agent_output(
"Planning an outline layout based on initial research...", agent="EDITOR")
plan = await call_model(
prompt=prompt,
model=task.get("model"),
response_format="json",
)
return {
"title": plan.get("title"),
"date": plan.get("date"),
"sections": plan.get("sections"),
}
async def run_parallel_research(self, research_state: Dict[str, any]) -> Dict[str, List[str]]:
"""
Execute parallel research tasks for each section.
:param research_state: Dictionary containing research state information
:return: Dictionary with research results
"""
agents = self._initialize_agents()
workflow = self._create_workflow()
chain = workflow.compile()
queries = research_state.get("sections")
title = research_state.get("title")
self._log_parallel_research(queries)
final_drafts = [
chain.ainvoke(self._create_task_input(
research_state, query, title))
for query in queries
]
research_results = [
result["draft"] for result in await asyncio.gather(*final_drafts)
]
return {"research_data": research_results}
def _create_planning_prompt(self, initial_research: str, include_human_feedback: bool,
human_feedback: Optional[str], max_sections: int) -> List[Dict[str, str]]:
"""Create the prompt for research planning."""
return [
{
"role": "system",
"content": "You are a research editor. Your goal is to oversee the research project "
"from inception to completion. Your main task is to plan the article section "
"layout based on an initial research summary.\n ",
},
{
"role": "user",
"content": self._format_planning_instructions(initial_research, include_human_feedback,
human_feedback, max_sections),
},
]
def _format_planning_instructions(self, initial_research: str, include_human_feedback: bool,
human_feedback: Optional[str], max_sections: int) -> str:
"""Format the instructions for research planning."""
today = datetime.now().strftime('%d/%m/%Y')
feedback_instruction = (
f"Human feedback: {human_feedback}. You must plan the sections based on the human feedback."
if include_human_feedback and human_feedback and human_feedback != 'no'
else ''
)
return f"""Today's date is {today}
Research summary report: '{initial_research}'
{feedback_instruction}
\nYour task is to generate an outline of sections headers for the research project
based on the research summary report above.
You must generate a maximum of {max_sections} section headers.
You must focus ONLY on related research topics for subheaders and do NOT include introduction, conclusion and references.
You must return nothing but a JSON with the fields 'title' (str) and
'sections' (maximum {max_sections} section headers) with the following structure:
'{{title: string research title, date: today's date,
sections: ['section header 1', 'section header 2', 'section header 3' ...]}}'."""
def _initialize_agents(self) -> Dict[str, any]:
"""Initialize the research, reviewer, and reviser skills."""
return {
"research": ResearchAgent(self.websocket, self.stream_output, self.headers),
"reviewer": ReviewerAgent(self.websocket, self.stream_output, self.headers),
"reviser": ReviserAgent(self.websocket, self.stream_output, self.headers),
}
def _create_workflow(self) -> StateGraph:
"""Create the workflow for the research process."""
agents = self._initialize_agents()
workflow = StateGraph(DraftState)
workflow.add_node("researcher", agents["research"].run_depth_research)
workflow.add_node("reviewer", agents["reviewer"].run)
workflow.add_node("reviser", agents["reviser"].run)
workflow.set_entry_point("researcher")
workflow.add_edge("researcher", "reviewer")
workflow.add_edge("reviser", "reviewer")
workflow.add_conditional_edges(
"reviewer",
lambda draft: "accept" if draft["review"] is None else "revise",
{"accept": END, "revise": "reviser"},
)
return workflow
def _log_parallel_research(self, queries: List[str]) -> None:
"""Log the start of parallel research tasks."""
if self.websocket and self.stream_output:
asyncio.create_task(self.stream_output(
"logs",
"parallel_research",
f"Running parallel research for the following queries: {queries}",
self.websocket,
))
else:
print_agent_output(
f"Running the following research tasks in parallel: {queries}...",
agent="EDITOR",
)
def _create_task_input(self, research_state: Dict[str, any], query: str, title: str) -> Dict[str, any]:
"""Create the input for a single research task."""
return {
"task": research_state.get("task"),
"topic": query,
"title": title,
"headers": self.headers,
}