Spaces:
Running
Running
from collections import defaultdict | |
import spacy | |
from tqdm import tqdm | |
from src.interfaces import Paper | |
class SearchAPI: | |
SEARCH_PRIORITY = ["year", "month", "venue", "author", "title", "abstract"] | |
def __init__(self) -> None: | |
self.papers: list[Paper] = [] | |
self.nlp = None | |
def in_string(self, statement: str, string: str, lemmatization: bool = False): | |
_stmt = " ".join(self.tokenize(statement, lemmatization=lemmatization)) | |
_string = " ".join(self.tokenize(string, lemmatization=lemmatization)) | |
return _stmt in _string | |
def exhausted_lemma_search( | |
self, query: dict[str, tuple[tuple[str]]], lemmatization: bool = False | |
) -> list[Paper]: | |
"""Exhausted search papers by matching query""" | |
results = [] | |
fields = [] | |
time_spans = defaultdict(list) | |
for f in self.SEARCH_PRIORITY: | |
if f in query: | |
fields.append(f) | |
if f in ["year", "month"]: | |
for span in query[f]: | |
assert len(span) == 2 | |
assert all(num.isdigit() for num in span) | |
time_spans[f].append((int(span[0]), int(span[1]))) | |
pbar = tqdm(self.papers) | |
found = 0 | |
for p in pbar: | |
for f in fields: | |
matched = False | |
or_statements = query[f] | |
if f in time_spans: | |
for s, e in time_spans[f]: | |
if s <= p[f] <= e: | |
matched = True | |
break | |
else: | |
for and_statements in or_statements: | |
if all( | |
self.in_string(stmt, p[f], lemmatization=lemmatization) | |
for stmt in and_statements | |
): | |
matched = True | |
break | |
if not matched: | |
break | |
else: | |
results.append(p) | |
found += 1 | |
pbar.set_postfix({"found": found}) | |
return results | |
def search( | |
self, query: dict[str, tuple[tuple[str]]], method: str = "exhausted" | |
) -> list[Paper]: | |
"""Search papers | |
Args: | |
query: A dict of queries on different field. | |
A query in a field is a tuple of strings, where strings are AND | |
and tuple means OR. Strings are case-insensitive. | |
e.g. { | |
"venue": (("EMNLP", ), ("ACL",)), | |
"title": (("parsing", "tree-crf"), ("event extraction",)) | |
} | |
This query means we want to find papers in EMNLP or ACL, | |
AND the title either contains ("parsing" AND "tree-crf") OR "event extraction" | |
method: choice from: | |
- `exhausted`: brute force mathing | |
Returns: | |
a list of `Paper` | |
""" | |
papers = [] | |
if method == "exhausted": | |
papers = self.exhausted_lemma_search(query) | |
elif method == "exhausted_lemma": | |
if self.nlp is None: | |
self.nlp = spacy.load("en_core_web_sm") | |
papers = self.exhausted_lemma_search(query, lemmatization=True) | |
else: | |
raise NotImplementedError | |
if papers: | |
papers = sorted(set(papers), key=lambda p: (p.year, p.month), reverse=True) | |
return papers | |
def tokenize(self, string: str, lemmatization: bool = False) -> list[str]: | |
_string = string.lower() | |
if lemmatization: | |
doc = self.nlp(_string) | |
return [str(t.lemma_) for t in doc] | |
else: | |
return _string.split() | |
def build_paper_list(cls, *args, **kwargs): | |
raise NotImplementedError | |
def build_and_search(cls, *args, **kwargs) -> list[Paper]: | |
raise NotImplementedError | |