File size: 7,409 Bytes
6611ead
ca806f9
64e9557
7a417b0
 
 
ce35c00
7a417b0
 
ca806f9
 
 
c25929c
 
 
36e5a14
 
de8365c
7a417b0
1c0cf1d
7a417b0
 
53b63ed
7a417b0
 
2c1d8e4
53b63ed
7a417b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce35c00
 
 
 
 
 
 
 
7a417b0
 
 
 
 
 
 
 
ce35c00
 
7a417b0
ce35c00
 
 
 
 
 
 
 
 
7a417b0
 
ce35c00
 
 
 
 
 
7a417b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce35c00
7a417b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64e9557
7a417b0
 
ce35c00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import os
from pathlib import Path
import streamlit as st
import pandas as pd
from sentence_transformers import SentenceTransformer, util
import torch
from spellchecker import SpellChecker
from io import StringIO

# --- Configuration ---
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
# Use /tmp directory for temporary files in Hugging Face Spaces
script_dir = Path(__file__).parent

# Define paths relative to the script's directory
DATA_FILE =  './src/papers_data.pkl'
EMBEDDINGS_FILE =  './src/paper_embeddings.pt'
CSV_FILE = 'papers_with_abstracts_parallel.csv'



# --- Caching Functions ---

def load_embedding_model():
    """Loads the Sentence Transformer model and caches it."""
    return SentenceTransformer(EMBEDDING_MODEL)


def load_spell_checker():
    """Loads the SpellChecker object and caches it."""
    return SpellChecker()

# --- Core Functions ---
def create_and_save_embeddings(model, data_df):
    """
    Generates and saves document embeddings and the dataframe.
    This function is called only once if the files don't exist.
    """
    st.info("First time setup: Generating and saving embeddings. This may take a moment...")
    # Combine title and abstract for richer embeddings
    data_df['text_to_embed'] = data_df['title'] + ". " + data_df['abstract'].fillna('')
    
    # Generate embeddings
    corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
    
    # Save embeddings and dataframe to /tmp directory
    try:
        torch.save(corpus_embeddings, EMBEDDINGS_FILE)
        data_df.to_pickle(DATA_FILE)
        st.success("Embeddings and data saved successfully!")
    except Exception as e:
        st.warning(f"Could not save embeddings to disk: {e}. Will regenerate on each session.")
    
    return corpus_embeddings, data_df

def load_data_and_embeddings():
    """
    Loads the saved embeddings and dataframe from disk.
    If files don't exist, it calls the creation function.
    """
    model = load_embedding_model()
    
    # Check if files exist and are readable
    if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
        try:
            corpus_embeddings = torch.load(EMBEDDINGS_FILE)
            data_df = pd.read_pickle(DATA_FILE)
            return model, corpus_embeddings, data_df
        except Exception as e:
            st.warning(f"Could not load saved embeddings: {e}. Regenerating...")
    
    # Load the raw data from CSV
    try:
        data_df = pd.read_csv(CSV_FILE)
        corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
    except FileNotFoundError:
        st.error(f"CSV file '{CSV_FILE}' not found. Please ensure it's in your repository.")
        st.stop()
    except Exception as e:
        st.error(f"Error loading data: {e}")
        st.stop()
        
    return model, corpus_embeddings, data_df

def correct_query_spelling(query, spell_checker):
    """
    Corrects potential spelling mistakes in the user's query.
    """
    if not query:
        return ""
    
    # Split the query into words
    words = query.split()
    
    # Find words that are likely misspelled
    misspelled = spell_checker.unknown(words)
    
    if not misspelled:
        return query # Return original if no typos found

    # Generate the corrected query
    corrected_words = []
    for word in words:
        if word in misspelled:
            corrected_word = spell_checker.correction(word)
            # Use the correction, but fall back to the original word if no correction is found
            corrected_words.append(corrected_word if corrected_word else word)
        else:
            corrected_words.append(word)
            
    return " ".join(corrected_words)

def semantic_search(query, model, corpus_embeddings, data_df, top_k=10):
    """
    Performs semantic search on the loaded data.
    """
    if not query:
        return []
        
    # Encode the query
    query_embedding = model.encode(query, convert_to_tensor=True)

    # Calculate cosine similarity
    cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]

    # Get the top k results, ensuring we don't ask for more results than exist
    top_k = min(top_k, len(corpus_embeddings))
    top_results = torch.topk(cos_scores, k=top_k)

    # Format results
    results = []
    for score, idx in zip(top_results[0], top_results[1]):
        item = data_df.iloc[idx.item()]
        results.append({
            "title": item["title"],
            "authors": item["authors"],
            "link": item["link"],
            "abstract": item["abstract"],
            "score": score.item() # Score is kept for potential future use but not displayed
        })
    return results

# --- Streamlit App UI ---
st.set_page_config(page_title="Semantic Paper Search", layout="wide")

st.title("πŸ“„ Semantic Research Paper Search")
st.markdown("""
Enter a query below to search through a small collection of ICML 2025 papers. 
The search is performed by comparing the semantic meaning of your query with the papers' titles and abstracts.
Spelling mistakes in your query will be automatically corrected.
""")

# Load all necessary data
try:
    model, corpus_embeddings, data_df = load_data_and_embeddings()
    spell_checker = load_spell_checker()

    # --- User Inputs: Search Bar and Slider ---
    col1, col2 = st.columns([4, 1])
    with col1:
        search_query = st.text_input(
            "Enter your search query:", 
            placeholder="e.g., machine learning models for time series"
        )
    with col2:
        top_k_results = st.number_input(
            "Number of results", 
            min_value=1, 
            max_value=100, # Set a reasonable max
            value=10, 
            help="Select the number of top results to display."
        )

    if search_query:
        # --- Perform Typo Correction ---
        corrected_query = correct_query_spelling(search_query, spell_checker)
        
        # If a correction was made, notify the user
        if corrected_query.lower() != search_query.lower():
            st.info(f"Did you mean: **{corrected_query}**? \n\n*Showing results for the corrected query.*")
        
        final_query = corrected_query
        
        # --- Perform Search ---
        search_results = semantic_search(final_query, model, corpus_embeddings, data_df, top_k=top_k_results)
        
        st.subheader(f"Found {len(search_results)} results for '{final_query}'")
        
        # --- Display Results ---
        if search_results:
            for result in search_results:
                with st.container(border=True):
                    # Title as a clickable link
                    st.markdown(f"### [{result['title']}]({result['link']})")
                    
                    # Authors
                    st.caption(f"**Authors:** {result['authors']}")
                    
                    # Expander for the abstract
                    if pd.notna(result['abstract']):
                        with st.expander("View Abstract"):
                            st.write(result['abstract'])
        else:
            st.warning("No results found. Try a different query.")

except Exception as e:
    st.error(f"An error occurred: {e}")
    st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.")