File size: 5,341 Bytes
0df5fcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import requests, wikipedia, re, spacy
from rank_bm25 import BM25Okapi
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    pipeline,
)

class QueryProcessor:
    """
    Processes text into queries using a spaCy model
    """
    def __init__(self):
        self.keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'}
        self.nlp = spacy.load(
            'en_core_web_sm',
            disable = ['ner','parser','textcat']
        )

    def generate_query(self,text):
        """
        Process text into a search query,
        only retaining nouns, proper nouns numerals, verbs, and adjectives
        Parameters:
        -----------
        text : str
            The input text to be processed
        Returns:
        --------
        query : str
            The condensed search query
        """
        tokens = self.nlp(text)
        query = ' '.join(token.text for token in tokens \
                         if token.pos_ in self.keep)
        return query

class ContextRetriever:
    """
    Retrieves documents from Wikipedia based on a query,
    and prepared context paragraphs for a RoBERTa model
    """
    def __init__(self,url='https://en.wikipedia.org/w/api.php'):
        self.url = url
        self.pageids = None
        self.pages = None
        self.paragraphs = None
        
    def get_pageids(self,query):
        """
        Retrieve page ids corresponding to a search query
        Parameters:
        -----------
        query : str
            A query to use for Wikipedia page search
        Returns: None, but stores:
        --------
        self.pageids : list(int)
            A list of Wikipedia page ids corresponding to search results
        """
        params = {
            'action':'query',
            'list':'search',
            'srsearch':query,
            'format':'json',
        }
        results = requests.get(self.url, params=params).json()
        pageids = [page['pageid'] for page in results['query']['search']]
        self.pageids = pageids
    
    def get_pages(self):
        """
        Use MediaWiki API to retrieve page content corresponding to
        entries of self.pageids
        Parameters: None
        -----------
        Returns: None, but stores
        --------
        self.pages : list(str)
            Entries are content of pages corresponding to entries of self.pageid
        """
        assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids"
        self.pages = []
        for pageid in self.pageids:
            try:
                self.pages.append(wikipedia.page(pageid=pageid,auto_suggest=False).content)
            except wikipedia.DisambiguationError as e:
                continue

    def get_paragraphs(self):
        """
        Process self.pages into list of paragraphs from pages
        Parameters: None
        -----------
        Returns: None, but stores
        --------
        self.paragraphs : list(str)
            List of paragraphs from all pages in self.pages, in order of self.pages
        """
        assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
        # Content from WikiMedia has these headings. We only grab content appearing
        # before the first instance of any of these
        pattern = '|'.join([
            '== References ==',
            '== Further reading ==',
            '== External links',
            '== See also ==',
            '== Sources ==',
            '== Notes ==',
            '== Further references ==',
            '== Footnotes ==',
            '=== Notes ===',
            '=== Sources ===',
            '=== Citations ===',
        ])
        pattern = re.compile(pattern)
        paragraphs = []
        for page in self.pages:
            # Truncate page to the first index of the start of a matching heading,
            # or the end of the page if no matches exist
            idx = min([match.start() for match in pattern.finditer(page)]+[len(page)])
            page = page[:idx]
            # Split into paragraphs, omitting lines with headings (start with '='),
            # empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
            paragraphs += [
                p for p in page.split('\n') if p \
                and not p.startswith('=') \
                and not p.startswith('\t\t')
            ]
        self.paragraphs = paragraphs

    def rank_paragraphs(self,query,topn=10):
        """
        Ranks the elements of self.paragraphs in descending order
        by relevance to query using BM25F, and returns top topn results
        Parameters:
        -----------
        query : str
            The query to use in ranking paragraphs by relevance
        topn : int
            The number of most relevant paragraphs to return
        Returns: None, but stores
        --------
        self.best_paragraphs : list(str)
            The topn most relevant paragraphs to the query
        """
        tokenized_paragraphs = [p.split(" ") for p in self.paragraphs]
        bm25 = BM25Okapi(tokenized_paragraphs)
        tokenized_query = query.split(" ")
        self.best_paragraphs = bm25.get_top_n(tokenized_query,self.paragraphs,n=topn)