Akshayram1 commited on
Commit
ecd6b96
·
verified ·
1 Parent(s): 62dc939

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +198 -198
main.py CHANGED
@@ -1,198 +1,198 @@
1
- import os
2
- from dotenv import load_dotenv
3
- from langchain.document_loaders import PyPDFLoader
4
- from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain.schema import Document
6
- from langchain.prompts import PromptTemplate
7
- from langchain.vectorstores import Neo4jVector
8
- from langchain.chat_models import ChatOpenAI
9
- from langchain.embeddings import OpenAIEmbeddings
10
- from langchain.graphs import Neo4jGraph
11
- from langchain_experimental.graph_transformers import LLMGraphTransformer
12
- from langchain.chains.graph_qa.cypher import GraphCypherQAChain
13
- import streamlit as st
14
- import tempfile
15
- from neo4j import GraphDatabase
16
- from openai.embeddings_utils import OpenAIEmbeddings
17
- from openai.chat_models import ChatOpenAI
18
-
19
- def main():
20
- st.set_page_config(
21
- layout="wide",
22
- page_title="MayaJal",
23
- page_icon=":graph:"
24
- )
25
- st.sidebar.image('logo.png', use_column_width=True)
26
- with st.sidebar.expander("Expand Me"):
27
- st.markdown("""
28
- This application allows you to upload a PDF file, extract its content into a Neo4j graph database, and perform queries using natural language.
29
- It leverages LangChain and OpenAI's GPT models to generate Cypher queries that interact with the Neo4j database in real-time.
30
- """)
31
- st.title("Mayajal: Realtime GraphRAG App")
32
-
33
- load_dotenv()
34
-
35
- # Set OpenAI API key
36
- if 'OPENAI_API_KEY' not in st.session_state:
37
- st.sidebar.subheader("OpenAI API Key")
38
- openai_api_key = st.sidebar.text_input("Enter your OpenAI API Key:", type='password')
39
- if openai_api_key:
40
- os.environ['OPENAI_API_KEY'] = openai_api_key
41
- st.session_state['OPENAI_API_KEY'] = openai_api_key
42
- st.sidebar.success("OpenAI API Key set successfully.")
43
- embeddings = OpenAIEmbeddings()
44
- llm = ChatOpenAI(model_name="gpt-4o") # Use model that supports function calling
45
- st.session_state['embeddings'] = embeddings
46
- st.session_state['llm'] = llm
47
- else:
48
- embeddings = st.session_state['embeddings']
49
- llm = st.session_state['llm']
50
-
51
- # Initialize variables
52
- neo4j_url = None
53
- neo4j_username = None
54
- neo4j_password = None
55
- graph = None
56
-
57
- # Set Neo4j connection details
58
- if 'neo4j_connected' not in st.session_state:
59
- st.sidebar.subheader("Connect to Neo4j Database")
60
- neo4j_url = st.sidebar.text_input("Neo4j URL:", value="neo4j+s://<your-neo4j-url>")
61
- neo4j_username = st.sidebar.text_input("Neo4j Username:", value="neo4j")
62
- neo4j_password = st.sidebar.text_input("Neo4j Password:", type='password')
63
- connect_button = st.sidebar.button("Connect")
64
- if connect_button and neo4j_password:
65
- try:
66
- graph = Neo4jGraph(
67
- url=neo4j_url,
68
- username=neo4j_username,
69
- password=neo4j_password
70
- )
71
- st.session_state['graph'] = graph
72
- st.session_state['neo4j_connected'] = True
73
- # Store connection parameters for later use
74
- st.session_state['neo4j_url'] = neo4j_url
75
- st.session_state['neo4j_username'] = neo4j_username
76
- st.session_state['neo4j_password'] = neo4j_password
77
- st.sidebar.success("Connected to Neo4j database.")
78
- except Exception as e:
79
- st.error(f"Failed to connect to Neo4j: {e}")
80
- else:
81
- graph = st.session_state['graph']
82
- neo4j_url = st.session_state['neo4j_url']
83
- neo4j_username = st.session_state['neo4j_username']
84
- neo4j_password = st.session_state['neo4j_password']
85
-
86
- # Ensure that the Neo4j connection is established before proceeding
87
- if graph is not None:
88
- # File uploader
89
- uploaded_file = st.file_uploader("Please select a PDF file.", type="pdf")
90
-
91
- if uploaded_file is not None and 'qa' not in st.session_state:
92
- with st.spinner("Processing the PDF..."):
93
- # Save uploaded file to temporary file
94
- with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
95
- tmp_file.write(uploaded_file.read())
96
- tmp_file_path = tmp_file.name
97
-
98
- # Load and split the PDF
99
- loader = PyPDFLoader(tmp_file_path)
100
- pages = loader.load_and_split()
101
-
102
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=40)
103
- docs = text_splitter.split_documents(pages)
104
-
105
- lc_docs = []
106
- for doc in docs:
107
- lc_docs.append(Document(page_content=doc.page_content.replace("\n", ""),
108
- metadata={'source': uploaded_file.name}))
109
-
110
- # Clear the graph database
111
- cypher = """
112
- MATCH (n)
113
- DETACH DELETE n;
114
- """
115
- graph.query(cypher)
116
-
117
- # Define allowed nodes and relationships
118
- allowed_nodes = ["Patient", "Disease", "Medication", "Test", "Symptom", "Doctor"]
119
- allowed_relationships = ["HAS_DISEASE", "TAKES_MEDICATION", "UNDERWENT_TEST", "HAS_SYMPTOM", "TREATED_BY"]
120
-
121
- # Transform documents into graph documents
122
- transformer = LLMGraphTransformer(
123
- llm=llm,
124
- allowed_nodes=allowed_nodes,
125
- allowed_relationships=allowed_relationships,
126
- node_properties=False,
127
- relationship_properties=False
128
- )
129
-
130
- graph_documents = transformer.convert_to_graph_documents(lc_docs)
131
- graph.add_graph_documents(graph_documents, include_source=True)
132
-
133
- # Use the stored connection parameters
134
- index = Neo4jVector.from_existing_graph(
135
- embedding=embeddings,
136
- url=neo4j_url,
137
- username=neo4j_username,
138
- password=neo4j_password,
139
- database="neo4j",
140
- node_label="Patient", # Adjust node_label as needed
141
- text_node_properties=["id", "text"],
142
- embedding_node_property="embedding",
143
- index_name="vector_index",
144
- keyword_index_name="entity_index",
145
- search_type="hybrid"
146
- )
147
-
148
- st.success(f"{uploaded_file.name} preparation is complete.")
149
-
150
- # Retrieve the graph schema
151
- schema = graph.get_schema
152
-
153
- # Set up the QA chain
154
- template = """
155
- Task: Generate a Cypher statement to query the graph database.
156
-
157
- Instructions:
158
- Use only relationship types and properties provided in schema.
159
- Do not use other relationship types or properties that are not provided.
160
-
161
- schema:
162
- {schema}
163
-
164
- Note: Do not include explanations or apologies in your answers.
165
- Do not answer questions that ask anything other than creating Cypher statements.
166
- Do not include any text other than generated Cypher statements.
167
-
168
- Question: {question}"""
169
-
170
- question_prompt = PromptTemplate(
171
- template=template,
172
- input_variables=["schema", "question"]
173
- )
174
-
175
- qa = GraphCypherQAChain.from_llm(
176
- llm=llm,
177
- graph=graph,
178
- cypher_prompt=question_prompt,
179
- verbose=True,
180
- allow_dangerous_requests=True
181
- )
182
- st.session_state['qa'] = qa
183
- else:
184
- st.warning("Please connect to the Neo4j database before you can upload a PDF.")
185
-
186
- if 'qa' in st.session_state:
187
- st.subheader("Ask a Question")
188
- with st.form(key='question_form'):
189
- question = st.text_input("Enter your question:")
190
- submit_button = st.form_submit_button(label='Submit')
191
-
192
- if submit_button and question:
193
- with st.spinner("Generating answer..."):
194
- res = st.session_state['qa'].invoke({"query": question})
195
- st.write("\n**Answer:**\n" + res['result'])
196
-
197
- if __name__ == "__main__":
198
- main()
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain.document_loaders import PyPDFLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.schema import Document
6
+ from langchain.prompts import PromptTemplate
7
+ from langchain.vectorstores import Neo4jVector
8
+ from langchain.chat_models import ChatOpenAI
9
+ from langchain.embeddings import OpenAIEmbeddings
10
+ from langchain.graphs import Neo4jGraph
11
+ from langchain_experimental.graph_transformers import LLMGraphTransformer
12
+ from langchain.chains.graph_qa.cypher import GraphCypherQAChain
13
+ import streamlit as st
14
+ import tempfile
15
+ from neo4j import GraphDatabase
16
+
17
+ def main():
18
+ st.set_page_config(
19
+ layout="wide",
20
+ page_title="Graphy v1",
21
+ page_icon=":graph:"
22
+ )
23
+ st.sidebar.image('logo.png', use_column_width=True)
24
+ with st.sidebar.expander("Expand Me"):
25
+ st.markdown("""
26
+ This application allows you to upload a PDF file, extract its content into a Neo4j graph database, and perform queries using natural language.
27
+ It leverages LangChain and OpenAI's GPT models to generate Cypher queries that interact with the Neo4j database in real-time.
28
+ """)
29
+ st.title("Graphy: Realtime GraphRAG App")
30
+
31
+ load_dotenv()
32
+
33
+ # Set OpenAI API key
34
+ if 'OPENAI_API_KEY' not in st.session_state:
35
+ st.sidebar.subheader("OpenAI API Key")
36
+ openai_api_key = st.sidebar.text_input("Enter your OpenAI API Key:", type='password')
37
+ if openai_api_key:
38
+ os.environ['OPENAI_API_KEY'] = openai_api_key
39
+ st.session_state['OPENAI_API_KEY'] = openai_api_key
40
+ st.sidebar.success("OpenAI API Key set successfully.")
41
+ embeddings = OpenAIEmbeddings()
42
+ llm = ChatOpenAI(model_name="gpt-4o") # Use model that supports function calling
43
+ st.session_state['embeddings'] = embeddings
44
+ st.session_state['llm'] = llm
45
+ else:
46
+ embeddings = st.session_state['embeddings']
47
+ llm = st.session_state['llm']
48
+
49
+ # Initialize variables
50
+ neo4j_url = None
51
+ neo4j_username = None
52
+ neo4j_password = None
53
+ graph = None
54
+
55
+ # Set Neo4j connection details
56
+ if 'neo4j_connected' not in st.session_state:
57
+ st.sidebar.subheader("Connect to Neo4j Database")
58
+ neo4j_url = st.sidebar.text_input("Neo4j URL:", value="neo4j+s://<your-neo4j-url>")
59
+ neo4j_username = st.sidebar.text_input("Neo4j Username:", value="neo4j")
60
+ neo4j_password = st.sidebar.text_input("Neo4j Password:", type='password')
61
+ connect_button = st.sidebar.button("Connect")
62
+ if connect_button and neo4j_password:
63
+ try:
64
+ graph = Neo4jGraph(
65
+ url=neo4j_url,
66
+ username=neo4j_username,
67
+ password=neo4j_password
68
+ )
69
+ st.session_state['graph'] = graph
70
+ st.session_state['neo4j_connected'] = True
71
+ # Store connection parameters for later use
72
+ st.session_state['neo4j_url'] = neo4j_url
73
+ st.session_state['neo4j_username'] = neo4j_username
74
+ st.session_state['neo4j_password'] = neo4j_password
75
+ st.sidebar.success("Connected to Neo4j database.")
76
+ except Exception as e:
77
+ st.error(f"Failed to connect to Neo4j: {e}")
78
+ else:
79
+ graph = st.session_state['graph']
80
+ neo4j_url = st.session_state['neo4j_url']
81
+ neo4j_username = st.session_state['neo4j_username']
82
+ neo4j_password = st.session_state['neo4j_password']
83
+
84
+ # Ensure that the Neo4j connection is established before proceeding
85
+ if graph is not None:
86
+ # File uploader
87
+ uploaded_file = st.file_uploader("Please select a PDF file.", type="pdf")
88
+
89
+ if uploaded_file is not None and 'qa' not in st.session_state:
90
+ with st.spinner("Processing the PDF..."):
91
+ # Save uploaded file to temporary file
92
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
93
+ tmp_file.write(uploaded_file.read())
94
+ tmp_file_path = tmp_file.name
95
+
96
+ # Load and split the PDF
97
+ loader = PyPDFLoader(tmp_file_path)
98
+ pages = loader.load_and_split()
99
+
100
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=40)
101
+ docs = text_splitter.split_documents(pages)
102
+
103
+ lc_docs = []
104
+ for doc in docs:
105
+ lc_docs.append(Document(page_content=doc.page_content.replace("\n", ""),
106
+ metadata={'source': uploaded_file.name}))
107
+
108
+ # Clear the graph database
109
+ cypher = """
110
+ MATCH (n)
111
+ DETACH DELETE n;
112
+ """
113
+ graph.query(cypher)
114
+
115
+ # Define allowed nodes and relationships
116
+ allowed_nodes = ["Patient", "Disease", "Medication", "Test", "Symptom", "Doctor"]
117
+ allowed_relationships = ["HAS_DISEASE", "TAKES_MEDICATION", "UNDERWENT_TEST", "HAS_SYMPTOM", "TREATED_BY"]
118
+
119
+ # Transform documents into graph documents
120
+ transformer = LLMGraphTransformer(
121
+ llm=llm,
122
+ allowed_nodes=allowed_nodes,
123
+ allowed_relationships=allowed_relationships,
124
+ node_properties=False,
125
+ relationship_properties=False
126
+ )
127
+
128
+ graph_documents = transformer.convert_to_graph_documents(lc_docs)
129
+ graph.add_graph_documents(graph_documents, include_source=True)
130
+
131
+ # Use the stored connection parameters
132
+ index = Neo4jVector.from_existing_graph(
133
+ embedding=embeddings,
134
+ url=neo4j_url,
135
+ username=neo4j_username,
136
+ password=neo4j_password,
137
+ database="neo4j",
138
+ node_label="Patient", # Adjust node_label as needed
139
+ text_node_properties=["id", "text"],
140
+ embedding_node_property="embedding",
141
+ index_name="vector_index",
142
+ keyword_index_name="entity_index",
143
+ search_type="hybrid"
144
+ )
145
+
146
+ st.success(f"{uploaded_file.name} preparation is complete.")
147
+
148
+ # Retrieve the graph schema
149
+ schema = graph.get_schema
150
+
151
+ # Set up the QA chain
152
+ template = """
153
+ Task: Generate a Cypher statement to query the graph database.
154
+
155
+ Instructions:
156
+ Use only relationship types and properties provided in schema.
157
+ Do not use other relationship types or properties that are not provided.
158
+
159
+ schema:
160
+ {schema}
161
+
162
+ Note: Do not include explanations or apologies in your answers.
163
+ Do not answer questions that ask anything other than creating Cypher statements.
164
+ Do not include any text other than generated Cypher statements.
165
+
166
+ Question: {question}"""
167
+
168
+ question_prompt = PromptTemplate(
169
+ template=template,
170
+ input_variables=["schema", "question"]
171
+ )
172
+
173
+ qa = GraphCypherQAChain.from_llm(
174
+ llm=llm,
175
+ graph=graph,
176
+ cypher_prompt=question_prompt,
177
+ verbose=True,
178
+ allow_dangerous_requests=True
179
+ )
180
+ st.session_state['qa'] = qa
181
+ else:
182
+ st.warning("Please connect to the Neo4j database before you can upload a PDF.")
183
+
184
+ if 'qa' in st.session_state:
185
+ st.subheader("Ask a Question")
186
+ with st.form(key='question_form'):
187
+ question = st.text_input("Enter your question:")
188
+ submit_button = st.form_submit_button(label='Submit')
189
+
190
+ if submit_button and question:
191
+ with st.spinner("Generating answer..."):
192
+ res = st.session_state['qa'].invoke({"query": question})
193
+ st.write("\n**Answer:**\n" + res['result'])
194
+
195
+ if __name__ == "__main__":
196
+ main()
197
+
198
+