ernestchu's picture
i
0d13d40
import itertools
import re
from collections import Counter, defaultdict
from typing import Dict, List, NamedTuple
import argparse
import sys
import time
import threading
import itertools
import gradio as gr
import numpy as np
from numpy.linalg import norm
import nltk
from nltk.stem.snowball import SnowballStemmer
from nltk.tokenize import word_tokenize
nltk.download('punkt_tab')
def spinner(stop_event):
spinner_chars = itertools.cycle(['-', '\\', '|', '/'])
sys.stdout.write(f'{next(spinner_chars)}')
sys.stdout.flush()
time.sleep(0.1)
while not stop_event.is_set():
sys.stdout.write(f'\b{next(spinner_chars)}')
sys.stdout.flush()
time.sleep(0.1)
print(f'\b \n')
# Create a threading event to stop the spinner
stop_event = threading.Event()
### File IO and processing
class Document(NamedTuple):
doc_id: int
author: List[str]
title: List[str]
keyword: List[str]
abstract: List[str]
def sections(self):
return [self.author, self.title, self.keyword, self.abstract]
def __repr__(self):
return (f"doc_id: {self.doc_id}\n" +
f" author: {self.author}\n" +
f" title: {self.title}\n" +
f" keyword: {self.keyword}\n" +
f" abstract: {self.abstract}")
def read_stopwords(file):
with open(file) as f:
return set([x.strip() for x in f.readlines()])
stopwords = read_stopwords('common_words')
stemmer = SnowballStemmer('english')
def read_rels(file):
'''
Reads the file of related documents and returns a dictionary of query id -> list of related documents
'''
rels = {}
with open(file) as f:
for line in f:
qid, rel = line.strip().split()
qid = int(qid)
rel = int(rel)
if qid not in rels:
rels[qid] = []
rels[qid].append(rel)
return rels
def read_docs(file):
'''
Reads the corpus into a list of Documents
'''
docs = [defaultdict(list)] # empty 0 index
category = ''
with open(file) as f:
i = 0
for line in f:
line = line.strip()
if line.startswith('.I'):
i = int(line[3:])
docs.append(defaultdict(list))
elif re.match(r'\.\w', line):
category = line[1]
elif line != '':
for word in word_tokenize(line):
docs[i][category].append(word.lower())
return [Document(i + 1, d['A'], d['T'], d['K'], d['W'])
for i, d in enumerate(docs[1:])]
def read_docs_for_presentation(file):
docs = [defaultdict(str)] # empty 0 index
category = ''
with open(file) as f:
i = 0
for line in f:
line = line.strip()
if line.startswith('.I'):
i = int(line[3:])
docs.append(defaultdict(str))
elif re.match(r'\.\w', line):
category = line[1]
elif line != '':
if docs[i][category] == '':
docs[i][category] = line
else:
if docs[i][category][-1] == '.':
docs[i][category] = f'{docs[i][category]} {line}'
else:
docs[i][category] = f'{docs[i][category]}. {line}'
return [Document(i + 1, d['A'], d['T'], d['K'], d['W'])
for i, d in enumerate(docs[1:])]
def stem_doc(doc: Document):
return Document(doc.doc_id, *[[stemmer.stem(word) for word in sec]
for sec in doc.sections()])
def stem_docs(docs: List[Document]):
return [stem_doc(doc) for doc in docs]
def remove_stopwords_doc(doc: Document):
return Document(doc.doc_id, *[[word for word in sec if word not in stopwords]
for sec in doc.sections()])
def remove_stopwords(docs: List[Document]):
return [remove_stopwords_doc(doc) for doc in docs]
### Term-Document Matrix
class TermWeights(NamedTuple):
author: float
title: float
keyword: float
abstract: float
def compute_doc_freqs(docs: List[Document]):
'''
Computes document frequency, i.e. how many documents contain a specific word
'''
freq = Counter()
for doc in docs:
words = set()
for sec in doc.sections():
for word in sec:
words.add(word)
for word in words:
freq[word] += 1
return freq
def compute_tf(doc: Document, doc_freqs: Dict[str, int], weights: list):
vec = defaultdict(float)
for word in doc.author:
vec[word] += weights.author
for word in doc.keyword:
vec[word] += weights.keyword
for word in doc.title:
vec[word] += weights.title
for word in doc.abstract:
vec[word] += weights.abstract
return dict(vec) # convert back to a regular dict
def compute_tfidf(doc, doc_freqs, weights):
tfidf = defaultdict(float)
tf = compute_tf(doc, doc_freqs, weights)
N = 3204
for word in tf:
idf = np.log((1+N) / (1+doc_freqs[word]))
tfidf[word] = tf[word] * idf
return dict(tfidf) # convert back to a regular dict
def compute_boolean(doc, doc_freqs, weights):
vec = defaultdict(float)
for word in doc.author:
vec[word] = weights.author
for word in doc.keyword:
vec[word] = weights.keyword
for word in doc.title:
vec[word] = weights.title
for word in doc.abstract:
vec[word] = weights.abstract
return dict(vec) # convert back to a regular dict
### Vector Similarity
def dictdot(x: Dict[str, float], y: Dict[str, float]):
'''
Computes the dot product of vectors x and y, represented as sparse dictionaries.
'''
keys = list(x.keys()) if len(x) < len(y) else list(y.keys())
return sum(x.get(key, 0) * y.get(key, 0) for key in keys)
def cosine_sim_dict(x, y):
'''
Computes the cosine similarity between two sparse term vectors represented as dictionaries.
'''
num = dictdot(x, y)
if num == 0:
return 0
return num / (norm(list(x.values())) * norm(list(y.values())))
def cosine_sim(x, y):
if isinstance(x, dict):
return cosine_sim_dict(x, y)
denom = np.linalg.norm(x) * np.linalg.norm(y)
if denom == 0:
return 0
return np.dot(x, y) / denom
def dice_sim(x, y):
raise NotImplementedError
num = 2 * dictdot(x, y)
if num == 0:
return 0
denom = sum(list(x.values())) + sum(list(y.values()))
ret = num / denom if denom != 0 else 0
# if ret > 1 or ret < 0:
# breakpoint()
return ret
def jaccard_sim(x, y):
raise NotImplementedError
num = dictdot(x, y)
if num == 0:
return 0
# denom = norm(list(x.values())) ** 2 + norm(list(y.values())) ** 2 - num
denom = sum(list(x.values())) + sum(list(y.values())) - num
ret = num / denom if denom != 0 else 0
# if ret > 1 or ret < 0:
# breakpoint()
return ret
def overlap_sim(x, y):
raise NotImplementedError
num = dictdot(x, y)
if num == 0:
return 0
# denom = min(norm(list(x.values())) ** 2, norm(list(y.values())) ** 2)
denom = min(sum(list(x.values())), sum(list(y.values())))
ret = num / denom if denom != 0 else 0
# if ret > 1 or ret < 0:
# breakpoint()
return ret
### Precision/Recall
def interpolate(x1, y1, x2, y2, x):
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return m * x + b
def precision_at(recall: float, results: List[int], relevant: List[int]) -> float:
'''
This function should compute the precision at the specified recall level.
If the recall level is in between two points, you should do a linear interpolation
between the two closest points. For example, if you have 4 results
(recall 0.25, 0.5, 0.75, and 1.0), and you need to compute recall @ 0.6, then do something like
interpolate(0.5, prec @ 0.5, 0.75, prec @ 0.75, 0.6)
Note that there is implicitly a point (recall=0, precision=1).
`results` is a sorted list of document ids
`relevant` is a list of relevant documents
'''
assert recall >= 0 and recall <= 1, f'Invalid recall: {recall}'
recalls = [0]
precisions = [1]
recalls += [(i+1) / len(relevant) for i in range(len(relevant))]
ranks = sorted([results.index(rel)+1 for rel in relevant])
precisions += [(i+1) / rk for i, rk in enumerate(ranks)]
idx = 0
for i, rec in enumerate(recalls):
if recall > rec:
idx = i
r1 = recalls[idx]
r2 = recalls[idx+1]
val = interpolate(r1, precisions[idx], r2, precisions[idx+1], recall)
return val
def mean_precision1(results, relevant):
return (precision_at(0.25, results, relevant) +
precision_at(0.5, results, relevant) +
precision_at(0.75, results, relevant)) / 3
def mean_precision2(results, relevant):
return sum([precision_at((i+1)/10, results, relevant) for i in range(10)]) / 10
def norm_recall(results, relevant):
N = len(results)
num_rel = len(relevant)
ranks = [results.index(rel) + 1 for rel in relevant]
return 1 - (sum([ranks[i] for i in range(num_rel)]) - sum([i+1 for i in range(num_rel)])) / num_rel / (N - num_rel)
def norm_precision(results, relevant):
N = len(results)
num_rel = len(relevant)
ranks = [results.index(rel) + 1 for rel in relevant]
denum = N * np.log(N) - (N - num_rel) * np.log(N - num_rel) - num_rel * np.log(num_rel)
return 1 - (sum([np.log(ranks[i]) for i in range(num_rel)]) - sum([np.log(i+1) for i in range(num_rel)])) / denum
### Extensions
# TODO: put any extensions here
def to_full_matrix(doc_vectors):
'''
Converts a list of sparse term vectors into a full term-document matrix.
'''
# a set of words in all documents
words = set()
for doc_vec in doc_vectors:
words.update(doc_vec.keys())
words = list(words)
matrix = np.zeros((len(doc_vectors), len(words)))
for i, doc_vec in enumerate(doc_vectors):
for word, val in doc_vec.items():
matrix[i, words.index(word)] = val
return matrix, words
def sparse_svd(doc_vectors, rank):
doc_matrix, words = to_full_matrix(doc_vectors)
_, _, Vt = np.linalg.svd(doc_matrix)
Vt_k = Vt[:rank, :]
doc_matrix = doc_matrix @ Vt_k.T
def project_fn(input_vector):
output_vector = np.zeros(len(words))
for word, val in input_vector.items():
if word in words:
output_vector[words.index(word)] = val
return output_vector @ Vt_k.T
return [vec for vec in doc_matrix], project_fn
def formated_output_for_doc(doc):
res = ''
res = res + '# ' + ' '.join(doc.title) + '\n'
if doc.author:
res = res + ' by ' + ' '.join(doc.author) + '\n'
if doc.abstract:
res = res + ' ' + ' '.join(doc.abstract) + '\n'
return res
### Search
def setup():
# args = parse_args()
args = argparse.Namespace(use_svd=True, svd_rank=3000)
print('Starting search engine ', end='')
if args.use_svd:
print('(with SVD) ', end='')
# Start the spinner in a separate thread
spinner_thread = threading.Thread(target=spinner, args=(stop_event,))
spinner_thread.start()
docs = read_docs('cacm.raw')
# queries = read_docs('query.raw')
# rels = read_rels('query.rels')
stopwords = read_stopwords('common_words')
term_func = compute_tfidf
sim_func = cosine_sim
svd_rank = args.svd_rank
# for svd_rank, term, stem, removestop, sim, term_weights in itertools.product(*permutations):
stem = True
removestop = True
term_weights = TermWeights(author=3, title=3, keyword=4, abstract=1)
processed_docs = process_docs(docs, stem, removestop, stopwords)
doc_freqs = compute_doc_freqs(processed_docs)
doc_vectors = [term_func(doc, doc_freqs, term_weights) for doc in processed_docs]
if args.use_svd:
doc_vectors, svd_project_fn = sparse_svd(doc_vectors, svd_rank)
# Stop the spinner
stop_event.set()
spinner_thread.join()
def search_query(query):
tmp_query_file = '/tmp/irhw2'
with open(tmp_query_file, 'w') as f:
print(f"""
.I 1
.W
{query}
""", file=f)
queries = read_docs(tmp_query_file)
processed_queries = process_docs(queries, stem, removestop, stopwords)
query = processed_queries[0]
query_vec = term_func(query, doc_freqs, term_weights)
if args.use_svd:
query_vec = svd_project_fn(query_vec)
results = search(doc_vectors, query_vec, sim_func)
return results
docs_present = read_docs_for_presentation('cacm.raw')
return search_query, docs_present
def process_docs(docs, stem, removestop, stopwords):
processed_docs = docs
if removestop:
processed_docs = remove_stopwords(processed_docs)
if stem:
processed_docs = stem_docs(processed_docs)
return processed_docs
def process_docs_and_queries(docs, queries, stem, removestop, stopwords):
processed_docs = docs
processed_queries = queries
if removestop:
processed_docs = remove_stopwords(processed_docs)
processed_queries = remove_stopwords(processed_queries)
if stem:
processed_docs = stem_docs(processed_docs)
processed_queries = stem_docs(processed_queries)
return processed_docs, processed_queries
def search(doc_vectors, query_vec, sim):
results_with_score = [(doc_id + 1, sim(query_vec, doc_vec))
for doc_id, doc_vec in enumerate(doc_vectors)]
results_with_score = sorted(results_with_score, key=lambda x: -x[1])
return results_with_score
results = [x[0] for x in results_with_score]
return results
def search_debug(docs, query, relevant, doc_vectors, query_vec, sim):
results_with_score = [(doc_id + 1, sim(query_vec, doc_vec))
for doc_id, doc_vec in enumerate(doc_vectors)]
results_with_score = sorted(results_with_score, key=lambda x: -x[1])
results = [x[0] for x in results_with_score]
print('Query:', query)
print('Relevant docs: ', relevant)
print()
for doc_id, score in results_with_score[:10]:
print('Score:', score)
print(docs[doc_id - 1])
print()
def parse_args():
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--use_svd', action='store_true')
arg_parser.add_argument('--svd_rank', type=int, default=3000)
return arg_parser.parse_args()
search_query, docs = setup()
with gr.Blocks() as demo:
gr.Markdown("# Search Engine")
with gr.Row():
query = gr.Textbox(label="Query", autofocus=True)
# with gr.Row():
# search_results = gr.Textbox(lines=5, label="Results")
#
num_results_step = 5
num_results = gr.State(num_results_step)
@gr.render(inputs=[query, num_results], triggers=[query.submit, num_results.change])
def render_results(query, num_res):
if query.strip() != '':
results = search_query(query)[:num_res]
for doc_id, score in results:
doc = docs[doc_id - 1]
html = f"""
<div style="margin: 30px 0">
<div style="display: flex; align-items: center; gap: 10px;">
<img src="https://www.cs.jhu.edu/favicon.ico" width="25px">
<div style="color: #202124; font-size: 14px;">{doc.author if doc.author.strip() else 'No author provided'}</div>
</div>
<div style="font-size: 20px; color: rgb(26, 13, 171); cursor: pointer; margin: 10px 0" onclick="alert('Just a mockup search engine, lol.')">{doc.title}</div>
<div style="color: rgb(71, 71, 71);">{doc.abstract if doc.abstract.strip() else 'No abstract provided'}<br>Relevance score: {score:.3f}</div>
</div>
"""
gr.HTML(html)
gr.HTML('<div style="margin: 50px"></div>')
# more_btn = gr.HTML('''
# <div style="display: flex;justify-content: center; margin: 40px">
# <div style="color: rgb(26, 13, 171); font-size: 18px; font-weight: 600; cursor: pointer">More like this</div>
# </div>''')
more_btn = gr.Button('More like this')
more_btn.click(lambda x: x + num_results_step, num_results, num_results)
query.change(lambda _: num_results_step, num_results, num_results)
if __name__ == '__main__':
demo.launch(
# server_name="0.0.0.0",
# server_port=7861,
)