Best_Guess / app.py
Chri12345's picture
Update app.py
1f9c919 verified
import os
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
from transformers import pipeline
import streamlit as st
# Cache dataset loading
@st.cache_data
def load_data(dataset_id="sentence-transformers/natural-questions", split="train"):
return load_dataset(dataset_id, split=split)
# Cache model loading
@st.cache_resource
def load_model():
return SentenceTransformer('allenai-specter')
# Cache corpus embedding generation
@st.cache_data
def generate_embeddings(_model, _dataset_file, sample_size=32):
# Prepare paper texts by combining query and answer fields
paper_texts = [
record['query'] + '[SEP]' + record['answer'] for record in _dataset_file.select(range(sample_size))
]
# Compute embeddings for all paper texts
return paper_texts, _model.encode(paper_texts, convert_to_tensor=True, show_progress_bar=True)
# Cache summarization pipeline
@st.cache_resource
def load_summarizer():
return pipeline("summarization")
# Streamlit app
st.title("Semantic Search with Summarization")
# Load resources
dataset_file = load_data()
model = load_model()
paper_texts, corpus_embeddings = generate_embeddings(model, dataset_file)
summarizer = load_summarizer()
# Function to search and summarize
def search_papers_and_summarize(query, max_summary_length=45):
# Encode the query
query_embedding = model.encode(query, convert_to_tensor=True)
# Perform semantic search
search_hits = util.semantic_search(query_embedding, corpus_embeddings)
search_hits = search_hits[0] # Get the hits for the first query
# Collect answers from top hits
answers = []
for hit in search_hits[:5]: # Limit to top 5 results
related_text = dataset_file[int(hit['corpus_id'])]
answers.append(related_text['answer'])
# Combine answers into a single text for summarization
combined_text = " ".join(answers)
# Summarize the combined text
summary = summarizer(combined_text, max_length=max_summary_length, clean_up_tokenization_spaces=True)
return summary[0]['summary_text']
# Streamlit input
query = st.text_input("Enter your query:", "")
if query:
st.write("Searching for relevant answers...")
summary = search_papers_and_summarize(query)
st.subheader("Summary")
st.write(summary)