Kathirsci's picture
Update app.py
ec8e5d1 verified
history blame
5.3 kB
import streamlit as st
import tempfile
import logging
from typing import List, Optional
import torch
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFacePipeline
from langchain.chains.summarize import load_summarize_chain
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# Set up logging
logger = logging.getLogger(__name__)
# Constants
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
DEFAULT_MODEL = "distilgpt2"
MAX_LENGTH_FRACTION = 0.2 # Set max_length to 20% of input length
# Check for GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
st.sidebar.write(f"Using device: {device}")
def load_embeddings(model_name: str) -> Optional[HuggingFaceEmbeddings]:
"""Load the embedding model."""
return HuggingFaceEmbeddings(model_name=model_name)
except Exception as e:
logger.error(f"Failed to load embeddings: {e}")
return None
def load_llm(model_name: str, max_length: int) -> Optional[HuggingFacePipeline]:
"""Load the language model."""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device, max_length=max_length)
return HuggingFacePipeline(pipeline=pipe)
except Exception as e:
logger.error(f"Failed to load LLM: {e}")
return None
def process_pdf(file) -> Optional[List[Document]]:
"""Process the uploaded PDF file."""
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
temp_file_path = temp_file.name
loader = PyPDFLoader(file_path=temp_file_path)
pages = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=100)
documents = text_splitter.split_documents(pages)
return documents
except Exception as e:
logger.error(f"Error processing PDF: {e}")
return None
def create_vector_store(documents: List[Document], embeddings: HuggingFaceEmbeddings) -> Optional[FAISS]:
"""Create the vector store."""
return FAISS.from_documents(documents, embeddings)
except Exception as e:
logger.error(f"Error creating vector store: {e}")
return None
def summarize_report(documents: List[Document], llm: HuggingFacePipeline, max_length: int, summary_style: str) -> Optional[str]:
"""Summarize the report using the loaded model."""
prompt_template = f"""
Summarize the following text in a {summary_style} manner. Focus on the main points and key details:
prompt = PromptTemplate(template=prompt_template, input_variables=["text"])
chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt)
summary = chain.run(documents, max_length=max_length)
return summary
except Exception as e:
logger.error(f"Error summarizing report: {e}")
return None
def main():
st.title("Report Summarizer")
model_option = st.sidebar.text_input("Enter model name", value=DEFAULT_MODEL)
summary_style = st.sidebar.selectbox("Summary style", options=["clear and concise", "formal", "informal", "bullet points"])
uploaded_file = st.sidebar.file_uploader("Upload your Report", type="pdf")
llm = load_llm(model_option, 1024) # Load the model with a default max_length
if not llm:
st.error(f"Failed to load the model {model_option}. Please try another model.")
embeddings = load_embeddings(EMBEDDING_MODEL)
if not embeddings:
st.error(f"Failed to load embeddings. Please try again later.")
if uploaded_file:
with st.spinner("Processing PDF..."):
documents = process_pdf(uploaded_file)
if documents:
with st.spinner("Creating vector store..."):
db = create_vector_store(documents, embeddings)
if db and st.button("Summarize"):
# Calculate max_length based on input text
input_length = sum([len(doc.page_content.split()) for doc in documents])
max_length = int(input_length * MAX_LENGTH_FRACTION)
# Reload the model with the calculated max_length
llm = load_llm(model_option, max_length)
with st.spinner(f"Generating summary using {model_option}..."):
summary = summarize_report(documents, llm, max_length, summary_style)
if summary:
st.warning("Failed to generate summary. Please try again.")
if __name__ == "__main__":