File size: 6,099 Bytes
6d575f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import gradio as gr
from langchain_groq import ChatGroq
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph
from langchain.chains import GraphQAChain
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Pinecone
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain import PromptTemplate
from neo4j import GraphDatabase
import networkx as nx
import pinecone
import os

# RAG Setup
text_path = r"C:\Users\USER\Downloads\RAG_langchain\text_chunks.txt"
loader = TextLoader(text_path, encoding='utf-8')
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=3000, chunk_overlap=4)
docs = text_splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings()

pinecone.init(
    api_key=os.getenv('PINECONE_API_KEY', '6396a319-9bc0-49b2-97ba-400e96eff377'),
    environment='gcp-starter'
)

index_name = "langchain-demo"
if index_name not in pinecone.list_indexes():
    pinecone.create_index(name=index_name, metric="cosine", dimension=768)
    docsearch = Pinecone.from_documents(docs, embeddings, index_name=index_name)
else:
    docsearch = Pinecone.from_existing_index(index_name, embeddings)

rag_llm = ChatGroq(
    model="Llama3-8b-8192",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=5,
    groq_api_key='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'
)

rag_prompt = PromptTemplate(
    template="""
    You are a Thai rice assistant that gives concise and direct answers. 
    Do not explain the process, 
    just provide the answer,
    provide the answer only in Thai."

    Context: {context}
    Question: {question}
    Answer: 
    """,
    input_variables=["context", "question"]
)

rag_chain = (
    {"context": docsearch.as_retriever(), "question": RunnablePassthrough()}
    | rag_prompt
    | rag_llm
    | StrOutputParser()
)

graphrag_llm = ChatGroq(
    model="Llama3-8b-8192",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=5,
    groq_api_key='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'
)

uri = "neo4j+s://46084f1a.databases.neo4j.io"
user = "neo4j"
password = "FwnX0ige_QYJk8eEYSXSF0l081mWWGIS7TFg6t8rLZc"
driver = GraphDatabase.driver(uri, auth=(user, password))

def fetch_nodes(tx):
    query = "MATCH (n) RETURN id(n) AS id, labels(n) AS labels"
    result = tx.run(query)
    return result.data()

def fetch_relationships(tx):
    query = "MATCH (n)-[r]->(m) RETURN id(n) AS source, id(m) AS target, type(r) AS relation"
    result = tx.run(query)
    return result.data()

def populate_networkx_graph():
    G = nx.Graph()
    with driver.session() as session:
        nodes = session.read_transaction(fetch_nodes)
        relationships = session.read_transaction(fetch_relationships)
        for node in nodes:
            G.add_node(node['id'], labels=node['labels'])
        for relationship in relationships:
            G.add_edge(
                relationship['source'],
                relationship['target'],
                relation=relationship['relation']
            )
    return G

networkx_graph = populate_networkx_graph()
graph = NetworkxEntityGraph()
graph._graph = networkx_graph

graphrag_chain = GraphQAChain.from_llm(
    llm=graphrag_llm,
    graph=graph,
    verbose=True
)

def get_rag_response(question):
    response = rag_chain.invoke(question)
    return response

def get_graphrag_response(question):
    system_prompt = "You are a Thai rice assistant that gives concise and direct answers. Do not explain the process, just provide the answer, provide the answer only in Thai."
    formatted_question = f"System Prompt: {system_prompt}\n\nQuestion: {question}"
    response = graphrag_chain.run(formatted_question)
    return response

def compare_models(question):
    rag_response = get_rag_response(question)
    graphrag_response = get_graphrag_response(question)
    return rag_response, graphrag_response

def store_feedback(feedback, question, rag_response, graphrag_response):
    print("Storing feedback...")
    print(f"Question: {question}")
    print(f"RAG Response: {rag_response}")
    print(f"GraphRAG Response: {graphrag_response}")
    print(f"User Feedback: {feedback}")
    
    with open("feedback.txt", "a", encoding='utf-8') as f:
        f.write(f"Question: {question}\n")
        f.write(f"RAG Response: {rag_response}\n")
        f.write(f"GraphRAG Response: {graphrag_response}\n")
        f.write(f"User Feedback: {feedback}\n\n")

def handle_feedback(feedback, question, rag_response, graphrag_response):
    store_feedback(feedback, question, rag_response, graphrag_response)
    return "Feedback stored successfully!"

with gr.Blocks() as demo:
    gr.Markdown("## Thai Rice Assistant A/B Testing")

    with gr.Row():
        with gr.Column():
            question_input = gr.Textbox(label="Ask a question about Thai rice:")
            submit_btn = gr.Button("Get Answers")

        with gr.Column():
            rag_output = gr.Textbox(label="Model A", interactive=False)
            graphrag_output = gr.Textbox(label="Model B", interactive=False)

    with gr.Row():
        with gr.Column():
            choice = gr.Radio(["A is better", "B is better", "Tie", "Both Bad"], label="Which response is better?")
            send_feedback_btn = gr.Button("Send Feedback")

    def on_submit(question):
        rag_response, graphrag_response = compare_models(question)
        return rag_response, graphrag_response

    def on_feedback(feedback):
        question = question_input.value
        rag_response = rag_output.value
        graphrag_response = graphrag_output.value
        return handle_feedback(feedback, question, rag_response, graphrag_response)

    submit_btn.click(on_submit, inputs=[question_input], outputs=[rag_output, graphrag_output])
    send_feedback_btn.click(on_feedback, inputs=[choice], outputs=[])

demo.launch(share=True)