Kathirsci commited on
Commit
ec8e5d1
·
verified ·
1 Parent(s): a8c600f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -31
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  import tempfile
3
  import logging
4
- from typing import List
5
  import torch
6
  from langchain_community.document_loaders import PyPDFLoader
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
@@ -20,25 +20,24 @@ logger = logging.getLogger(__name__)
20
  # Constants
21
  EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
22
  DEFAULT_MODEL = "distilgpt2"
23
- DEFAULT_MAX_LENGTH = 1024 # Increased default max length
24
 
25
  # Check for GPU
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  st.sidebar.write(f"Using device: {device}")
28
 
29
- @st.cache_resource
30
- def load_embeddings():
31
- """Load and cache the embedding model."""
32
  try:
33
- return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
34
  except Exception as e:
35
  logger.error(f"Failed to load embeddings: {e}")
36
- st.error("Failed to load the embedding model. Please try again later.")
37
  return None
38
 
39
- @st.cache_resource
40
- def load_llm(model_name, max_length):
41
- """Load and cache the language model."""
42
  try:
43
  tokenizer = AutoTokenizer.from_pretrained(model_name)
44
  model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -46,10 +45,9 @@ def load_llm(model_name, max_length):
46
  return HuggingFacePipeline(pipeline=pipe)
47
  except Exception as e:
48
  logger.error(f"Failed to load LLM: {e}")
49
- st.error(f"Failed to load the model {model_name}. Please try another model or check your internet connection.")
50
  return None
51
 
52
- def process_pdf(file) -> List[Document]:
53
  """Process the uploaded PDF file."""
54
  try:
55
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
@@ -63,55 +61,50 @@ def process_pdf(file) -> List[Document]:
63
  return documents
64
  except Exception as e:
65
  logger.error(f"Error processing PDF: {e}")
66
- st.error("Failed to process the PDF. Please make sure it's a valid PDF file.")
67
- return []
68
 
69
- def create_vector_store(documents: List[Document], embeddings):
70
  """Create the vector store."""
71
  try:
72
  return FAISS.from_documents(documents, embeddings)
73
  except Exception as e:
74
  logger.error(f"Error creating vector store: {e}")
75
- st.error("Failed to create the vector store. Please try again.")
76
  return None
77
 
78
- def summarize_report(documents: List[Document], llm) -> str:
79
  """Summarize the report using the loaded model."""
80
  try:
81
- prompt_template = """
82
- Summarize the following text in a clear and concise manner. Focus on the main points and key details:
83
-
84
- {text}
85
-
86
  Summary:
87
  """
88
 
89
  prompt = PromptTemplate(template=prompt_template, input_variables=["text"])
90
  chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt)
91
- summary = chain.run(documents)
92
  return summary
93
 
94
  except Exception as e:
95
  logger.error(f"Error summarizing report: {e}")
96
- st.error("Failed to summarize the report. Please try again.")
97
- return ""
98
 
99
  def main():
100
  st.title("Report Summarizer")
101
 
102
  model_option = st.sidebar.text_input("Enter model name", value=DEFAULT_MODEL)
103
- max_length = st.sidebar.slider("Max summary length", min_value=256, max_value=2048, value=DEFAULT_MAX_LENGTH, step=128)
104
 
105
  uploaded_file = st.sidebar.file_uploader("Upload your Report", type="pdf")
106
 
107
- llm = load_llm(model_option, max_length)
108
  if not llm:
109
  st.error(f"Failed to load the model {model_option}. Please try another model.")
110
  return
111
 
112
- embeddings = load_embeddings()
113
  if not embeddings:
114
- st.error("Failed to load embeddings. Please try again later.")
115
  return
116
 
117
  if uploaded_file:
@@ -123,8 +116,15 @@ def main():
123
  db = create_vector_store(documents, embeddings)
124
 
125
  if db and st.button("Summarize"):
 
 
 
 
 
 
 
126
  with st.spinner(f"Generating summary using {model_option}..."):
127
- summary = summarize_report(documents, llm)
128
 
129
  if summary:
130
  st.subheader("Summary:")
@@ -133,4 +133,4 @@ def main():
133
  st.warning("Failed to generate summary. Please try again.")
134
 
135
  if __name__ == "__main__":
136
- main()
 
1
  import streamlit as st
2
  import tempfile
3
  import logging
4
+ from typing import List, Optional
5
  import torch
6
  from langchain_community.document_loaders import PyPDFLoader
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
20
  # Constants
21
  EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
22
  DEFAULT_MODEL = "distilgpt2"
23
+ MAX_LENGTH_FRACTION = 0.2 # Set max_length to 20% of input length
24
 
25
  # Check for GPU
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  st.sidebar.write(f"Using device: {device}")
28
 
29
+ @st.cache_data
30
+ def load_embeddings(model_name: str) -> Optional[HuggingFaceEmbeddings]:
31
+ """Load the embedding model."""
32
  try:
33
+ return HuggingFaceEmbeddings(model_name=model_name)
34
  except Exception as e:
35
  logger.error(f"Failed to load embeddings: {e}")
 
36
  return None
37
 
38
+ @st.cache_data
39
+ def load_llm(model_name: str, max_length: int) -> Optional[HuggingFacePipeline]:
40
+ """Load the language model."""
41
  try:
42
  tokenizer = AutoTokenizer.from_pretrained(model_name)
43
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
45
  return HuggingFacePipeline(pipeline=pipe)
46
  except Exception as e:
47
  logger.error(f"Failed to load LLM: {e}")
 
48
  return None
49
 
50
+ def process_pdf(file) -> Optional[List[Document]]:
51
  """Process the uploaded PDF file."""
52
  try:
53
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
 
61
  return documents
62
  except Exception as e:
63
  logger.error(f"Error processing PDF: {e}")
64
+ return None
 
65
 
66
+ def create_vector_store(documents: List[Document], embeddings: HuggingFaceEmbeddings) -> Optional[FAISS]:
67
  """Create the vector store."""
68
  try:
69
  return FAISS.from_documents(documents, embeddings)
70
  except Exception as e:
71
  logger.error(f"Error creating vector store: {e}")
 
72
  return None
73
 
74
+ def summarize_report(documents: List[Document], llm: HuggingFacePipeline, max_length: int, summary_style: str) -> Optional[str]:
75
  """Summarize the report using the loaded model."""
76
  try:
77
+ prompt_template = f"""
78
+ Summarize the following text in a {summary_style} manner. Focus on the main points and key details:
79
+ {{text}}
 
 
80
  Summary:
81
  """
82
 
83
  prompt = PromptTemplate(template=prompt_template, input_variables=["text"])
84
  chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt)
85
+ summary = chain.run(documents, max_length=max_length)
86
  return summary
87
 
88
  except Exception as e:
89
  logger.error(f"Error summarizing report: {e}")
90
+ return None
 
91
 
92
  def main():
93
  st.title("Report Summarizer")
94
 
95
  model_option = st.sidebar.text_input("Enter model name", value=DEFAULT_MODEL)
96
+ summary_style = st.sidebar.selectbox("Summary style", options=["clear and concise", "formal", "informal", "bullet points"])
97
 
98
  uploaded_file = st.sidebar.file_uploader("Upload your Report", type="pdf")
99
 
100
+ llm = load_llm(model_option, 1024) # Load the model with a default max_length
101
  if not llm:
102
  st.error(f"Failed to load the model {model_option}. Please try another model.")
103
  return
104
 
105
+ embeddings = load_embeddings(EMBEDDING_MODEL)
106
  if not embeddings:
107
+ st.error(f"Failed to load embeddings. Please try again later.")
108
  return
109
 
110
  if uploaded_file:
 
116
  db = create_vector_store(documents, embeddings)
117
 
118
  if db and st.button("Summarize"):
119
+ # Calculate max_length based on input text
120
+ input_length = sum([len(doc.page_content.split()) for doc in documents])
121
+ max_length = int(input_length * MAX_LENGTH_FRACTION)
122
+
123
+ # Reload the model with the calculated max_length
124
+ llm = load_llm(model_option, max_length)
125
+
126
  with st.spinner(f"Generating summary using {model_option}..."):
127
+ summary = summarize_report(documents, llm, max_length, summary_style)
128
 
129
  if summary:
130
  st.subheader("Summary:")
 
133
  st.warning("Failed to generate summary. Please try again.")
134
 
135
  if __name__ == "__main__":
136
+ main()