DrishtiSharma commited on
Commit
4cebfdc
·
verified ·
1 Parent(s): 33c0070

Create graph.py

Browse files
Files changed (1) hide show
  1. graph.py +133 -0
graph.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __import__('pysqlite3') # This is a workaround to fix the error "sqlite3 module is not found" on live streamlit.
2
+ import sys
3
+ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') # This is a workaround to fix the error "sqlite3 module is not found" on live streamlit.
4
+
5
+ from langgraph.graph import StateGraph, END
6
+ from langchain_openai import ChatOpenAI
7
+ from pydantic import BaseModel, Field
8
+ from typing import TypedDict, List, Literal, Dict, Any
9
+ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.memory import ConversationBufferMemory
12
+ from pdf_writer import generate_pdf
13
+
14
+ from crew import CrewClass, Essay
15
+
16
+ class GraphState(TypedDict):
17
+ topic: str
18
+ response: str
19
+ documents: List[str]
20
+ essay: Dict[str, Any]
21
+ pdf_name: str
22
+
23
+
24
+ class RouteQuery(BaseModel):
25
+ """Route a user query to direct answer or research."""
26
+
27
+ way: Literal["edit_essay","write_essay", "answer"] = Field(
28
+ ...,
29
+ description="Given a user question choose to route it to write_essay, edit_essay or answer",
30
+ )
31
+
32
+
33
+ class EssayWriter:
34
+ def __init__(self):
35
+ self.model = ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0)
36
+ self.crew = CrewClass(llm=ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0.5))
37
+
38
+ self.memory = ConversationBufferMemory()
39
+ self.essay = {}
40
+ self.router_prompt = """
41
+ You are a router and your duty is to route the user to the correct expert.
42
+ Always check conversation history and consider your move based on it.
43
+ If topic is something about memory, or daily talk route the user to the answer expert.
44
+ If topic starts something like can u write, or user request you write an article or essay, route the user to the write_essay expert.
45
+ If topic is user wants to edit anything in the essay, route the user to the edit_essay expert.
46
+
47
+ \nConservation History: {memory}
48
+ \nTopic: {topic}
49
+ """
50
+
51
+ self.simple_answer_prompt = """
52
+ You are an expert and you are providing a simple answer to the user's question.
53
+
54
+ \nConversation History: {memory}
55
+ \nTopic: {topic}
56
+ """
57
+
58
+
59
+ builder = StateGraph(GraphState)
60
+
61
+ builder.add_node("answer", self.answer)
62
+ builder.add_node("write_essay", self.write_essay)
63
+ builder.add_node("edit_essay", self.edit_essay)
64
+
65
+
66
+ builder.set_conditional_entry_point(self.router_query,
67
+ {"write_essay": "write_essay",
68
+ "answer": "answer",
69
+ "edit_essay": "edit_essay"})
70
+ builder.add_edge("write_essay", END)
71
+ builder.add_edge("edit_essay", END)
72
+ builder.add_edge("answer", END)
73
+
74
+ self.graph = builder.compile()
75
+ self.graph.get_graph().draw_mermaid_png(output_file_path="graph.png")
76
+
77
+
78
+ def router_query(self, state: GraphState):
79
+ print("**ROUTER**")
80
+ prompt = PromptTemplate.from_template(self.router_prompt)
81
+ memory = self.memory.load_memory_variables({})
82
+
83
+ router_query = self.model.with_structured_output(RouteQuery)
84
+ chain = prompt | router_query
85
+ result: RouteQuery = chain.invoke({"topic": state["topic"], "memory": memory})
86
+
87
+ print("Router Result: ", result.way)
88
+ return result.way
89
+
90
+ def answer(self, state: GraphState):
91
+ print("**ANSWER**")
92
+ prompt = PromptTemplate.from_template(self.simple_answer_prompt)
93
+ memory = self.memory.load_memory_variables({})
94
+ chain = prompt | self.model | StrOutputParser()
95
+ result = chain.invoke({"topic": state["topic"], "memory": memory})
96
+
97
+ self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": result})
98
+ return {"response": result}
99
+
100
+ def write_essay(self, state: GraphState):
101
+ print("**ESSAY COMPLETION**")
102
+
103
+ self.essay = self.crew.kickoff({"topic": state["topic"]})
104
+
105
+ self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": str(self.essay)})
106
+
107
+ pdf_name = generate_pdf(self.essay)
108
+ return {"response": "Here is your essay! ", "pdf_name": f"{pdf_name}"}
109
+
110
+ def edit_essay(self, state: GraphState):
111
+ print("**ESSAY EDIT**")
112
+ memory = self.memory.load_memory_variables({})
113
+
114
+ user_request = state["topic"]
115
+ parser = JsonOutputParser(pydantic_object=Essay)
116
+ prompt = PromptTemplate(
117
+ template=("Edit the Json file as user requested, and return the new Json file."
118
+ "\n Request:{user_request} "
119
+ "\n Conservation History: {memory}"
120
+ "\n Json File: {essay}"
121
+ " \n{format_instructions}"),
122
+ input_variables=["memory","user_request","essay"],
123
+ partial_variables={"format_instructions": parser.get_format_instructions()},
124
+ )
125
+
126
+ chain = prompt | self.model | parser
127
+
128
+ self.essay = chain.invoke({"user_request": user_request, "memory": memory, "essay": self.essay})
129
+
130
+
131
+ self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": str(self.essay)})
132
+ pdf_name = generate_pdf(self.essay)
133
+ return {"response": "Here is your edited essay! ", "essay": self.essay, "pdf_name": f"{pdf_name}"}