ataliba / app.py
amiguel's picture
Update app.py
16ad5dc verified
raw
history blame
5.18 kB
import streamlit as st
import torch
import os
import tempfile
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.schema import Document
from langchain.docstore.document import Document as LangchainDocument
# --- Avatars ---
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
# --- Hugging Face Token ---
HF_TOKEN = st.secrets["HF_TOKEN"]
# --- Page Setup ---
st.set_page_config(page_title="Hybrid RAG Chat", page_icon="πŸ€–", layout="centered")
st.title("πŸ€– DigiTwin - Hybrid Search + Streaming")
# --- Sidebar Upload ---
with st.sidebar:
st.header("πŸ“€ Upload Documents")
uploaded_files = st.file_uploader("PDFs or .txt files only", type=["pdf", "txt"], accept_multiple_files=True)
clear_chat = st.button("🧹 Clear Conversation")
# --- Chat Memory ---
if "messages" not in st.session_state or clear_chat:
st.session_state.messages = []
# --- Load LLM ---
@st.cache_resource
def load_model():
model_id = "amiguel/GM_Qwen1.8B_Finetune"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
return tokenizer, model
tokenizer, model = load_model()
# --- Document Processing ---
def process_documents(files):
documents = []
for file in files:
suffix = ".pdf" if file.name.endswith(".pdf") else ".txt"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(file.read())
path = tmp.name
loader = PyPDFLoader(path) if suffix == ".pdf" else TextLoader(path)
documents.extend(loader.load())
return documents
def chunk_documents(docs):
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
return splitter.split_documents(docs)
def build_hybrid_retriever(chunks):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
faiss = FAISS.from_documents(chunks, embeddings)
faiss_ret = faiss.as_retriever(search_type="similarity", search_kwargs={"k": 5})
bm25 = BM25Retriever.from_documents([LangchainDocument(page_content=c.page_content) for c in chunks])
bm25.k = 5
return EnsembleRetriever(retrievers=[faiss_ret, bm25], weights=[0.5, 0.5])
# --- Prompt Builder ---
def build_prompt(history, context=""):
dialog = ""
for msg in history:
role = "User" if msg["role"] == "user" else "Assistant"
dialog += f"{role}: {msg['content']}\n"
return f"""You are DigiTwin, a highly professional and experienced assistant in inspection, integrity, and maintenance of topside equipment, piping systems, pressure vessels, structures, and safety systems.
Use the following context to provide expert-level answers.
Context:
{context}
{dialog}
Assistant:"""
# --- Response Generator ---
def generate_response(prompt):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
Thread(target=model.generate, kwargs={**inputs, "streamer": streamer, "max_new_tokens": 300}).start()
output = ""
for token in streamer:
output += token
yield output
# --- Retrieval Logic ---
retriever = None
if uploaded_files:
with st.spinner("πŸ” Indexing documents..."):
docs = process_documents(uploaded_files)
chunks = chunk_documents(docs)
retriever = build_hybrid_retriever(chunks)
st.success("βœ… Documents ready for hybrid search.")
# --- Display Conversation ---
for msg in st.session_state.messages:
with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR):
st.markdown(msg["content"])
# --- Chat Input ---
if query := st.chat_input("Ask DigiTwin anything..."):
st.chat_message("user", avatar=USER_AVATAR).markdown(query)
st.session_state.messages.append({"role": "user", "content": query})
context = ""
if retriever:
docs = retriever.get_relevant_documents(query)
context = "\n\n".join([doc.page_content for doc in docs])
full_prompt = build_prompt(st.session_state.messages, context)
with st.chat_message("assistant", avatar=BOT_AVATAR):
container = st.empty()
answer = ""
for chunk in generate_response(full_prompt):
answer = chunk
container.markdown(answer + "β–Œ", unsafe_allow_html=True)
container.markdown(answer)
st.session_state.messages.append({"role": "assistant", "content": answer})