Shreyas094's picture
Upload 528 files
372531f verified
raw
history blame
7.28 kB
from datetime import datetime
import asyncio
from typing import Dict, List, Optional
from langgraph.graph import StateGraph, END
from .utils.views import print_agent_output
from .utils.llms import call_model
from ..memory.draft import DraftState
from . import ResearchAgent, ReviewerAgent, ReviserAgent
class EditorAgent:
"""Agent responsible for editing and managing code."""
def __init__(self, websocket=None, stream_output=None, headers=None):
self.websocket = websocket
self.stream_output = stream_output
self.headers = headers or {}
async def plan_research(self, research_state: Dict[str, any]) -> Dict[str, any]:
"""
Plan the research outline based on initial research and task parameters.
:param research_state: Dictionary containing research state information
:return: Dictionary with title, date, and planned sections
"""
initial_research = research_state.get("initial_research")
task = research_state.get("task")
include_human_feedback = task.get("include_human_feedback")
human_feedback = research_state.get("human_feedback")
max_sections = task.get("max_sections")
prompt = self._create_planning_prompt(
initial_research, include_human_feedback, human_feedback, max_sections)
print_agent_output(
"Planning an outline layout based on initial research...", agent="EDITOR")
plan = await call_model(
prompt=prompt,
model=task.get("model"),
response_format="json",
)
return {
"title": plan.get("title"),
"date": plan.get("date"),
"sections": plan.get("sections"),
}
async def run_parallel_research(self, research_state: Dict[str, any]) -> Dict[str, List[str]]:
"""
Execute parallel research tasks for each section.
:param research_state: Dictionary containing research state information
:return: Dictionary with research results
"""
agents = self._initialize_agents()
workflow = self._create_workflow()
chain = workflow.compile()
queries = research_state.get("sections")
title = research_state.get("title")
self._log_parallel_research(queries)
final_drafts = [
chain.ainvoke(self._create_task_input(
research_state, query, title))
for query in queries
]
research_results = [
result["draft"] for result in await asyncio.gather(*final_drafts)
]
return {"research_data": research_results}
def _create_planning_prompt(self, initial_research: str, include_human_feedback: bool,
human_feedback: Optional[str], max_sections: int) -> List[Dict[str, str]]:
"""Create the prompt for research planning."""
return [
{
"role": "system",
"content": "You are a research editor. Your goal is to oversee the research project "
"from inception to completion. Your main task is to plan the article section "
"layout based on an initial research summary.\n ",
},
{
"role": "user",
"content": self._format_planning_instructions(initial_research, include_human_feedback,
human_feedback, max_sections),
},
]
def _format_planning_instructions(self, initial_research: str, include_human_feedback: bool,
human_feedback: Optional[str], max_sections: int) -> str:
"""Format the instructions for research planning."""
today = datetime.now().strftime('%d/%m/%Y')
feedback_instruction = (
f"Human feedback: {human_feedback}. You must plan the sections based on the human feedback."
if include_human_feedback and human_feedback and human_feedback != 'no'
else ''
)
return f"""Today's date is {today}
Research summary report: '{initial_research}'
{feedback_instruction}
\nYour task is to generate an outline of sections headers for the research project
based on the research summary report above.
You must generate a maximum of {max_sections} section headers.
You must focus ONLY on related research topics for subheaders and do NOT include introduction, conclusion and references.
You must return nothing but a JSON with the fields 'title' (str) and
'sections' (maximum {max_sections} section headers) with the following structure:
'{{title: string research title, date: today's date,
sections: ['section header 1', 'section header 2', 'section header 3' ...]}}'."""
def _initialize_agents(self) -> Dict[str, any]:
"""Initialize the research, reviewer, and reviser skills."""
return {
"research": ResearchAgent(self.websocket, self.stream_output, self.headers),
"reviewer": ReviewerAgent(self.websocket, self.stream_output, self.headers),
"reviser": ReviserAgent(self.websocket, self.stream_output, self.headers),
}
def _create_workflow(self) -> StateGraph:
"""Create the workflow for the research process."""
agents = self._initialize_agents()
workflow = StateGraph(DraftState)
workflow.add_node("researcher", agents["research"].run_depth_research)
workflow.add_node("reviewer", agents["reviewer"].run)
workflow.add_node("reviser", agents["reviser"].run)
workflow.set_entry_point("researcher")
workflow.add_edge("researcher", "reviewer")
workflow.add_edge("reviser", "reviewer")
workflow.add_conditional_edges(
"reviewer",
lambda draft: "accept" if draft["review"] is None else "revise",
{"accept": END, "revise": "reviser"},
)
return workflow
def _log_parallel_research(self, queries: List[str]) -> None:
"""Log the start of parallel research tasks."""
if self.websocket and self.stream_output:
asyncio.create_task(self.stream_output(
"logs",
"parallel_research",
f"Running parallel research for the following queries: {queries}",
self.websocket,
))
else:
print_agent_output(
f"Running the following research tasks in parallel: {queries}...",
agent="EDITOR",
)
def _create_task_input(self, research_state: Dict[str, any], query: str, title: str) -> Dict[str, any]:
"""Create the input for a single research task."""
return {
"task": research_state.get("task"),
"topic": query,
"title": title,
"headers": self.headers,
}