Spaces:
Sleeping
Sleeping
import itertools | |
import re | |
from collections import Counter, defaultdict | |
from typing import Dict, List, NamedTuple | |
import argparse | |
import sys | |
import time | |
import threading | |
import itertools | |
import gradio as gr | |
import numpy as np | |
from numpy.linalg import norm | |
import nltk | |
from nltk.stem.snowball import SnowballStemmer | |
from nltk.tokenize import word_tokenize | |
# nltk.download('punkt_tab') | |
def spinner(stop_event): | |
spinner_chars = itertools.cycle(['-', '\\', '|', '/']) | |
sys.stdout.write(f'{next(spinner_chars)}') | |
sys.stdout.flush() | |
time.sleep(0.1) | |
while not stop_event.is_set(): | |
sys.stdout.write(f'\b{next(spinner_chars)}') | |
sys.stdout.flush() | |
time.sleep(0.1) | |
print(f'\b \n') | |
# Create a threading event to stop the spinner | |
stop_event = threading.Event() | |
### File IO and processing | |
class Document(NamedTuple): | |
doc_id: int | |
author: List[str] | |
title: List[str] | |
keyword: List[str] | |
abstract: List[str] | |
def sections(self): | |
return [self.author, self.title, self.keyword, self.abstract] | |
def __repr__(self): | |
return (f"doc_id: {self.doc_id}\n" + | |
f" author: {self.author}\n" + | |
f" title: {self.title}\n" + | |
f" keyword: {self.keyword}\n" + | |
f" abstract: {self.abstract}") | |
def read_stopwords(file): | |
with open(file) as f: | |
return set([x.strip() for x in f.readlines()]) | |
stopwords = read_stopwords('common_words') | |
stemmer = SnowballStemmer('english') | |
def read_rels(file): | |
''' | |
Reads the file of related documents and returns a dictionary of query id -> list of related documents | |
''' | |
rels = {} | |
with open(file) as f: | |
for line in f: | |
qid, rel = line.strip().split() | |
qid = int(qid) | |
rel = int(rel) | |
if qid not in rels: | |
rels[qid] = [] | |
rels[qid].append(rel) | |
return rels | |
def read_docs(file): | |
''' | |
Reads the corpus into a list of Documents | |
''' | |
docs = [defaultdict(list)] # empty 0 index | |
category = '' | |
with open(file) as f: | |
i = 0 | |
for line in f: | |
line = line.strip() | |
if line.startswith('.I'): | |
i = int(line[3:]) | |
docs.append(defaultdict(list)) | |
elif re.match(r'\.\w', line): | |
category = line[1] | |
elif line != '': | |
for word in word_tokenize(line): | |
docs[i][category].append(word.lower()) | |
return [Document(i + 1, d['A'], d['T'], d['K'], d['W']) | |
for i, d in enumerate(docs[1:])] | |
def read_docs_for_presentation(file): | |
docs = [defaultdict(str)] # empty 0 index | |
category = '' | |
with open(file) as f: | |
i = 0 | |
for line in f: | |
line = line.strip() | |
if line.startswith('.I'): | |
i = int(line[3:]) | |
docs.append(defaultdict(str)) | |
elif re.match(r'\.\w', line): | |
category = line[1] | |
elif line != '': | |
if docs[i][category] == '': | |
docs[i][category] = line | |
else: | |
if docs[i][category][-1] == '.': | |
docs[i][category] = f'{docs[i][category]} {line}' | |
else: | |
docs[i][category] = f'{docs[i][category]}. {line}' | |
return [Document(i + 1, d['A'], d['T'], d['K'], d['W']) | |
for i, d in enumerate(docs[1:])] | |
def stem_doc(doc: Document): | |
return Document(doc.doc_id, *[[stemmer.stem(word) for word in sec] | |
for sec in doc.sections()]) | |
def stem_docs(docs: List[Document]): | |
return [stem_doc(doc) for doc in docs] | |
def remove_stopwords_doc(doc: Document): | |
return Document(doc.doc_id, *[[word for word in sec if word not in stopwords] | |
for sec in doc.sections()]) | |
def remove_stopwords(docs: List[Document]): | |
return [remove_stopwords_doc(doc) for doc in docs] | |
### Term-Document Matrix | |
class TermWeights(NamedTuple): | |
author: float | |
title: float | |
keyword: float | |
abstract: float | |
def compute_doc_freqs(docs: List[Document]): | |
''' | |
Computes document frequency, i.e. how many documents contain a specific word | |
''' | |
freq = Counter() | |
for doc in docs: | |
words = set() | |
for sec in doc.sections(): | |
for word in sec: | |
words.add(word) | |
for word in words: | |
freq[word] += 1 | |
return freq | |
def compute_tf(doc: Document, doc_freqs: Dict[str, int], weights: list): | |
vec = defaultdict(float) | |
for word in doc.author: | |
vec[word] += weights.author | |
for word in doc.keyword: | |
vec[word] += weights.keyword | |
for word in doc.title: | |
vec[word] += weights.title | |
for word in doc.abstract: | |
vec[word] += weights.abstract | |
return dict(vec) # convert back to a regular dict | |
def compute_tfidf(doc, doc_freqs, weights): | |
tfidf = defaultdict(float) | |
tf = compute_tf(doc, doc_freqs, weights) | |
N = 3204 | |
for word in tf: | |
idf = np.log((1+N) / (1+doc_freqs[word])) | |
tfidf[word] = tf[word] * idf | |
return dict(tfidf) # convert back to a regular dict | |
def compute_boolean(doc, doc_freqs, weights): | |
vec = defaultdict(float) | |
for word in doc.author: | |
vec[word] = weights.author | |
for word in doc.keyword: | |
vec[word] = weights.keyword | |
for word in doc.title: | |
vec[word] = weights.title | |
for word in doc.abstract: | |
vec[word] = weights.abstract | |
return dict(vec) # convert back to a regular dict | |
### Vector Similarity | |
def dictdot(x: Dict[str, float], y: Dict[str, float]): | |
''' | |
Computes the dot product of vectors x and y, represented as sparse dictionaries. | |
''' | |
keys = list(x.keys()) if len(x) < len(y) else list(y.keys()) | |
return sum(x.get(key, 0) * y.get(key, 0) for key in keys) | |
def cosine_sim_dict(x, y): | |
''' | |
Computes the cosine similarity between two sparse term vectors represented as dictionaries. | |
''' | |
num = dictdot(x, y) | |
if num == 0: | |
return 0 | |
return num / (norm(list(x.values())) * norm(list(y.values()))) | |
def cosine_sim(x, y): | |
if isinstance(x, dict): | |
return cosine_sim_dict(x, y) | |
return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y)) | |
def dice_sim(x, y): | |
raise NotImplementedError | |
num = 2 * dictdot(x, y) | |
if num == 0: | |
return 0 | |
denom = sum(list(x.values())) + sum(list(y.values())) | |
ret = num / denom if denom != 0 else 0 | |
# if ret > 1 or ret < 0: | |
# breakpoint() | |
return ret | |
def jaccard_sim(x, y): | |
raise NotImplementedError | |
num = dictdot(x, y) | |
if num == 0: | |
return 0 | |
# denom = norm(list(x.values())) ** 2 + norm(list(y.values())) ** 2 - num | |
denom = sum(list(x.values())) + sum(list(y.values())) - num | |
ret = num / denom if denom != 0 else 0 | |
# if ret > 1 or ret < 0: | |
# breakpoint() | |
return ret | |
def overlap_sim(x, y): | |
raise NotImplementedError | |
num = dictdot(x, y) | |
if num == 0: | |
return 0 | |
# denom = min(norm(list(x.values())) ** 2, norm(list(y.values())) ** 2) | |
denom = min(sum(list(x.values())), sum(list(y.values()))) | |
ret = num / denom if denom != 0 else 0 | |
# if ret > 1 or ret < 0: | |
# breakpoint() | |
return ret | |
### Precision/Recall | |
def interpolate(x1, y1, x2, y2, x): | |
m = (y2 - y1) / (x2 - x1) | |
b = y1 - m * x1 | |
return m * x + b | |
def precision_at(recall: float, results: List[int], relevant: List[int]) -> float: | |
''' | |
This function should compute the precision at the specified recall level. | |
If the recall level is in between two points, you should do a linear interpolation | |
between the two closest points. For example, if you have 4 results | |
(recall 0.25, 0.5, 0.75, and 1.0), and you need to compute recall @ 0.6, then do something like | |
interpolate(0.5, prec @ 0.5, 0.75, prec @ 0.75, 0.6) | |
Note that there is implicitly a point (recall=0, precision=1). | |
`results` is a sorted list of document ids | |
`relevant` is a list of relevant documents | |
''' | |
assert recall >= 0 and recall <= 1, f'Invalid recall: {recall}' | |
recalls = [0] | |
precisions = [1] | |
recalls += [(i+1) / len(relevant) for i in range(len(relevant))] | |
ranks = sorted([results.index(rel)+1 for rel in relevant]) | |
precisions += [(i+1) / rk for i, rk in enumerate(ranks)] | |
idx = 0 | |
for i, rec in enumerate(recalls): | |
if recall > rec: | |
idx = i | |
r1 = recalls[idx] | |
r2 = recalls[idx+1] | |
val = interpolate(r1, precisions[idx], r2, precisions[idx+1], recall) | |
return val | |
def mean_precision1(results, relevant): | |
return (precision_at(0.25, results, relevant) + | |
precision_at(0.5, results, relevant) + | |
precision_at(0.75, results, relevant)) / 3 | |
def mean_precision2(results, relevant): | |
return sum([precision_at((i+1)/10, results, relevant) for i in range(10)]) / 10 | |
def norm_recall(results, relevant): | |
N = len(results) | |
num_rel = len(relevant) | |
ranks = [results.index(rel) + 1 for rel in relevant] | |
return 1 - (sum([ranks[i] for i in range(num_rel)]) - sum([i+1 for i in range(num_rel)])) / num_rel / (N - num_rel) | |
def norm_precision(results, relevant): | |
N = len(results) | |
num_rel = len(relevant) | |
ranks = [results.index(rel) + 1 for rel in relevant] | |
denum = N * np.log(N) - (N - num_rel) * np.log(N - num_rel) - num_rel * np.log(num_rel) | |
return 1 - (sum([np.log(ranks[i]) for i in range(num_rel)]) - sum([np.log(i+1) for i in range(num_rel)])) / denum | |
### Extensions | |
# TODO: put any extensions here | |
def to_full_matrix(doc_vectors): | |
''' | |
Converts a list of sparse term vectors into a full term-document matrix. | |
''' | |
# a set of words in all documents | |
words = set() | |
for doc_vec in doc_vectors: | |
words.update(doc_vec.keys()) | |
words = list(words) | |
matrix = np.zeros((len(doc_vectors), len(words))) | |
for i, doc_vec in enumerate(doc_vectors): | |
for word, val in doc_vec.items(): | |
matrix[i, words.index(word)] = val | |
return matrix, words | |
def sparse_svd(doc_vectors, rank): | |
doc_matrix, words = to_full_matrix(doc_vectors) | |
_, _, Vt = np.linalg.svd(doc_matrix) | |
Vt_k = Vt[:rank, :] | |
doc_matrix = doc_matrix @ Vt_k.T | |
def project_fn(input_vector): | |
output_vector = np.zeros(len(words)) | |
for word, val in input_vector.items(): | |
if word in words: | |
output_vector[words.index(word)] = val | |
return output_vector @ Vt_k.T | |
return [vec for vec in doc_matrix], project_fn | |
def formated_output_for_doc(doc): | |
res = '' | |
res = res + '# ' + ' '.join(doc.title) + '\n' | |
if doc.author: | |
res = res + ' by ' + ' '.join(doc.author) + '\n' | |
if doc.abstract: | |
res = res + ' ' + ' '.join(doc.abstract) + '\n' | |
return res | |
### Search | |
def setup(): | |
# args = parse_args() | |
args = argparse.Namespace(use_svd=True, svd_rank=3000) | |
print('Starting search engine ', end='') | |
if args.use_svd: | |
print('(with SVD) ', end='') | |
# Start the spinner in a separate thread | |
spinner_thread = threading.Thread(target=spinner, args=(stop_event,)) | |
spinner_thread.start() | |
docs = read_docs('cacm.raw') | |
# queries = read_docs('query.raw') | |
# rels = read_rels('query.rels') | |
stopwords = read_stopwords('common_words') | |
term_func = compute_tfidf | |
sim_func = cosine_sim | |
svd_rank = args.svd_rank | |
# for svd_rank, term, stem, removestop, sim, term_weights in itertools.product(*permutations): | |
stem = True | |
removestop = True | |
term_weights = TermWeights(author=3, title=3, keyword=4, abstract=1) | |
processed_docs = process_docs(docs, stem, removestop, stopwords) | |
doc_freqs = compute_doc_freqs(processed_docs) | |
doc_vectors = [term_func(doc, doc_freqs, term_weights) for doc in processed_docs] | |
if args.use_svd: | |
doc_vectors, svd_project_fn = sparse_svd(doc_vectors, svd_rank) | |
# Stop the spinner | |
stop_event.set() | |
spinner_thread.join() | |
def search_query(query): | |
tmp_query_file = '/tmp/irhw2' | |
with open(tmp_query_file, 'w') as f: | |
print(f""" | |
.I 1 | |
.W | |
{query} | |
""", file=f) | |
queries = read_docs(tmp_query_file) | |
processed_queries = process_docs(queries, stem, removestop, stopwords) | |
query = processed_queries[0] | |
query_vec = term_func(query, doc_freqs, term_weights) | |
if args.use_svd: | |
query_vec = svd_project_fn(query_vec) | |
results = search(doc_vectors, query_vec, sim_func) | |
return results | |
docs_present = read_docs_for_presentation('cacm.raw') | |
return search_query, docs_present | |
def process_docs(docs, stem, removestop, stopwords): | |
processed_docs = docs | |
if removestop: | |
processed_docs = remove_stopwords(processed_docs) | |
if stem: | |
processed_docs = stem_docs(processed_docs) | |
return processed_docs | |
def process_docs_and_queries(docs, queries, stem, removestop, stopwords): | |
processed_docs = docs | |
processed_queries = queries | |
if removestop: | |
processed_docs = remove_stopwords(processed_docs) | |
processed_queries = remove_stopwords(processed_queries) | |
if stem: | |
processed_docs = stem_docs(processed_docs) | |
processed_queries = stem_docs(processed_queries) | |
return processed_docs, processed_queries | |
def search(doc_vectors, query_vec, sim): | |
results_with_score = [(doc_id + 1, sim(query_vec, doc_vec)) | |
for doc_id, doc_vec in enumerate(doc_vectors)] | |
results_with_score = sorted(results_with_score, key=lambda x: -x[1]) | |
return results_with_score | |
results = [x[0] for x in results_with_score] | |
return results | |
def search_debug(docs, query, relevant, doc_vectors, query_vec, sim): | |
results_with_score = [(doc_id + 1, sim(query_vec, doc_vec)) | |
for doc_id, doc_vec in enumerate(doc_vectors)] | |
results_with_score = sorted(results_with_score, key=lambda x: -x[1]) | |
results = [x[0] for x in results_with_score] | |
print('Query:', query) | |
print('Relevant docs: ', relevant) | |
print() | |
for doc_id, score in results_with_score[:10]: | |
print('Score:', score) | |
print(docs[doc_id - 1]) | |
print() | |
def parse_args(): | |
arg_parser = argparse.ArgumentParser() | |
arg_parser.add_argument('--use_svd', action='store_true') | |
arg_parser.add_argument('--svd_rank', type=int, default=3000) | |
return arg_parser.parse_args() | |
search_query, docs = setup() | |
with gr.Blocks() as demo: | |
gr.Markdown("# Search Engine") | |
with gr.Row(): | |
query = gr.Textbox(label="Query", autofocus=True) | |
# with gr.Row(): | |
# search_results = gr.Textbox(lines=5, label="Results") | |
# | |
num_results_step = 5 | |
num_results = gr.State(num_results_step) | |
def render_results(query, num_res): | |
if query.strip() != '': | |
results = search_query(query)[:num_res] | |
for doc_id, score in results: | |
doc = docs[doc_id - 1] | |
html = f""" | |
<div style="margin: 30px 0"> | |
<div style="display: flex; align-items: center; gap: 10px;"> | |
<img src="https://www.cs.jhu.edu/favicon.ico" width="25px"> | |
<div style="color: #202124; font-size: 14px;">{doc.author if doc.author.strip() else 'No author provided'}</div> | |
</div> | |
<div style="font-size: 20px; color: rgb(26, 13, 171); cursor: pointer; margin: 10px 0" onclick="alert('Just a mockup search engine, lol.')">{doc.title}</div> | |
<div style="color: rgb(71, 71, 71);">{doc.abstract if doc.abstract.strip() else 'No abstract provided'}<br>Relevance score: {score:.3f}</div> | |
</div> | |
""" | |
gr.HTML(html) | |
gr.HTML('<div style="margin: 50px"></div>') | |
# more_btn = gr.HTML(''' | |
# <div style="display: flex;justify-content: center; margin: 40px"> | |
# <div style="color: rgb(26, 13, 171); font-size: 18px; font-weight: 600; cursor: pointer">More like this</div> | |
# </div>''') | |
more_btn = gr.Button('More like this') | |
more_btn.click(lambda x: x + num_results_step, num_results, num_results) | |
query.change(lambda _: num_results_step, num_results, num_results) | |
if __name__ == '__main__': | |
demo.launch( | |
# server_name="0.0.0.0", | |
server_port=7861, | |
) | |