datawithsuman's picture
Update app.py
8e01382 verified
import os
import streamlit as st
import openai
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, StorageContext
from llama_index.core.node_parser import SentenceSplitter
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import BaseRetriever
# Configuration
class Config:
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
MODEL_NAME = "gpt-3.5-turbo"
EMBEDDING_MODEL = "text-embedding-3-small"
CHUNK_SIZE = 256
# Document Processing
class DocumentProcessor:
def __init__(self):
self.splitter = SentenceSplitter(chunk_size=Config.CHUNK_SIZE)
def process_uploaded_file(self, uploaded_file):
file_path = f"./data/{uploaded_file.name}"
with open(file_path, 'wb') as f:
f.write(uploaded_file.getbuffer())
reader = SimpleDirectoryReader(input_files=[file_path])
documents = reader.load_data()
return documents
def create_index(self, documents):
nodes = self.splitter.get_nodes_from_documents(documents)
storage_context = StorageContext.from_defaults()
storage_context.docstore.add_documents(nodes)
return VectorStoreIndex(nodes=nodes, storage_context=storage_context), nodes
# Hybrid Retriever
class HybridRetriever(BaseRetriever):
def __init__(self, vector_retriever, bm25_retriever):
self.vector_retriever = vector_retriever
self.bm25_retriever = bm25_retriever
super().__init__()
def _retrieve(self, query, **kwargs):
vector_nodes = self.vector_retriever.retrieve(query, **kwargs)
all_nodes = []
node_ids = set()
for n in vector_nodes:
if n.node.node_id not in node_ids:
all_nodes.append(n)
node_ids.add(n.node.node_id)
return all_nodes
# LLM Service
class LLMService:
def __init__(self, model_name):
self.model_name = model_name
openai.api_key = Config.OPENAI_API_KEY
def generate_response(self, prompt, system_message="You are a helpful assistant who answers from the following context. If the answer can't be found in context, politely refuse"):
response = openai.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": prompt}
]
)
return {
'content': response.choices[0].message.content,
'usage': {
'prompt_tokens': response.usage.prompt_tokens,
'completion_tokens': response.usage.completion_tokens,
'total_tokens': response.usage.total_tokens
}
}
def generate_summary(self, text, temperature=0.6):
response = openai.chat.completions.create(
model=self.model_name,
temperature=temperature,
messages=[
{"role": "system", "content": "Summarize the following context:"},
{"role": "user", "content": text}
]
)
return response.choices[0].message.content
# Main Application Class
class PromptOptimizationApp:
def __init__(self):
self.doc_processor = DocumentProcessor()
self.llm_service = LLMService(Config.MODEL_NAME)
self.initialize_session_state()
def initialize_session_state(self):
if "token_summary" not in st.session_state:
st.session_state.token_summary = []
if "messages" not in st.session_state:
st.session_state.messages = []
def process_documents(self, uploaded_files):
for uploaded_file in uploaded_files:
documents = self.doc_processor.process_uploaded_file(uploaded_file)
index, nodes = self.doc_processor.create_index(documents)
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=1)
vector_retriever = index.as_retriever(similarity_top_k=1)
hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever)
return documents, hybrid_retriever
def display_summaries(self, text):
st.success("Reference summary")
ref_summary = self.llm_service.generate_summary(text, temperature=0.6)
st.markdown(ref_summary)
st.success("Generated summary")
gen_summary = self.llm_service.generate_summary(text, temperature=0.8)
st.markdown(gen_summary)
def handle_chat(self, prompt, hybrid_retriever):
st.success("Fetching info...")
context_list = [n.get_content() for n in hybrid_retriever.retrieve(prompt)]
context = " ".join(context_list)
st.success("Getting context")
st.markdown(context)
full_prompt = "\n\n".join([context + prompt])
response = self.llm_service.generate_response(full_prompt)
st.session_state.messages.append({"role": "assistant", "content": response['content']})
with st.chat_message("assistant"):
st.markdown(response['content'])
return response
def main():
st.title("Prompt Optimization for a Policy Bot")
app = PromptOptimizationApp()
uploaded_files = st.file_uploader(
"Upload a Policy document in pdf format",
type="pdf",
accept_multiple_files=True
)
if uploaded_files:
documents, hybrid_retriever = app.process_documents(uploaded_files)
st.success("File uploaded...")
full_text = documents[0].text
st.success("Input text")
st.markdown(full_text)
app.display_summaries(full_text)
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Handle new chat input
if prompt := st.chat_input("Enter your query:"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
app.handle_chat(prompt, hybrid_retriever)
if __name__ == "__main__":
main()