File size: 7,409 Bytes
6611ead ca806f9 64e9557 7a417b0 ce35c00 7a417b0 ca806f9 c25929c 36e5a14 de8365c 7a417b0 1c0cf1d 7a417b0 53b63ed 7a417b0 2c1d8e4 53b63ed 7a417b0 ce35c00 7a417b0 ce35c00 7a417b0 ce35c00 7a417b0 ce35c00 7a417b0 ce35c00 7a417b0 64e9557 7a417b0 ce35c00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import os
from pathlib import Path
import streamlit as st
import pandas as pd
from sentence_transformers import SentenceTransformer, util
import torch
from spellchecker import SpellChecker
from io import StringIO
# --- Configuration ---
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
# Use /tmp directory for temporary files in Hugging Face Spaces
script_dir = Path(__file__).parent
# Define paths relative to the script's directory
DATA_FILE = './src/papers_data.pkl'
EMBEDDINGS_FILE = './src/paper_embeddings.pt'
CSV_FILE = 'papers_with_abstracts_parallel.csv'
# --- Caching Functions ---
def load_embedding_model():
"""Loads the Sentence Transformer model and caches it."""
return SentenceTransformer(EMBEDDING_MODEL)
def load_spell_checker():
"""Loads the SpellChecker object and caches it."""
return SpellChecker()
# --- Core Functions ---
def create_and_save_embeddings(model, data_df):
"""
Generates and saves document embeddings and the dataframe.
This function is called only once if the files don't exist.
"""
st.info("First time setup: Generating and saving embeddings. This may take a moment...")
# Combine title and abstract for richer embeddings
data_df['text_to_embed'] = data_df['title'] + ". " + data_df['abstract'].fillna('')
# Generate embeddings
corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
# Save embeddings and dataframe to /tmp directory
try:
torch.save(corpus_embeddings, EMBEDDINGS_FILE)
data_df.to_pickle(DATA_FILE)
st.success("Embeddings and data saved successfully!")
except Exception as e:
st.warning(f"Could not save embeddings to disk: {e}. Will regenerate on each session.")
return corpus_embeddings, data_df
def load_data_and_embeddings():
"""
Loads the saved embeddings and dataframe from disk.
If files don't exist, it calls the creation function.
"""
model = load_embedding_model()
# Check if files exist and are readable
if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
try:
corpus_embeddings = torch.load(EMBEDDINGS_FILE)
data_df = pd.read_pickle(DATA_FILE)
return model, corpus_embeddings, data_df
except Exception as e:
st.warning(f"Could not load saved embeddings: {e}. Regenerating...")
# Load the raw data from CSV
try:
data_df = pd.read_csv(CSV_FILE)
corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
except FileNotFoundError:
st.error(f"CSV file '{CSV_FILE}' not found. Please ensure it's in your repository.")
st.stop()
except Exception as e:
st.error(f"Error loading data: {e}")
st.stop()
return model, corpus_embeddings, data_df
def correct_query_spelling(query, spell_checker):
"""
Corrects potential spelling mistakes in the user's query.
"""
if not query:
return ""
# Split the query into words
words = query.split()
# Find words that are likely misspelled
misspelled = spell_checker.unknown(words)
if not misspelled:
return query # Return original if no typos found
# Generate the corrected query
corrected_words = []
for word in words:
if word in misspelled:
corrected_word = spell_checker.correction(word)
# Use the correction, but fall back to the original word if no correction is found
corrected_words.append(corrected_word if corrected_word else word)
else:
corrected_words.append(word)
return " ".join(corrected_words)
def semantic_search(query, model, corpus_embeddings, data_df, top_k=10):
"""
Performs semantic search on the loaded data.
"""
if not query:
return []
# Encode the query
query_embedding = model.encode(query, convert_to_tensor=True)
# Calculate cosine similarity
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
# Get the top k results, ensuring we don't ask for more results than exist
top_k = min(top_k, len(corpus_embeddings))
top_results = torch.topk(cos_scores, k=top_k)
# Format results
results = []
for score, idx in zip(top_results[0], top_results[1]):
item = data_df.iloc[idx.item()]
results.append({
"title": item["title"],
"authors": item["authors"],
"link": item["link"],
"abstract": item["abstract"],
"score": score.item() # Score is kept for potential future use but not displayed
})
return results
# --- Streamlit App UI ---
st.set_page_config(page_title="Semantic Paper Search", layout="wide")
st.title("π Semantic Research Paper Search")
st.markdown("""
Enter a query below to search through a small collection of ICML 2025 papers.
The search is performed by comparing the semantic meaning of your query with the papers' titles and abstracts.
Spelling mistakes in your query will be automatically corrected.
""")
# Load all necessary data
try:
model, corpus_embeddings, data_df = load_data_and_embeddings()
spell_checker = load_spell_checker()
# --- User Inputs: Search Bar and Slider ---
col1, col2 = st.columns([4, 1])
with col1:
search_query = st.text_input(
"Enter your search query:",
placeholder="e.g., machine learning models for time series"
)
with col2:
top_k_results = st.number_input(
"Number of results",
min_value=1,
max_value=100, # Set a reasonable max
value=10,
help="Select the number of top results to display."
)
if search_query:
# --- Perform Typo Correction ---
corrected_query = correct_query_spelling(search_query, spell_checker)
# If a correction was made, notify the user
if corrected_query.lower() != search_query.lower():
st.info(f"Did you mean: **{corrected_query}**? \n\n*Showing results for the corrected query.*")
final_query = corrected_query
# --- Perform Search ---
search_results = semantic_search(final_query, model, corpus_embeddings, data_df, top_k=top_k_results)
st.subheader(f"Found {len(search_results)} results for '{final_query}'")
# --- Display Results ---
if search_results:
for result in search_results:
with st.container(border=True):
# Title as a clickable link
st.markdown(f"### [{result['title']}]({result['link']})")
# Authors
st.caption(f"**Authors:** {result['authors']}")
# Expander for the abstract
if pd.notna(result['abstract']):
with st.expander("View Abstract"):
st.write(result['abstract'])
else:
st.warning("No results found. Try a different query.")
except Exception as e:
st.error(f"An error occurred: {e}")
st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.") |