stefanoviel commited on
Commit
7a417b0
·
1 Parent(s): 6165217
papers_with_abstracts_parallel.csv ADDED
The diff for this file is too large to render. See raw diff
 
src/streamlit_app.py CHANGED
@@ -1,40 +1,189 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ from sentence_transformers import SentenceTransformer, util
4
+ import torch
5
+ import os
6
+ from spellchecker import SpellChecker # Import the spellchecker library
7
+ from io import StringIO
8
+
9
+ # --- Configuration ---
10
+ EMBEDDING_MODEL = 'all-MiniLM-L6-v2'
11
+ EMBEDDINGS_FILE = 'paper_embeddings.pt'
12
+ DATA_FILE = 'papers_data.pkl'
13
+
14
+ # --- Data Loading and Preparation ---
15
+ # This is the raw data provided by the user.
16
+ # In a real application, you might load this from a CSV file.
17
+ CSV_FILE = 'papers_with_abstracts_parallel.csv'
18
+
19
+ # --- Caching Functions ---
20
+ @st.cache_resource
21
+ def load_embedding_model():
22
+ """Loads the Sentence Transformer model and caches it."""
23
+ return SentenceTransformer(EMBEDDING_MODEL)
24
+
25
+ @st.cache_resource
26
+ def load_spell_checker():
27
+ """Loads the SpellChecker object and caches it."""
28
+ return SpellChecker()
29
+
30
+ # --- Core Functions ---
31
+ def create_and_save_embeddings(model, data_df):
32
+ """
33
+ Generates and saves document embeddings and the dataframe.
34
+ This function is called only once if the files don't exist.
35
+ """
36
+ st.info("First time setup: Generating and saving embeddings. This may take a moment...")
37
+ # Combine title and abstract for richer embeddings
38
+ data_df['text_to_embed'] = data_df['title'] + ". " + data_df['abstract'].fillna('')
39
+
40
+ # Generate embeddings
41
+ corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
42
+
43
+ # Save embeddings and dataframe
44
+ torch.save(corpus_embeddings, EMBEDDINGS_FILE)
45
+ data_df.to_pickle(DATA_FILE)
46
+ st.success("Embeddings and data saved successfully!")
47
+ return corpus_embeddings, data_df
48
+
49
+ def load_data_and_embeddings():
50
+ """
51
+ Loads the saved embeddings and dataframe from disk.
52
+ If files don't exist, it calls the creation function.
53
+ """
54
+ model = load_embedding_model()
55
+ if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
56
+ corpus_embeddings = torch.load(EMBEDDINGS_FILE)
57
+ data_df = pd.read_pickle(DATA_FILE)
58
+ else:
59
+ # Load the raw data from the string
60
+ data_df = pd.read_csv(CSV_FILE)
61
+ corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
62
+
63
+ return model, corpus_embeddings, data_df
64
+
65
+ def correct_query_spelling(query, spell_checker):
66
+ """
67
+ Corrects potential spelling mistakes in the user's query.
68
+ """
69
+ if not query:
70
+ return ""
71
+
72
+ # Split the query into words
73
+ words = query.split()
74
+
75
+ # Find words that are likely misspelled
76
+ misspelled = spell_checker.unknown(words)
77
+
78
+ if not misspelled:
79
+ return query # Return original if no typos found
80
+
81
+ # Generate the corrected query
82
+ corrected_words = []
83
+ for word in words:
84
+ if word in misspelled:
85
+ corrected_word = spell_checker.correction(word)
86
+ # Use the correction, but fall back to the original word if no correction is found
87
+ corrected_words.append(corrected_word if corrected_word else word)
88
+ else:
89
+ corrected_words.append(word)
90
+
91
+ return " ".join(corrected_words)
92
+
93
+
94
+ def semantic_search(query, model, corpus_embeddings, data_df, top_k=10):
95
+ """
96
+ Performs semantic search on the loaded data.
97
+ """
98
+ if not query:
99
+ return []
100
+
101
+ # Encode the query
102
+ query_embedding = model.encode(query, convert_to_tensor=True)
103
+
104
+ # Calculate cosine similarity
105
+ cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
106
+
107
+ # Get the top k results, ensuring we don't ask for more results than exist
108
+ top_k = min(top_k, len(corpus_embeddings))
109
+ top_results = torch.topk(cos_scores, k=top_k)
110
+
111
+ # Format results
112
+ results = []
113
+ for score, idx in zip(top_results[0], top_results[1]):
114
+ item = data_df.iloc[idx.item()]
115
+ results.append({
116
+ "title": item["title"],
117
+ "authors": item["authors"],
118
+ "link": item["link"],
119
+ "abstract": item["abstract"],
120
+ "score": score.item() # Score is kept for potential future use but not displayed
121
+ })
122
+ return results
123
+
124
+ # --- Streamlit App UI ---
125
+ st.set_page_config(page_title="Semantic Paper Search", layout="wide")
126
+
127
+ st.title("📄 Semantic Research Paper Search")
128
+ st.markdown("""
129
+ Enter a query below to search through a small collection of ICML 2025 papers.
130
+ The search is performed by comparing the semantic meaning of your query with the papers' titles and abstracts.
131
+ Spelling mistakes in your query will be automatically corrected.
132
+ """)
133
+
134
+ # Load all necessary data
135
+ try:
136
+ model, corpus_embeddings, data_df = load_data_and_embeddings()
137
+ spell_checker = load_spell_checker()
138
+
139
+ # --- User Inputs: Search Bar and Slider ---
140
+ col1, col2 = st.columns([4, 1])
141
+ with col1:
142
+ search_query = st.text_input(
143
+ "Enter your search query:",
144
+ placeholder="e.g., maschine lerning modles for time series"
145
+ )
146
+ with col2:
147
+ top_k_results = st.number_input(
148
+ "Number of results",
149
+ min_value=1,
150
+ max_value=100, # Set a reasonable max
151
+ value=10,
152
+ help="Select the number of top results to display."
153
+ )
154
+
155
+ if search_query:
156
+ # --- Perform Typo Correction ---
157
+ corrected_query = correct_query_spelling(search_query, spell_checker)
158
+
159
+ # If a correction was made, notify the user
160
+ if corrected_query.lower() != search_query.lower():
161
+ st.info(f"Did you mean: **{corrected_query}**? \n\n*Showing results for the corrected query.*")
162
+
163
+ final_query = corrected_query
164
+
165
+ # --- Perform Search ---
166
+ search_results = semantic_search(final_query, model, corpus_embeddings, data_df, top_k=top_k_results)
167
+
168
+ st.subheader(f"Found {len(search_results)} results for '{final_query}'")
169
+
170
+ # --- Display Results ---
171
+ if search_results:
172
+ for result in search_results:
173
+ with st.container(border=True):
174
+ # Title as a clickable link
175
+ st.markdown(f"### [{result['title']}]({result['link']})")
176
+
177
+ # Authors
178
+ st.caption(f"**Authors:** {result['authors']}")
179
+
180
+ # Expander for the abstract
181
+ if pd.notna(result['abstract']):
182
+ with st.expander("View Abstract"):
183
+ st.write(result['abstract'])
184
+ else:
185
+ st.warning("No results found. Try a different query.")
186
 
187
+ except Exception as e:
188
+ st.error(f"An error occurred: {e}")
189
+ st.info("Please ensure all required libraries are installed (`pip install streamlit pandas sentence-transformers torch pyspellchecker`) and try again.")