|
import os |
|
from pathlib import Path |
|
import streamlit as st |
|
import pandas as pd |
|
from sentence_transformers import SentenceTransformer, util |
|
import torch |
|
from spellchecker import SpellChecker |
|
from io import StringIO |
|
|
|
|
|
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2' |
|
|
|
script_dir = Path(__file__).parent |
|
|
|
|
|
DATA_FILE = './src/papers_data.pkl' |
|
EMBEDDINGS_FILE = './src/paper_embeddings.pt' |
|
CSV_FILE = 'papers_with_abstracts_parallel.csv' |
|
|
|
|
|
|
|
|
|
|
|
def load_embedding_model(): |
|
"""Loads the Sentence Transformer model and caches it.""" |
|
return SentenceTransformer(EMBEDDING_MODEL) |
|
|
|
|
|
def load_spell_checker(): |
|
"""Loads the SpellChecker object and caches it.""" |
|
return SpellChecker() |
|
|
|
|
|
def create_and_save_embeddings(model, data_df): |
|
""" |
|
Generates and saves document embeddings and the dataframe. |
|
This function is called only once if the files don't exist. |
|
""" |
|
st.info("First time setup: Generating and saving embeddings. This may take a moment...") |
|
|
|
data_df['text_to_embed'] = data_df['title'] + ". " + data_df['abstract'].fillna('') |
|
|
|
|
|
corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True) |
|
|
|
|
|
try: |
|
torch.save(corpus_embeddings, EMBEDDINGS_FILE) |
|
data_df.to_pickle(DATA_FILE) |
|
st.success("Embeddings and data saved successfully!") |
|
except Exception as e: |
|
st.warning(f"Could not save embeddings to disk: {e}. Will regenerate on each session.") |
|
|
|
return corpus_embeddings, data_df |
|
|
|
def load_data_and_embeddings(): |
|
""" |
|
Loads the saved embeddings and dataframe from disk. |
|
If files don't exist, it calls the creation function. |
|
""" |
|
model = load_embedding_model() |
|
|
|
|
|
if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE): |
|
try: |
|
corpus_embeddings = torch.load(EMBEDDINGS_FILE) |
|
data_df = pd.read_pickle(DATA_FILE) |
|
return model, corpus_embeddings, data_df |
|
except Exception as e: |
|
st.warning(f"Could not load saved embeddings: {e}. Regenerating...") |
|
|
|
|
|
try: |
|
data_df = pd.read_csv(CSV_FILE) |
|
corpus_embeddings, data_df = create_and_save_embeddings(model, data_df) |
|
except FileNotFoundError: |
|
st.error(f"CSV file '{CSV_FILE}' not found. Please ensure it's in your repository.") |
|
st.stop() |
|
except Exception as e: |
|
st.error(f"Error loading data: {e}") |
|
st.stop() |
|
|
|
return model, corpus_embeddings, data_df |
|
|
|
def correct_query_spelling(query, spell_checker): |
|
""" |
|
Corrects potential spelling mistakes in the user's query. |
|
""" |
|
if not query: |
|
return "" |
|
|
|
|
|
words = query.split() |
|
|
|
|
|
misspelled = spell_checker.unknown(words) |
|
|
|
if not misspelled: |
|
return query |
|
|
|
|
|
corrected_words = [] |
|
for word in words: |
|
if word in misspelled: |
|
corrected_word = spell_checker.correction(word) |
|
|
|
corrected_words.append(corrected_word if corrected_word else word) |
|
else: |
|
corrected_words.append(word) |
|
|
|
return " ".join(corrected_words) |
|
|
|
def semantic_search(query, model, corpus_embeddings, data_df, top_k=10): |
|
""" |
|
Performs semantic search on the loaded data. |
|
""" |
|
if not query: |
|
return [] |
|
|
|
|
|
query_embedding = model.encode(query, convert_to_tensor=True) |
|
|
|
|
|
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0] |
|
|
|
|
|
top_k = min(top_k, len(corpus_embeddings)) |
|
top_results = torch.topk(cos_scores, k=top_k) |
|
|
|
|
|
results = [] |
|
for score, idx in zip(top_results[0], top_results[1]): |
|
item = data_df.iloc[idx.item()] |
|
results.append({ |
|
"title": item["title"], |
|
"authors": item["authors"], |
|
"link": item["link"], |
|
"abstract": item["abstract"], |
|
"score": score.item() |
|
}) |
|
return results |
|
|
|
|
|
st.set_page_config(page_title="Semantic Paper Search", layout="wide") |
|
|
|
st.title("π Semantic Research Paper Search") |
|
st.markdown(""" |
|
Enter a query below to search through a small collection of ICML 2025 papers. |
|
The search is performed by comparing the semantic meaning of your query with the papers' titles and abstracts. |
|
Spelling mistakes in your query will be automatically corrected. |
|
""") |
|
|
|
|
|
try: |
|
model, corpus_embeddings, data_df = load_data_and_embeddings() |
|
spell_checker = load_spell_checker() |
|
|
|
|
|
col1, col2 = st.columns([4, 1]) |
|
with col1: |
|
search_query = st.text_input( |
|
"Enter your search query:", |
|
placeholder="e.g., machine learning models for time series" |
|
) |
|
with col2: |
|
top_k_results = st.number_input( |
|
"Number of results", |
|
min_value=1, |
|
max_value=100, |
|
value=10, |
|
help="Select the number of top results to display." |
|
) |
|
|
|
if search_query: |
|
|
|
corrected_query = correct_query_spelling(search_query, spell_checker) |
|
|
|
|
|
if corrected_query.lower() != search_query.lower(): |
|
st.info(f"Did you mean: **{corrected_query}**? \n\n*Showing results for the corrected query.*") |
|
|
|
final_query = corrected_query |
|
|
|
|
|
search_results = semantic_search(final_query, model, corpus_embeddings, data_df, top_k=top_k_results) |
|
|
|
st.subheader(f"Found {len(search_results)} results for '{final_query}'") |
|
|
|
|
|
if search_results: |
|
for result in search_results: |
|
with st.container(border=True): |
|
|
|
st.markdown(f"### [{result['title']}]({result['link']})") |
|
|
|
|
|
st.caption(f"**Authors:** {result['authors']}") |
|
|
|
|
|
if pd.notna(result['abstract']): |
|
with st.expander("View Abstract"): |
|
st.write(result['abstract']) |
|
else: |
|
st.warning("No results found. Try a different query.") |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.") |