Technologic101 commited on
Commit
e64fe22
·
1 Parent(s): b6743fd

task: begins notebook to demonstrate rag retrieval

Browse files
Files changed (1) hide show
  1. src/graph.ipynb +249 -0
src/graph.ipynb ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import getpass\n",
11
+ "\n",
12
+ "\n",
13
+ "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": [
20
+ "Add tools later"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 1,
26
+ "metadata": {},
27
+ "outputs": [
28
+ {
29
+ "name": "stdout",
30
+ "output_type": "stream",
31
+ "text": [
32
+ "Loaded 82 design documents\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "from tools.design_retriever import DesignRetrieverTool\n",
38
+ "from chains.design_rag import DesignRAG\n",
39
+ "\n",
40
+ "# Initialize DesignRAG and create the tool\n",
41
+ "design_rag = DesignRAG()\n",
42
+ "design_retriever = DesignRetrieverTool(rag=design_rag)\n",
43
+ "\n",
44
+ "\n",
45
+ "tool_belt = [design_retriever]"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "metadata": {},
51
+ "source": [
52
+ "Pick a model good for chat and tools"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 4,
58
+ "metadata": {},
59
+ "outputs": [
60
+ {
61
+ "data": {
62
+ "text/plain": [
63
+ "RunnableBinding(bound=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x10d8b4750>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x10d8c1110>, root_client=<openai.OpenAI object at 0x10b08ef10>, root_async_client=<openai.AsyncOpenAI object at 0x10d8b4910>, model_name='gpt-4o-mini', temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), kwargs={'tools': []}, config={}, config_factories=[])"
64
+ ]
65
+ },
66
+ "execution_count": 4,
67
+ "metadata": {},
68
+ "output_type": "execute_result"
69
+ }
70
+ ],
71
+ "source": [
72
+ "from langchain_openai import ChatOpenAI\n",
73
+ "\n",
74
+ "model = ChatOpenAI(model=\"gpt-4o-mini\", temperature=0)\n",
75
+ "\n",
76
+ "model.bind_tools(tool_belt)"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "metadata": {},
82
+ "source": [
83
+ "Initialize state\n"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "from typing import TypedDict, Annotated\n",
93
+ "from langgraph.graph.message import add_messages\n",
94
+ "\n",
95
+ "class AgentState(TypedDict):\n",
96
+ " messages: Annotated[list, add_messages]"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "markdown",
101
+ "metadata": {},
102
+ "source": [
103
+ "Set up the nodes and graph\n"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 8,
109
+ "metadata": {},
110
+ "outputs": [
111
+ {
112
+ "data": {
113
+ "text/plain": [
114
+ "Graph(nodes={'__start__': Node(id='__start__', name='__start__', data=<class 'langchain_core.utils.pydantic.LangGraphInput'>, metadata=None), 'agent': Node(id='agent', name='agent', data=agent(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None), 'action': Node(id='action', name='action', data=tools(tags=None, recurse=True, explode_args=False, func_accepts_config=True, func_accepts={'store': ('__pregel_store', None)}, tools_by_name={}, tool_to_state_args={}, tool_to_store_arg={}, handle_tool_errors=True, messages_key='messages'), metadata=None), '__end__': Node(id='__end__', name='__end__', data=<class 'langchain_core.utils.pydantic.LangGraphOutput'>, metadata=None)}, edges=[Edge(source='__start__', target='agent', data=None, conditional=False), Edge(source='action', target='agent', data=None, conditional=False), Edge(source='agent', target='action', data=None, conditional=True), Edge(source='agent', target='__end__', data=None, conditional=True)])"
115
+ ]
116
+ },
117
+ "execution_count": 8,
118
+ "metadata": {},
119
+ "output_type": "execute_result"
120
+ }
121
+ ],
122
+ "source": [
123
+ "from langgraph.prebuilt import ToolNode\n",
124
+ "from langgraph.graph import StateGraph, END\n",
125
+ "\n",
126
+ "def call_model(state):\n",
127
+ " messages = state[\"messages\"]\n",
128
+ " response = model.invoke(messages)\n",
129
+ " return {\"messages\" : [response]}\n",
130
+ "\n",
131
+ "tool_node = ToolNode(tool_belt)\n",
132
+ "\n",
133
+ "uncompiled_graph = StateGraph(AgentState)\n",
134
+ "\n",
135
+ "uncompiled_graph.add_node(\"agent\", call_model)\n",
136
+ "uncompiled_graph.add_node(\"action\", tool_node)\n",
137
+ "uncompiled_graph.set_entry_point(\"agent\")\n",
138
+ "\n",
139
+ "\n",
140
+ "def should_continue(state):\n",
141
+ " last_message = state[\"messages\"][-1]\n",
142
+ "\n",
143
+ " if last_message.tool_calls:\n",
144
+ " return \"action\"\n",
145
+ "\n",
146
+ " return END\n",
147
+ "\n",
148
+ "uncompiled_graph.add_conditional_edges(\n",
149
+ " \"agent\",\n",
150
+ " should_continue\n",
151
+ ")\n",
152
+ "uncompiled_graph.add_edge(\"action\", \"agent\")\n",
153
+ "\n",
154
+ "graph = uncompiled_graph.compile()\n",
155
+ "\n"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "markdown",
160
+ "metadata": {},
161
+ "source": [
162
+ "Try it out!"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": 11,
168
+ "metadata": {},
169
+ "outputs": [
170
+ {
171
+ "data": {
172
+ "text/plain": [
173
+ "{'messages': [HumanMessage(content='Hello, how are you?', additional_kwargs={}, response_metadata={}, id='1af14df1-568c-460c-997f-11d59e70a3b3'),\n",
174
+ " AIMessage(content=\"Hello! I'm just a computer program, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 31, 'prompt_tokens': 13, 'total_tokens': 44, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_709714d124', 'finish_reason': 'stop', 'logprobs': None}, id='run-d532debb-85a6-4730-9786-43543f1cfd43-0', usage_metadata={'input_tokens': 13, 'output_tokens': 31, 'total_tokens': 44, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}"
175
+ ]
176
+ },
177
+ "execution_count": 11,
178
+ "metadata": {},
179
+ "output_type": "execute_result"
180
+ }
181
+ ],
182
+ "source": [
183
+ "from langchain_core.messages import HumanMessage\n",
184
+ "\n",
185
+ "graph.invoke({\"messages\" : [HumanMessage(content=\"Hello, how are you?\")]})"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "markdown",
190
+ "metadata": {},
191
+ "source": [
192
+ "Let's see if the RAG tool works."
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "\n",
202
+ "\n",
203
+ "# Test the tool\n",
204
+ "test_requirements = {\n",
205
+ " \"style_description\": \"Modern and minimalist\",\n",
206
+ " \"key_elements\": [\"clean lines\", \"white space\", \"typography\"],\n",
207
+ " \"color_scheme\": \"Monochromatic with subtle accent colors\",\n",
208
+ " \"layout_preferences\": \"Grid-based layout with clear hierarchy\",\n",
209
+ " \"mood\": \"Professional and sophisticated\"\n",
210
+ "}\n",
211
+ "\n",
212
+ "# Create a test message\n",
213
+ "from langchain_core.messages import HumanMessage\n",
214
+ "\n",
215
+ "test_message = HumanMessage(\n",
216
+ " content=\"\"\"I need a design that's modern and minimalist, with clean lines and plenty of white space. \n",
217
+ " I want it to use a monochromatic color scheme with subtle accent colors. \n",
218
+ " The layout should be grid-based with clear hierarchy. \n",
219
+ " The overall mood should be professional and sophisticated.\"\"\"\n",
220
+ ")\n",
221
+ "\n",
222
+ "# Invoke the graph with the test message\n",
223
+ "response = graph.invoke({\"messages\": [test_message]})\n",
224
+ "print(response)"
225
+ ]
226
+ }
227
+ ],
228
+ "metadata": {
229
+ "kernelspec": {
230
+ "display_name": ".venv",
231
+ "language": "python",
232
+ "name": "python3"
233
+ },
234
+ "language_info": {
235
+ "codemirror_mode": {
236
+ "name": "ipython",
237
+ "version": 3
238
+ },
239
+ "file_extension": ".py",
240
+ "mimetype": "text/x-python",
241
+ "name": "python",
242
+ "nbconvert_exporter": "python",
243
+ "pygments_lexer": "ipython3",
244
+ "version": "3.11.11"
245
+ }
246
+ },
247
+ "nbformat": 4,
248
+ "nbformat_minor": 2
249
+ }