File size: 6,318 Bytes
c6e5236
8e01382
 
 
 
 
 
a09734b
8e01382
ff0a602
 
 
 
 
a09734b
8e01382
ff0a602
 
 
 
 
 
 
02f1b5e
ff0a602
 
c6e5236
ff0a602
682c36d
ff0a602
 
682c36d
 
ff0a602
682c36d
8e01382
ff0a602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e01382
ff0a602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e01382
ff0a602
 
 
 
 
c6e5236
ff0a602
c6e5236
 
 
 
ff0a602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6e5236
 
 
ff0a602
 
c6e5236
 
 
 
a09734b
ff0a602
a09734b
ff0a602
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import streamlit as st
import openai
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, StorageContext
from llama_index.core.node_parser import SentenceSplitter
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import BaseRetriever

# Configuration
class Config:
    OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
    MODEL_NAME = "gpt-3.5-turbo"
    EMBEDDING_MODEL = "text-embedding-3-small"
    CHUNK_SIZE = 256

# Document Processing
class DocumentProcessor:
    def __init__(self):
        self.splitter = SentenceSplitter(chunk_size=Config.CHUNK_SIZE)
    
    def process_uploaded_file(self, uploaded_file):
        file_path = f"./data/{uploaded_file.name}"
        with open(file_path, 'wb') as f:
            f.write(uploaded_file.getbuffer())
        
        reader = SimpleDirectoryReader(input_files=[file_path])
        documents = reader.load_data()
        return documents

    def create_index(self, documents):
        nodes = self.splitter.get_nodes_from_documents(documents)
        storage_context = StorageContext.from_defaults()
        storage_context.docstore.add_documents(nodes)
        return VectorStoreIndex(nodes=nodes, storage_context=storage_context), nodes

# Hybrid Retriever
class HybridRetriever(BaseRetriever):
    def __init__(self, vector_retriever, bm25_retriever):
        self.vector_retriever = vector_retriever
        self.bm25_retriever = bm25_retriever
        super().__init__()

    def _retrieve(self, query, **kwargs):
        vector_nodes = self.vector_retriever.retrieve(query, **kwargs)
        all_nodes = []
        node_ids = set()
        for n in vector_nodes:
            if n.node.node_id not in node_ids:
                all_nodes.append(n)
                node_ids.add(n.node.node_id)
        return all_nodes

# LLM Service
class LLMService:
    def __init__(self, model_name):
        self.model_name = model_name
        openai.api_key = Config.OPENAI_API_KEY
    
    def generate_response(self, prompt, system_message="You are a helpful assistant who answers from the following context. If the answer can't be found in context, politely refuse"):
        response = openai.chat.completions.create(
            model=self.model_name,
            messages=[
                {"role": "system", "content": system_message},
                {"role": "user", "content": prompt}
            ]
        )
        return {
            'content': response.choices[0].message.content,
            'usage': {
                'prompt_tokens': response.usage.prompt_tokens,
                'completion_tokens': response.usage.completion_tokens,
                'total_tokens': response.usage.total_tokens
            }
        }
    
    def generate_summary(self, text, temperature=0.6):
        response = openai.chat.completions.create(
            model=self.model_name,
            temperature=temperature,
            messages=[
                {"role": "system", "content": "Summarize the following context:"},
                {"role": "user", "content": text}
            ]
        )
        return response.choices[0].message.content

# Main Application Class
class PromptOptimizationApp:
    def __init__(self):
        self.doc_processor = DocumentProcessor()
        self.llm_service = LLMService(Config.MODEL_NAME)
        self.initialize_session_state()
        
    def initialize_session_state(self):
        if "token_summary" not in st.session_state:
            st.session_state.token_summary = []
        if "messages" not in st.session_state:
            st.session_state.messages = []
    
    def process_documents(self, uploaded_files):
        for uploaded_file in uploaded_files:
            documents = self.doc_processor.process_uploaded_file(uploaded_file)
            index, nodes = self.doc_processor.create_index(documents)
            
            bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=1)
            vector_retriever = index.as_retriever(similarity_top_k=1)
            hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever)
            
            return documents, hybrid_retriever
    
    def display_summaries(self, text):
        st.success("Reference summary")
        ref_summary = self.llm_service.generate_summary(text, temperature=0.6)
        st.markdown(ref_summary)
        
        st.success("Generated summary")
        gen_summary = self.llm_service.generate_summary(text, temperature=0.8)
        st.markdown(gen_summary)
    
    def handle_chat(self, prompt, hybrid_retriever):
        st.success("Fetching info...")
        context_list = [n.get_content() for n in hybrid_retriever.retrieve(prompt)]
        context = " ".join(context_list)
        
        st.success("Getting context")
        st.markdown(context)
        
        full_prompt = "\n\n".join([context + prompt])
        response = self.llm_service.generate_response(full_prompt)
        
        st.session_state.messages.append({"role": "assistant", "content": response['content']})
        with st.chat_message("assistant"):
            st.markdown(response['content'])
        
        return response

def main():
    st.title("Prompt Optimization for a Policy Bot")
    
    app = PromptOptimizationApp()
    
    uploaded_files = st.file_uploader(
        "Upload a Policy document in pdf format", 
        type="pdf", 
        accept_multiple_files=True
    )
    
    if uploaded_files:
        documents, hybrid_retriever = app.process_documents(uploaded_files)
        st.success("File uploaded...")
        
        full_text = documents[0].text
        st.success("Input text")
        st.markdown(full_text)
        
        app.display_summaries(full_text)
        
        # Display chat history
        for message in st.session_state.messages:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])
        
        # Handle new chat input
        if prompt := st.chat_input("Enter your query:"):
            st.session_state.messages.append({"role": "user", "content": prompt})
            with st.chat_message("user"):
                st.markdown(prompt)
            
            app.handle_chat(prompt, hybrid_retriever)

if __name__ == "__main__":
    main()