Gampanut commited on
Commit
6d575f5
·
verified ·
1 Parent(s): d7c913a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -0
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain_groq import ChatGroq
3
+ from langchain_community.graphs.networkx_graph import NetworkxEntityGraph
4
+ from langchain.chains import GraphQAChain
5
+ from langchain.document_loaders import TextLoader
6
+ from langchain.text_splitter import CharacterTextSplitter
7
+ from langchain.vectorstores import Pinecone
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from langchain.schema.runnable import RunnablePassthrough
10
+ from langchain.schema.output_parser import StrOutputParser
11
+ from langchain import PromptTemplate
12
+ from neo4j import GraphDatabase
13
+ import networkx as nx
14
+ import pinecone
15
+ import os
16
+
17
+ # RAG Setup
18
+ text_path = r"C:\Users\USER\Downloads\RAG_langchain\text_chunks.txt"
19
+ loader = TextLoader(text_path, encoding='utf-8')
20
+ documents = loader.load()
21
+ text_splitter = CharacterTextSplitter(chunk_size=3000, chunk_overlap=4)
22
+ docs = text_splitter.split_documents(documents)
23
+ embeddings = HuggingFaceEmbeddings()
24
+
25
+ pinecone.init(
26
+ api_key=os.getenv('PINECONE_API_KEY', '6396a319-9bc0-49b2-97ba-400e96eff377'),
27
+ environment='gcp-starter'
28
+ )
29
+
30
+ index_name = "langchain-demo"
31
+ if index_name not in pinecone.list_indexes():
32
+ pinecone.create_index(name=index_name, metric="cosine", dimension=768)
33
+ docsearch = Pinecone.from_documents(docs, embeddings, index_name=index_name)
34
+ else:
35
+ docsearch = Pinecone.from_existing_index(index_name, embeddings)
36
+
37
+ rag_llm = ChatGroq(
38
+ model="Llama3-8b-8192",
39
+ temperature=0,
40
+ max_tokens=None,
41
+ timeout=None,
42
+ max_retries=5,
43
+ groq_api_key='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'
44
+ )
45
+
46
+ rag_prompt = PromptTemplate(
47
+ template="""
48
+ You are a Thai rice assistant that gives concise and direct answers.
49
+ Do not explain the process,
50
+ just provide the answer,
51
+ provide the answer only in Thai."
52
+
53
+ Context: {context}
54
+ Question: {question}
55
+ Answer:
56
+ """,
57
+ input_variables=["context", "question"]
58
+ )
59
+
60
+ rag_chain = (
61
+ {"context": docsearch.as_retriever(), "question": RunnablePassthrough()}
62
+ | rag_prompt
63
+ | rag_llm
64
+ | StrOutputParser()
65
+ )
66
+
67
+ graphrag_llm = ChatGroq(
68
+ model="Llama3-8b-8192",
69
+ temperature=0,
70
+ max_tokens=None,
71
+ timeout=None,
72
+ max_retries=5,
73
+ groq_api_key='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'
74
+ )
75
+
76
+ uri = "neo4j+s://46084f1a.databases.neo4j.io"
77
+ user = "neo4j"
78
+ password = "FwnX0ige_QYJk8eEYSXSF0l081mWWGIS7TFg6t8rLZc"
79
+ driver = GraphDatabase.driver(uri, auth=(user, password))
80
+
81
+ def fetch_nodes(tx):
82
+ query = "MATCH (n) RETURN id(n) AS id, labels(n) AS labels"
83
+ result = tx.run(query)
84
+ return result.data()
85
+
86
+ def fetch_relationships(tx):
87
+ query = "MATCH (n)-[r]->(m) RETURN id(n) AS source, id(m) AS target, type(r) AS relation"
88
+ result = tx.run(query)
89
+ return result.data()
90
+
91
+ def populate_networkx_graph():
92
+ G = nx.Graph()
93
+ with driver.session() as session:
94
+ nodes = session.read_transaction(fetch_nodes)
95
+ relationships = session.read_transaction(fetch_relationships)
96
+ for node in nodes:
97
+ G.add_node(node['id'], labels=node['labels'])
98
+ for relationship in relationships:
99
+ G.add_edge(
100
+ relationship['source'],
101
+ relationship['target'],
102
+ relation=relationship['relation']
103
+ )
104
+ return G
105
+
106
+ networkx_graph = populate_networkx_graph()
107
+ graph = NetworkxEntityGraph()
108
+ graph._graph = networkx_graph
109
+
110
+ graphrag_chain = GraphQAChain.from_llm(
111
+ llm=graphrag_llm,
112
+ graph=graph,
113
+ verbose=True
114
+ )
115
+
116
+ def get_rag_response(question):
117
+ response = rag_chain.invoke(question)
118
+ return response
119
+
120
+ def get_graphrag_response(question):
121
+ system_prompt = "You are a Thai rice assistant that gives concise and direct answers. Do not explain the process, just provide the answer, provide the answer only in Thai."
122
+ formatted_question = f"System Prompt: {system_prompt}\n\nQuestion: {question}"
123
+ response = graphrag_chain.run(formatted_question)
124
+ return response
125
+
126
+ def compare_models(question):
127
+ rag_response = get_rag_response(question)
128
+ graphrag_response = get_graphrag_response(question)
129
+ return rag_response, graphrag_response
130
+
131
+ def store_feedback(feedback, question, rag_response, graphrag_response):
132
+ print("Storing feedback...")
133
+ print(f"Question: {question}")
134
+ print(f"RAG Response: {rag_response}")
135
+ print(f"GraphRAG Response: {graphrag_response}")
136
+ print(f"User Feedback: {feedback}")
137
+
138
+ with open("feedback.txt", "a", encoding='utf-8') as f:
139
+ f.write(f"Question: {question}\n")
140
+ f.write(f"RAG Response: {rag_response}\n")
141
+ f.write(f"GraphRAG Response: {graphrag_response}\n")
142
+ f.write(f"User Feedback: {feedback}\n\n")
143
+
144
+ def handle_feedback(feedback, question, rag_response, graphrag_response):
145
+ store_feedback(feedback, question, rag_response, graphrag_response)
146
+ return "Feedback stored successfully!"
147
+
148
+ with gr.Blocks() as demo:
149
+ gr.Markdown("## Thai Rice Assistant A/B Testing")
150
+
151
+ with gr.Row():
152
+ with gr.Column():
153
+ question_input = gr.Textbox(label="Ask a question about Thai rice:")
154
+ submit_btn = gr.Button("Get Answers")
155
+
156
+ with gr.Column():
157
+ rag_output = gr.Textbox(label="Model A", interactive=False)
158
+ graphrag_output = gr.Textbox(label="Model B", interactive=False)
159
+
160
+ with gr.Row():
161
+ with gr.Column():
162
+ choice = gr.Radio(["A is better", "B is better", "Tie", "Both Bad"], label="Which response is better?")
163
+ send_feedback_btn = gr.Button("Send Feedback")
164
+
165
+ def on_submit(question):
166
+ rag_response, graphrag_response = compare_models(question)
167
+ return rag_response, graphrag_response
168
+
169
+ def on_feedback(feedback):
170
+ question = question_input.value
171
+ rag_response = rag_output.value
172
+ graphrag_response = graphrag_output.value
173
+ return handle_feedback(feedback, question, rag_response, graphrag_response)
174
+
175
+ submit_btn.click(on_submit, inputs=[question_input], outputs=[rag_output, graphrag_output])
176
+ send_feedback_btn.click(on_feedback, inputs=[choice], outputs=[])
177
+
178
+ demo.launch(share=True)