ICML2025 / src /streamlit_app.py
stefanoviel
changing folder path
36e5a14
raw
history blame
7.41 kB
import os
from pathlib import Path
import streamlit as st
import pandas as pd
from sentence_transformers import SentenceTransformer, util
import torch
from spellchecker import SpellChecker
from io import StringIO
# --- Configuration ---
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
# Use /tmp directory for temporary files in Hugging Face Spaces
script_dir = Path(__file__).parent
# Define paths relative to the script's directory
DATA_FILE = './src/papers_data.pkl'
EMBEDDINGS_FILE = './src/paper_embeddings.pt'
CSV_FILE = 'papers_with_abstracts_parallel.csv'
# --- Caching Functions ---
def load_embedding_model():
"""Loads the Sentence Transformer model and caches it."""
return SentenceTransformer(EMBEDDING_MODEL)
def load_spell_checker():
"""Loads the SpellChecker object and caches it."""
return SpellChecker()
# --- Core Functions ---
def create_and_save_embeddings(model, data_df):
"""
Generates and saves document embeddings and the dataframe.
This function is called only once if the files don't exist.
"""
st.info("First time setup: Generating and saving embeddings. This may take a moment...")
# Combine title and abstract for richer embeddings
data_df['text_to_embed'] = data_df['title'] + ". " + data_df['abstract'].fillna('')
# Generate embeddings
corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
# Save embeddings and dataframe to /tmp directory
try:
torch.save(corpus_embeddings, EMBEDDINGS_FILE)
data_df.to_pickle(DATA_FILE)
st.success("Embeddings and data saved successfully!")
except Exception as e:
st.warning(f"Could not save embeddings to disk: {e}. Will regenerate on each session.")
return corpus_embeddings, data_df
def load_data_and_embeddings():
"""
Loads the saved embeddings and dataframe from disk.
If files don't exist, it calls the creation function.
"""
model = load_embedding_model()
# Check if files exist and are readable
if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
try:
corpus_embeddings = torch.load(EMBEDDINGS_FILE)
data_df = pd.read_pickle(DATA_FILE)
return model, corpus_embeddings, data_df
except Exception as e:
st.warning(f"Could not load saved embeddings: {e}. Regenerating...")
# Load the raw data from CSV
try:
data_df = pd.read_csv(CSV_FILE)
corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
except FileNotFoundError:
st.error(f"CSV file '{CSV_FILE}' not found. Please ensure it's in your repository.")
st.stop()
except Exception as e:
st.error(f"Error loading data: {e}")
st.stop()
return model, corpus_embeddings, data_df
def correct_query_spelling(query, spell_checker):
"""
Corrects potential spelling mistakes in the user's query.
"""
if not query:
return ""
# Split the query into words
words = query.split()
# Find words that are likely misspelled
misspelled = spell_checker.unknown(words)
if not misspelled:
return query # Return original if no typos found
# Generate the corrected query
corrected_words = []
for word in words:
if word in misspelled:
corrected_word = spell_checker.correction(word)
# Use the correction, but fall back to the original word if no correction is found
corrected_words.append(corrected_word if corrected_word else word)
else:
corrected_words.append(word)
return " ".join(corrected_words)
def semantic_search(query, model, corpus_embeddings, data_df, top_k=10):
"""
Performs semantic search on the loaded data.
"""
if not query:
return []
# Encode the query
query_embedding = model.encode(query, convert_to_tensor=True)
# Calculate cosine similarity
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
# Get the top k results, ensuring we don't ask for more results than exist
top_k = min(top_k, len(corpus_embeddings))
top_results = torch.topk(cos_scores, k=top_k)
# Format results
results = []
for score, idx in zip(top_results[0], top_results[1]):
item = data_df.iloc[idx.item()]
results.append({
"title": item["title"],
"authors": item["authors"],
"link": item["link"],
"abstract": item["abstract"],
"score": score.item() # Score is kept for potential future use but not displayed
})
return results
# --- Streamlit App UI ---
st.set_page_config(page_title="Semantic Paper Search", layout="wide")
st.title("πŸ“„ Semantic Research Paper Search")
st.markdown("""
Enter a query below to search through a small collection of ICML 2025 papers.
The search is performed by comparing the semantic meaning of your query with the papers' titles and abstracts.
Spelling mistakes in your query will be automatically corrected.
""")
# Load all necessary data
try:
model, corpus_embeddings, data_df = load_data_and_embeddings()
spell_checker = load_spell_checker()
# --- User Inputs: Search Bar and Slider ---
col1, col2 = st.columns([4, 1])
with col1:
search_query = st.text_input(
"Enter your search query:",
placeholder="e.g., machine learning models for time series"
)
with col2:
top_k_results = st.number_input(
"Number of results",
min_value=1,
max_value=100, # Set a reasonable max
value=10,
help="Select the number of top results to display."
)
if search_query:
# --- Perform Typo Correction ---
corrected_query = correct_query_spelling(search_query, spell_checker)
# If a correction was made, notify the user
if corrected_query.lower() != search_query.lower():
st.info(f"Did you mean: **{corrected_query}**? \n\n*Showing results for the corrected query.*")
final_query = corrected_query
# --- Perform Search ---
search_results = semantic_search(final_query, model, corpus_embeddings, data_df, top_k=top_k_results)
st.subheader(f"Found {len(search_results)} results for '{final_query}'")
# --- Display Results ---
if search_results:
for result in search_results:
with st.container(border=True):
# Title as a clickable link
st.markdown(f"### [{result['title']}]({result['link']})")
# Authors
st.caption(f"**Authors:** {result['authors']}")
# Expander for the abstract
if pd.notna(result['abstract']):
with st.expander("View Abstract"):
st.write(result['abstract'])
else:
st.warning("No results found. Try a different query.")
except Exception as e:
st.error(f"An error occurred: {e}")
st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.")