Spaces:
Sleeping
Sleeping
File size: 6,318 Bytes
c6e5236 8e01382 a09734b 8e01382 ff0a602 a09734b 8e01382 ff0a602 02f1b5e ff0a602 c6e5236 ff0a602 682c36d ff0a602 682c36d ff0a602 682c36d 8e01382 ff0a602 8e01382 ff0a602 8e01382 ff0a602 c6e5236 ff0a602 c6e5236 ff0a602 c6e5236 ff0a602 c6e5236 a09734b ff0a602 a09734b ff0a602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
import streamlit as st
import openai
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, StorageContext
from llama_index.core.node_parser import SentenceSplitter
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import BaseRetriever
# Configuration
class Config:
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
MODEL_NAME = "gpt-3.5-turbo"
EMBEDDING_MODEL = "text-embedding-3-small"
CHUNK_SIZE = 256
# Document Processing
class DocumentProcessor:
def __init__(self):
self.splitter = SentenceSplitter(chunk_size=Config.CHUNK_SIZE)
def process_uploaded_file(self, uploaded_file):
file_path = f"./data/{uploaded_file.name}"
with open(file_path, 'wb') as f:
f.write(uploaded_file.getbuffer())
reader = SimpleDirectoryReader(input_files=[file_path])
documents = reader.load_data()
return documents
def create_index(self, documents):
nodes = self.splitter.get_nodes_from_documents(documents)
storage_context = StorageContext.from_defaults()
storage_context.docstore.add_documents(nodes)
return VectorStoreIndex(nodes=nodes, storage_context=storage_context), nodes
# Hybrid Retriever
class HybridRetriever(BaseRetriever):
def __init__(self, vector_retriever, bm25_retriever):
self.vector_retriever = vector_retriever
self.bm25_retriever = bm25_retriever
super().__init__()
def _retrieve(self, query, **kwargs):
vector_nodes = self.vector_retriever.retrieve(query, **kwargs)
all_nodes = []
node_ids = set()
for n in vector_nodes:
if n.node.node_id not in node_ids:
all_nodes.append(n)
node_ids.add(n.node.node_id)
return all_nodes
# LLM Service
class LLMService:
def __init__(self, model_name):
self.model_name = model_name
openai.api_key = Config.OPENAI_API_KEY
def generate_response(self, prompt, system_message="You are a helpful assistant who answers from the following context. If the answer can't be found in context, politely refuse"):
response = openai.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": prompt}
]
)
return {
'content': response.choices[0].message.content,
'usage': {
'prompt_tokens': response.usage.prompt_tokens,
'completion_tokens': response.usage.completion_tokens,
'total_tokens': response.usage.total_tokens
}
}
def generate_summary(self, text, temperature=0.6):
response = openai.chat.completions.create(
model=self.model_name,
temperature=temperature,
messages=[
{"role": "system", "content": "Summarize the following context:"},
{"role": "user", "content": text}
]
)
return response.choices[0].message.content
# Main Application Class
class PromptOptimizationApp:
def __init__(self):
self.doc_processor = DocumentProcessor()
self.llm_service = LLMService(Config.MODEL_NAME)
self.initialize_session_state()
def initialize_session_state(self):
if "token_summary" not in st.session_state:
st.session_state.token_summary = []
if "messages" not in st.session_state:
st.session_state.messages = []
def process_documents(self, uploaded_files):
for uploaded_file in uploaded_files:
documents = self.doc_processor.process_uploaded_file(uploaded_file)
index, nodes = self.doc_processor.create_index(documents)
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=1)
vector_retriever = index.as_retriever(similarity_top_k=1)
hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever)
return documents, hybrid_retriever
def display_summaries(self, text):
st.success("Reference summary")
ref_summary = self.llm_service.generate_summary(text, temperature=0.6)
st.markdown(ref_summary)
st.success("Generated summary")
gen_summary = self.llm_service.generate_summary(text, temperature=0.8)
st.markdown(gen_summary)
def handle_chat(self, prompt, hybrid_retriever):
st.success("Fetching info...")
context_list = [n.get_content() for n in hybrid_retriever.retrieve(prompt)]
context = " ".join(context_list)
st.success("Getting context")
st.markdown(context)
full_prompt = "\n\n".join([context + prompt])
response = self.llm_service.generate_response(full_prompt)
st.session_state.messages.append({"role": "assistant", "content": response['content']})
with st.chat_message("assistant"):
st.markdown(response['content'])
return response
def main():
st.title("Prompt Optimization for a Policy Bot")
app = PromptOptimizationApp()
uploaded_files = st.file_uploader(
"Upload a Policy document in pdf format",
type="pdf",
accept_multiple_files=True
)
if uploaded_files:
documents, hybrid_retriever = app.process_documents(uploaded_files)
st.success("File uploaded...")
full_text = documents[0].text
st.success("Input text")
st.markdown(full_text)
app.display_summaries(full_text)
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Handle new chat input
if prompt := st.chat_input("Enter your query:"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
app.handle_chat(prompt, hybrid_retriever)
if __name__ == "__main__":
main() |