File size: 4,503 Bytes
0df5fcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import requests, wikipedia, re
from rank_bm25 import BM25Okapi
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    pipeline,
)

class ContextRetriever:
    """
    Retrieves documents from Wikipedia based on a query,
    and prepared context paragraphs for a RoBERTa model
    """
    def __init__(self,url='https://en.wikipedia.org/w/api.php'):
        self.url = url
        self.pageids = None
        self.pages = None
        self.paragraphs = None
        
    def get_pageids(self,query):
        """
        Retrieve page ids corresponding to a search query
        Parameters:
        -----------
        query : str
            A query to use for Wikipedia page search
        Returns: None, but stores:
        --------
        self.pageids : list(int)
            A list of Wikipedia page ids corresponding to search results
        """
        params = {
            'action':'query',
            'list':'search',
            'srsearch':query,
            'format':'json',
        }
        results = requests.get(self.url, params=params).json()
        pageids = [page['pageid'] for page in results['query']['search']]
        self.pageids = pageids
    
    def get_pages(self):
        """
        Use MediaWiki API to retrieve page content corresponding to
        entries of self.pageids
        Parameters: None
        -----------
        Returns: None, but stores
        --------
        self.pages : list(str)
            Entries are content of pages corresponding to entries of self.pageid
        """
        assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids"
        self.pages = []
        for pageid in self.pageids:
            try:
                self.pages.append(wikipedia.page(pageid=pageid,auto_suggest=False).content)
            except wikipedia.DisambiguationError as e:
                continue

    def get_paragraphs(self):
        """
        Process self.pages into list of paragraphs from pages
        Parameters: None
        -----------
        Returns: None, but stores
        --------
        self.paragraphs : list(str)
            List of paragraphs from all pages in self.pages, in order of self.pages
        """
        assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
        # Content from WikiMedia has these headings. We only grab content appearing
        # before the first instance of any of these
        pattern = '|'.join([
            '== References ==',
            '== Further reading ==',
            '== External links',
            '== See also ==',
            '== Sources ==',
            '== Notes ==',
            '== Further references ==',
            '== Footnotes ==',
            '=== Notes ===',
            '=== Sources ===',
            '=== Citations ===',
        ])
        pattern = re.compile(pattern)
        paragraphs = []
        for page in self.pages:
            # Truncate page to the first index of the start of a matching heading,
            # or the end of the page if no matches exist
            idx = min([match.start() for match in pattern.finditer(page)]+[len(page)])
            page = page[:idx]
            # Split into paragraphs, omitting lines with headings (start with '='),
            # empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
            paragraphs += [
                p for p in page.split('\n') if p \
                and not p.startswith('=') \
                and not p.startswith('\t\t')
            ]
        self.paragraphs = paragraphs

    def rank_paragraphs(self,query,topn=10):
        """
        Ranks the elements of self.paragraphs in descending order
        by relevance to query using BM25F, and returns top topn results
        Parameters:
        -----------
        query : str
            The query to use in ranking paragraphs by relevance
        topn : int
            The number of most relevant paragraphs to return
        Returns: None, but stores
        --------
        self.best_paragraphs : list(str)
            The topn most relevant paragraphs to the query
        """
        tokenized_paragraphs = [p.split(" ") for p in self.paragraphs]
        bm25 = BM25Okapi(tokenized_paragraphs)
        tokenized_query = query.split(" ")
        self.best_paragraphs = bm25.get_top_n(tokenized_query,self.paragraphs,n=topn)