from os import makedirs, remove | |
from os.path import exists, dirname | |
from functools import cache | |
import json | |
import streamlit as st | |
from googleapiclient.discovery import build | |
from slugify import slugify | |
from transformers import pipeline | |
import uuid | |
import spacy | |
from spacy.matcher import PhraseMatcher | |
from beautiful_soup.beautiful_soup import get_url_content | |
def google_search_api_request( query ): | |
""" | |
Request Google Search API with query and return results. | |
""" | |
api_key = st.secrets["google_search_api_key"] | |
cx = st.secrets["google_search_engine_id"] | |
service = build( | |
"customsearch", | |
"v1", | |
developerKey=api_key, | |
cache_discovery=False | |
) | |
# Exclude PDFs from search results. | |
query = query + ' -filetype:pdf' | |
return service.cse().list( | |
q=query, | |
cx=cx, | |
num=5, | |
lr='lang_en', # lang_de | |
fields='items(title,link),searchInformation(totalResults)' | |
).execute() | |
def search_results( query ): | |
""" | |
Request Google Search API with query and return results. Results are cached in files. | |
""" | |
file_path = 'search-results/' + slugify( query ) + '.json' | |
results = [] | |
makedirs(dirname(file_path), exist_ok=True) | |
if exists( file_path ): | |
with open( file_path, 'r' ) as results_file: | |
results = json.load( results_file ) | |
else: | |
search_result = google_search_api_request( query ) | |
if int( search_result['searchInformation']['totalResults'] ) > 0: | |
results = search_result['items'] | |
with open( file_path, 'w' ) as results_file: | |
json.dump( results, results_file ) | |
if len( results ) == 0: | |
raise Exception('No results found.') | |
return results | |
def get_summary( url_id, content ): | |
file_path = 'summaries/' + url_id + '.json' | |
makedirs(dirname(file_path), exist_ok=True) | |
if exists( file_path ): | |
with open( file_path, 'r' ) as file: | |
summary = json.load( file ) | |
else: | |
summary = generate_summary( content ) | |
with open( file_path, 'w' ) as file: | |
json.dump( summary, file ) | |
return summary | |
def generate_summary( content, max_length = 200 ): | |
""" | |
Generate summary for content. | |
""" | |
try: | |
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6") | |
# | |
summary = summarizer(content, max_length, min_length=30, do_sample=False, truncation=True) | |
except Exception as exception: | |
raise exception | |
return summary | |
def exception_notice( exception ): | |
""" | |
Helper function for exception notices. | |
""" | |
query_params = st.experimental_get_query_params() | |
if 'debug' in query_params.keys() and query_params['debug'][0] == 'true': | |
st.exception(exception) | |
else: | |
st.warning(str(exception)) | |
def is_keyword_in_string( keywords, string ): | |
""" | |
Checks if string contains keyword. | |
""" | |
for keyword in keywords: | |
if keyword in string: | |
return True | |
return False | |
def filter_sentences_by_keywords( strings, keywords ): | |
nlp = spacy.load("en_core_web_sm") | |
matcher = PhraseMatcher(nlp.vocab) | |
phrases = keywords | |
patterns = [nlp(phrase) for phrase in phrases] | |
matcher.add("QueryList", patterns) | |
sentences = [] | |
for string in strings: | |
# Exclude short sentences | |
string_length = len( string.split(' ') ) | |
if string_length < 5: | |
continue | |
doc = nlp(string) | |
for sentence in doc.sents: | |
matches = matcher(nlp(sentence.text)) | |
for match_id, start, end in matches: | |
if nlp.vocab.strings[match_id] in ["QueryList"]: | |
sentences.append(sentence.text) | |
return sentences | |
def split_content_into_chunks( sentences ): | |
""" | |
Split content into chunks. | |
""" | |
chunk = '' | |
word_count = 0 | |
chunks = [] | |
for sentence in sentences: | |
current_word_count = len(sentence.split(' ')) | |
if word_count + current_word_count > 512: | |
st.write("Number of words(tokens): {}".format(word_count)) | |
chunks.append(chunk) | |
chunk = '' | |
word_count = 0 | |
word_count += current_word_count | |
chunk += sentence + ' ' | |
st.write("Number of words(tokens): {}".format(word_count)) | |
chunks.append(chunk) | |
return chunks | |
def main(): | |
st.title('Racoon Search') | |
query = st.text_input('Search query') | |
query_params = st.experimental_get_query_params() | |
if query : | |
with st.spinner('Loading search results...'): | |
try: | |
results = search_results( query ) | |
except Exception as exception: | |
exception_notice(exception) | |
return | |
number_of_results = len( results ) | |
st.success( 'Found {} results for "{}".'.format( number_of_results, query ) ) | |
if 'debug' in query_params.keys() and query_params['debug'][0] == 'true': | |
with st.expander("Search results JSON"): | |
if st.button('Delete search result cache', key=query + 'cache'): | |
remove( 'search-results/' + slugify( query ) + '.json' ) | |
st.json( results ) | |
progress_bar = st.progress(0) | |
st.header('Search results') | |
st.markdown('---') | |
# for result in results: | |
for index, result in enumerate(results): | |
with st.container(): | |
st.markdown('### ' + result['title']) | |
url_id = uuid.uuid5( uuid.NAMESPACE_URL, result['link'] ).hex | |
try: | |
strings = get_url_content( result['link'] ) | |
keywords = query.split(' ') | |
sentences = filter_sentences_by_keywords( strings, keywords ) | |
chunks = split_content_into_chunks( sentences ) | |
number_of_chunks = len( chunks ) | |
if number_of_chunks > 1: | |
max_length = int( 512 / len( chunks ) ) | |
st.write("Max length: {}".format(max_length)) | |
content = '' | |
for chunk in chunks: | |
chunk_length = len( chunk.split(' ') ) | |
chunk_max_length = 200 | |
if chunk_length < max_length: | |
chunk_max_length = int( chunk_length / 2 ) | |
chunk_summary = generate_summary( chunk, min( max_length, chunk_max_length ) ) | |
for summary in chunk_summary: | |
content += summary['summary_text'] + ' ' | |
else: | |
content = chunks[0] | |
summary = get_summary( url_id, content ) | |
except Exception as exception: | |
exception_notice(exception) | |
progress_bar.progress( ( index + 1 ) / number_of_results ) | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.markdown('[Website Link]({})'.format(result['link'])) | |
with col2: | |
if st.button('Delete content from cache', key=url_id + 'content'): | |
remove( 'page-content/' + url_id + '.txt' ) | |
with col3: | |
if st.button('Delete summary from cache', key=url_id + 'summary'): | |
remove( 'summaries/' + url_id + '.json' ) | |
st.markdown('---') | |
if __name__ == '__main__': | |
main() | |