Spaces:
Sleeping
Sleeping
File size: 5,341 Bytes
0df5fcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import requests, wikipedia, re, spacy
from rank_bm25 import BM25Okapi
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
AutoTokenizer,
AutoModelForQuestionAnswering,
pipeline,
)
class QueryProcessor:
"""
Processes text into queries using a spaCy model
"""
def __init__(self):
self.keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'}
self.nlp = spacy.load(
'en_core_web_sm',
disable = ['ner','parser','textcat']
)
def generate_query(self,text):
"""
Process text into a search query,
only retaining nouns, proper nouns numerals, verbs, and adjectives
Parameters:
-----------
text : str
The input text to be processed
Returns:
--------
query : str
The condensed search query
"""
tokens = self.nlp(text)
query = ' '.join(token.text for token in tokens \
if token.pos_ in self.keep)
return query
class ContextRetriever:
"""
Retrieves documents from Wikipedia based on a query,
and prepared context paragraphs for a RoBERTa model
"""
def __init__(self,url='https://en.wikipedia.org/w/api.php'):
self.url = url
self.pageids = None
self.pages = None
self.paragraphs = None
def get_pageids(self,query):
"""
Retrieve page ids corresponding to a search query
Parameters:
-----------
query : str
A query to use for Wikipedia page search
Returns: None, but stores:
--------
self.pageids : list(int)
A list of Wikipedia page ids corresponding to search results
"""
params = {
'action':'query',
'list':'search',
'srsearch':query,
'format':'json',
}
results = requests.get(self.url, params=params).json()
pageids = [page['pageid'] for page in results['query']['search']]
self.pageids = pageids
def get_pages(self):
"""
Use MediaWiki API to retrieve page content corresponding to
entries of self.pageids
Parameters: None
-----------
Returns: None, but stores
--------
self.pages : list(str)
Entries are content of pages corresponding to entries of self.pageid
"""
assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids"
self.pages = []
for pageid in self.pageids:
try:
self.pages.append(wikipedia.page(pageid=pageid,auto_suggest=False).content)
except wikipedia.DisambiguationError as e:
continue
def get_paragraphs(self):
"""
Process self.pages into list of paragraphs from pages
Parameters: None
-----------
Returns: None, but stores
--------
self.paragraphs : list(str)
List of paragraphs from all pages in self.pages, in order of self.pages
"""
assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
# Content from WikiMedia has these headings. We only grab content appearing
# before the first instance of any of these
pattern = '|'.join([
'== References ==',
'== Further reading ==',
'== External links',
'== See also ==',
'== Sources ==',
'== Notes ==',
'== Further references ==',
'== Footnotes ==',
'=== Notes ===',
'=== Sources ===',
'=== Citations ===',
])
pattern = re.compile(pattern)
paragraphs = []
for page in self.pages:
# Truncate page to the first index of the start of a matching heading,
# or the end of the page if no matches exist
idx = min([match.start() for match in pattern.finditer(page)]+[len(page)])
page = page[:idx]
# Split into paragraphs, omitting lines with headings (start with '='),
# empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
paragraphs += [
p for p in page.split('\n') if p \
and not p.startswith('=') \
and not p.startswith('\t\t')
]
self.paragraphs = paragraphs
def rank_paragraphs(self,query,topn=10):
"""
Ranks the elements of self.paragraphs in descending order
by relevance to query using BM25F, and returns top topn results
Parameters:
-----------
query : str
The query to use in ranking paragraphs by relevance
topn : int
The number of most relevant paragraphs to return
Returns: None, but stores
--------
self.best_paragraphs : list(str)
The topn most relevant paragraphs to the query
"""
tokenized_paragraphs = [p.split(" ") for p in self.paragraphs]
bm25 = BM25Okapi(tokenized_paragraphs)
tokenized_query = query.split(" ")
self.best_paragraphs = bm25.get_top_n(tokenized_query,self.paragraphs,n=topn)
|