Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import openai | |
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, StorageContext | |
from llama_index.core.node_parser import SentenceSplitter | |
from llama_index.retrievers.bm25 import BM25Retriever | |
from llama_index.core.retrievers import BaseRetriever | |
# Configuration | |
class Config: | |
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') | |
MODEL_NAME = "gpt-3.5-turbo" | |
EMBEDDING_MODEL = "text-embedding-3-small" | |
CHUNK_SIZE = 256 | |
# Document Processing | |
class DocumentProcessor: | |
def __init__(self): | |
self.splitter = SentenceSplitter(chunk_size=Config.CHUNK_SIZE) | |
def process_uploaded_file(self, uploaded_file): | |
file_path = f"./data/{uploaded_file.name}" | |
with open(file_path, 'wb') as f: | |
f.write(uploaded_file.getbuffer()) | |
reader = SimpleDirectoryReader(input_files=[file_path]) | |
documents = reader.load_data() | |
return documents | |
def create_index(self, documents): | |
nodes = self.splitter.get_nodes_from_documents(documents) | |
storage_context = StorageContext.from_defaults() | |
storage_context.docstore.add_documents(nodes) | |
return VectorStoreIndex(nodes=nodes, storage_context=storage_context), nodes | |
# Hybrid Retriever | |
class HybridRetriever(BaseRetriever): | |
def __init__(self, vector_retriever, bm25_retriever): | |
self.vector_retriever = vector_retriever | |
self.bm25_retriever = bm25_retriever | |
super().__init__() | |
def _retrieve(self, query, **kwargs): | |
vector_nodes = self.vector_retriever.retrieve(query, **kwargs) | |
all_nodes = [] | |
node_ids = set() | |
for n in vector_nodes: | |
if n.node.node_id not in node_ids: | |
all_nodes.append(n) | |
node_ids.add(n.node.node_id) | |
return all_nodes | |
# LLM Service | |
class LLMService: | |
def __init__(self, model_name): | |
self.model_name = model_name | |
openai.api_key = Config.OPENAI_API_KEY | |
def generate_response(self, prompt, system_message="You are a helpful assistant who answers from the following context. If the answer can't be found in context, politely refuse"): | |
response = openai.chat.completions.create( | |
model=self.model_name, | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": prompt} | |
] | |
) | |
return { | |
'content': response.choices[0].message.content, | |
'usage': { | |
'prompt_tokens': response.usage.prompt_tokens, | |
'completion_tokens': response.usage.completion_tokens, | |
'total_tokens': response.usage.total_tokens | |
} | |
} | |
def generate_summary(self, text, temperature=0.6): | |
response = openai.chat.completions.create( | |
model=self.model_name, | |
temperature=temperature, | |
messages=[ | |
{"role": "system", "content": "Summarize the following context:"}, | |
{"role": "user", "content": text} | |
] | |
) | |
return response.choices[0].message.content | |
# Main Application Class | |
class PromptOptimizationApp: | |
def __init__(self): | |
self.doc_processor = DocumentProcessor() | |
self.llm_service = LLMService(Config.MODEL_NAME) | |
self.initialize_session_state() | |
def initialize_session_state(self): | |
if "token_summary" not in st.session_state: | |
st.session_state.token_summary = [] | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
def process_documents(self, uploaded_files): | |
for uploaded_file in uploaded_files: | |
documents = self.doc_processor.process_uploaded_file(uploaded_file) | |
index, nodes = self.doc_processor.create_index(documents) | |
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=1) | |
vector_retriever = index.as_retriever(similarity_top_k=1) | |
hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever) | |
return documents, hybrid_retriever | |
def display_summaries(self, text): | |
st.success("Reference summary") | |
ref_summary = self.llm_service.generate_summary(text, temperature=0.6) | |
st.markdown(ref_summary) | |
st.success("Generated summary") | |
gen_summary = self.llm_service.generate_summary(text, temperature=0.8) | |
st.markdown(gen_summary) | |
def handle_chat(self, prompt, hybrid_retriever): | |
st.success("Fetching info...") | |
context_list = [n.get_content() for n in hybrid_retriever.retrieve(prompt)] | |
context = " ".join(context_list) | |
st.success("Getting context") | |
st.markdown(context) | |
full_prompt = "\n\n".join([context + prompt]) | |
response = self.llm_service.generate_response(full_prompt) | |
st.session_state.messages.append({"role": "assistant", "content": response['content']}) | |
with st.chat_message("assistant"): | |
st.markdown(response['content']) | |
return response | |
def main(): | |
st.title("Prompt Optimization for a Policy Bot") | |
app = PromptOptimizationApp() | |
uploaded_files = st.file_uploader( | |
"Upload a Policy document in pdf format", | |
type="pdf", | |
accept_multiple_files=True | |
) | |
if uploaded_files: | |
documents, hybrid_retriever = app.process_documents(uploaded_files) | |
st.success("File uploaded...") | |
full_text = documents[0].text | |
st.success("Input text") | |
st.markdown(full_text) | |
app.display_summaries(full_text) | |
# Display chat history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Handle new chat input | |
if prompt := st.chat_input("Enter your query:"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
app.handle_chat(prompt, hybrid_retriever) | |
if __name__ == "__main__": | |
main() |