File size: 6,546 Bytes
9ee7d05
 
 
89168ac
63530e0
e3a1aaf
4cebfdc
 
 
 
 
 
 
 
 
 
9ee7d05
4cebfdc
 
 
 
 
 
 
 
 
 
 
9ee7d05
4cebfdc
9ee7d05
4cebfdc
 
 
 
 
 
 
 
 
 
 
9ee7d05
4cebfdc
9ee7d05
 
 
4cebfdc
9ee7d05
4cebfdc
 
 
 
9ee7d05
4cebfdc
 
 
 
 
 
 
 
 
 
 
9ee7d05
 
 
 
 
4cebfdc
 
 
 
 
 
206de1d
 
4cebfdc
 
 
 
 
 
 
 
9ee7d05
4cebfdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2053f34
4cebfdc
2053f34
4cebfdc
2053f34
4cebfdc
2053f34
9ee7d05
 
 
2053f34
 
4cebfdc
 
 
 
 
 
 
9ee7d05
 
 
 
 
 
 
 
4cebfdc
 
 
 
 
9ee7d05
4cebfdc
 
9ee7d05
4cebfdc
9ee7d05
 
4cebfdc
9ee7d05
 
 
 
 
632be14
3e68723
9ecae7f
3e68723
9ecae7f
 
 
 
 
 
 
3e68723
 
c9bbc9b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
__import__('pysqlite3')  # Workaround for sqlite3 error on live Streamlit.
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')  # Workaround for sqlite3 error on live Streamlit.
import graphviz
import traceback
import tempfile
from langgraph.graph import StateGraph, END
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from typing import TypedDict, List, Literal, Dict, Any
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from pdf_writer import generate_pdf
from crew import CrewClass, Essay


class GraphState(TypedDict):
    topic: str
    response: str
    documents: List[str]
    essay: Dict[str, Any]
    pdf_name: str


class RouteQuery(BaseModel):
    """Route a user query to direct answer or research."""

    way: Literal["edit_essay", "write_essay", "answer"] = Field(
        ...,
        description="Given a user question, choose to route it to write_essay, edit_essay, or answer",
    )


class EssayWriter:
    def __init__(self):
        self.model = ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0)
        self.crew = CrewClass(llm=ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0.5))

        self.memory = ConversationBufferMemory()
        self.essay = {}
        self.router_prompt = """
                            You are a router, and your duty is to route the user to the correct expert.
                            Always check conversation history and consider your move based on it.
                            If the topic is something about memory or daily talk, route the user to the answer expert.
                            If the topic starts with something like "Can you write" or the user requests an article or essay, route the user to the write_essay expert.
                            If the topic is about editing an essay, route the user to the edit_essay expert.
                            
                            \nConversation History: {memory}
                            \nTopic: {topic}
                            """

        self.simple_answer_prompt = """
                            You are an expert, and you are providing a simple answer to the user's question.
                            
                            \nConversation History: {memory}
                            \nTopic: {topic}
                            """

        builder = StateGraph(GraphState)

        builder.add_node("answer", self.answer)
        builder.add_node("write_essay", self.write_essay)
        builder.add_node("edit_essay", self.edit_essay)

        builder.set_conditional_entry_point(self.router_query, {
            "write_essay": "write_essay",
            "answer": "answer",
            "edit_essay": "edit_essay",
        })

        builder.add_edge("write_essay", END)
        builder.add_edge("edit_essay", END)
        builder.add_edge("answer", END)

        self.graph = builder.compile()
        self.save_workflow_graph()


    def router_query(self, state: GraphState):
        print("**ROUTER**")
        prompt = PromptTemplate.from_template(self.router_prompt)
        memory = self.memory.load_memory_variables({})

        router_query = self.model.with_structured_output(RouteQuery)
        chain = prompt | router_query
        result: RouteQuery = chain.invoke({"topic": state["topic"], "memory": memory})

        print("Router Result: ", result.way)
        return result.way

    def answer(self, state: GraphState):
        print("**ANSWER**")
        prompt = PromptTemplate.from_template(self.simple_answer_prompt)
        memory = self.memory.load_memory_variables({})
        chain = prompt | self.model | StrOutputParser()
        result = chain.invoke({"topic": state["topic"], "memory": memory})

        self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": result})
        return {"response": result}

    def write_essay(self, state: GraphState):
        print("**ESSAY COMPLETION**")
        # Generate the essay using the crew
        self.essay = self.crew.kickoff({"topic": state["topic"]})
        # Save the conversation context
        self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": str(self.essay)})
        # Generate the PDF and return essay content for preview
        pdf_name = generate_pdf(self.essay)
        return {
            "response": "Here is your essay! You can review it below before downloading.",
            "essay": self.essay,
            "pdf_name": pdf_name,
        }

    def edit_essay(self, state: GraphState):
        print("**ESSAY EDIT**")
        memory = self.memory.load_memory_variables({})

        user_request = state["topic"]
        parser = JsonOutputParser(pydantic_object=Essay)
        prompt = PromptTemplate(
            template=(
                "Edit the JSON file as the user requested, and return the new JSON file."
                "\n Request: {user_request} "
                "\n Conversation History: {memory}"
                "\n JSON File: {essay}"
                " \n{format_instructions}"
            ),
            input_variables=["memory", "user_request", "essay"],
            partial_variables={"format_instructions": parser.get_format_instructions()},
        )

        chain = prompt | self.model | parser

        # Update the essay with the edits
        self.essay = chain.invoke({"user_request": user_request, "memory": memory, "essay": self.essay})

        # Save the conversation context
        self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": str(self.essay)})

        # Generate the PDF and return essay content for preview
        pdf_name = generate_pdf(self.essay)
        return {
            "response": "Here is your edited essay! You can review it below before downloading.",
            "essay": self.essay,
            "pdf_name": pdf_name,
        }
        
    def save_workflow_graph(self):
        """Generate and save a dynamic LangGraph visualization to a fixed location."""
        try:
            graph_path = "/tmp/graph.png"  

            # Generate the mermaid diagram and save it to a fixed file
            with open(graph_path, "wb") as f:
                f.write(self.graph.get_graph().draw_mermaid_png())

            print(f"✅ Workflow visualization saved at: {graph_path}")

        except Exception as e:
            print(f"❌ Error generating graph: {e}")