Spaces:
Sleeping
Sleeping
Upload 9 files
Browse files- app.py +296 -0
- lib/.DS_Store +0 -0
- lib/.ipynb_checkpoints/__init__-checkpoint.py +0 -0
- lib/.ipynb_checkpoints/utils-checkpoint.py +154 -0
- lib/__init__.py +0 -0
- lib/__pycache__/__init__.cpython-310.pyc +0 -0
- lib/__pycache__/utils.cpython-310.pyc +0 -0
- lib/utils.py +125 -0
- requirements.txt +92 -0
app.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import streamlit as st
|
3 |
+
from datasets import Dataset
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from transformers import (
|
6 |
+
AutoTokenizer,
|
7 |
+
AutoModelForQuestionAnswering,
|
8 |
+
pipeline,
|
9 |
+
)
|
10 |
+
import spacy
|
11 |
+
# import pandas as pd
|
12 |
+
from lib.utils import ContextRetriever
|
13 |
+
|
14 |
+
|
15 |
+
#### TO DO:######
|
16 |
+
# build out functions for:
|
17 |
+
# * formatting input into document retrieval query (spaCy)
|
18 |
+
# * document retrieval based on query (wikipedia library)
|
19 |
+
# * document postprocessing into passages
|
20 |
+
# * ranking passage based on BM25 scores for query (rank_bm25)
|
21 |
+
# * feeding passages into RoBERTa an reporting answer(s) and passages as evidence
|
22 |
+
# decide what to do with examples
|
23 |
+
|
24 |
+
### CAN REMOVE:#####
|
25 |
+
# * context collection
|
26 |
+
# *
|
27 |
+
|
28 |
+
########################
|
29 |
+
### Helper functions ###
|
30 |
+
########################
|
31 |
+
|
32 |
+
# Build trainer using model and tokenizer from Hugging Face repo
|
33 |
+
@st.cache_resource(show_spinner=False)
|
34 |
+
def get_pipeline():
|
35 |
+
"""
|
36 |
+
Load model and tokenizer from 🤗 repo
|
37 |
+
and build pipeline
|
38 |
+
Parameters: None
|
39 |
+
-----------
|
40 |
+
Returns:
|
41 |
+
--------
|
42 |
+
qa_pipeline : transformers.QuestionAnsweringPipeline
|
43 |
+
The question answering pipeline object
|
44 |
+
"""
|
45 |
+
repo_id = 'etweedy/roberta-base-squad-v2'
|
46 |
+
qa_pipeline = pipeline(
|
47 |
+
task = 'question-answering',
|
48 |
+
model=repo_id,
|
49 |
+
tokenizer=repo_id,
|
50 |
+
handle_impossible_answer = True
|
51 |
+
)
|
52 |
+
return qa_pipeline
|
53 |
+
|
54 |
+
@st.cache_resource(show_spinner=False)
|
55 |
+
def get_spacy():
|
56 |
+
"""
|
57 |
+
Load spaCy model for processing query
|
58 |
+
Parameters: None
|
59 |
+
-----------
|
60 |
+
Returns:
|
61 |
+
--------
|
62 |
+
nlp : spaCy.Pipe
|
63 |
+
Portion of 'en_core_web_sm' model pipeline
|
64 |
+
only containing tokenizer and part-of-speech
|
65 |
+
tagger
|
66 |
+
"""
|
67 |
+
nlp = spacy.load(
|
68 |
+
'en_core_web_sm',
|
69 |
+
disable = ['ner','parser','textcat']
|
70 |
+
)
|
71 |
+
return nlp
|
72 |
+
|
73 |
+
def generate_query(nlp,text):
|
74 |
+
"""
|
75 |
+
Process text into a search query,
|
76 |
+
only retaining nouns, proper nouns,
|
77 |
+
numerals, verbs, and adjectives
|
78 |
+
Parameters:
|
79 |
+
-----------
|
80 |
+
nlp : spacy.Pipe
|
81 |
+
spaCy pipeline for processing search query
|
82 |
+
text : str
|
83 |
+
The input text to be processed
|
84 |
+
Returns:
|
85 |
+
--------
|
86 |
+
query : str
|
87 |
+
The condensed search query
|
88 |
+
"""
|
89 |
+
tokens = nlp(text)
|
90 |
+
keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'}
|
91 |
+
query = ' '.join(token.text for token in tokens \
|
92 |
+
if token.pos_ in keep)
|
93 |
+
return query
|
94 |
+
|
95 |
+
def fill_in_example(i):
|
96 |
+
"""
|
97 |
+
Function for context-question example button click
|
98 |
+
"""
|
99 |
+
st.session_state['response'] = ''
|
100 |
+
st.session_state['question'] = ex_q[i]
|
101 |
+
|
102 |
+
def clear_boxes():
|
103 |
+
"""
|
104 |
+
Function for field clear button click
|
105 |
+
"""
|
106 |
+
st.session_state['response'] = ''
|
107 |
+
st.session_state['question'] = ''
|
108 |
+
|
109 |
+
# def get_examples():
|
110 |
+
# """
|
111 |
+
# Retrieve pre-made examples from a .csv file
|
112 |
+
# Parameters: None
|
113 |
+
# -----------
|
114 |
+
# Returns:
|
115 |
+
# --------
|
116 |
+
# questions, contexts : list, list
|
117 |
+
# Lists of examples of corresponding question-context pairs
|
118 |
+
|
119 |
+
# """
|
120 |
+
# examples = pd.read_csv('examples.csv')
|
121 |
+
# questions = list(examples['question'])
|
122 |
+
# return questions
|
123 |
+
|
124 |
+
#############
|
125 |
+
### Setup ###
|
126 |
+
#############
|
127 |
+
|
128 |
+
# Set mps or cuda device if available
|
129 |
+
if torch.backends.mps.is_available():
|
130 |
+
device = "mps"
|
131 |
+
elif torch.cuda.is_available():
|
132 |
+
device = "cuda"
|
133 |
+
else:
|
134 |
+
device = "cpu"
|
135 |
+
|
136 |
+
# Initialize session state variables
|
137 |
+
if 'response' not in st.session_state:
|
138 |
+
st.session_state['response'] = ''
|
139 |
+
if 'question' not in st.session_state:
|
140 |
+
st.session_state['question'] = ''
|
141 |
+
|
142 |
+
# Retrieve trained RoBERTa pipeline for Q&A
|
143 |
+
# and spaCy pipeline for processing search query
|
144 |
+
with st.spinner('Loading the model...'):
|
145 |
+
qa_pipeline = get_pipeline()
|
146 |
+
nlp = get_spacy()
|
147 |
+
|
148 |
+
# # Grab example question-context pairs from csv file
|
149 |
+
# ex_q, ex_c = get_examples()
|
150 |
+
|
151 |
+
###################
|
152 |
+
### App content ###
|
153 |
+
###################
|
154 |
+
|
155 |
+
# Intro text
|
156 |
+
st.header('RoBERTa answer retieval')
|
157 |
+
st.markdown('''
|
158 |
+
This app demonstrates the answer-retrieval capabilities of a fine-tuned RoBERTa (Robustly optimized Bidirectional Encoder Representations from Transformers) model.
|
159 |
+
|
160 |
+
Please type in a question and click submit. When you do, a few things will happen:
|
161 |
+
1. A Wikipedia search will be performed based on your question
|
162 |
+
2. Candidate passages will be ranked based on a similarity score as compared to your question
|
163 |
+
3. RoBERTa will search the best candidate passages to find the answer to your question
|
164 |
+
|
165 |
+
If the model cannot find the answer to your question, it will tell you so.
|
166 |
+
''')
|
167 |
+
with st.expander('Click to read more about the model...'):
|
168 |
+
st.markdown('''
|
169 |
+
* [Click here](https://huggingface.co/etweedy/roberta-base-squad-v2) to visit the Hugging Face model card for this fine-tuned model.
|
170 |
+
* To create this model, the [RoBERTa base model](https://huggingface.co/roberta-base) was fine-tuned on Version 2 of [SQuAD (Stanford Question Answering Dataset)](https://huggingface.co/datasets/squad_v2), a dataset of context-question-answer triples.
|
171 |
+
* The objective of the model is "extractive question answering", the task of retrieving the answer to the question from a given context text corpus.
|
172 |
+
* SQuAD Version 2 incorporates the 100,000 samples from Version 1.1, along with 50,000 'unanswerable' questions, i.e. samples in the question cannot be answered using the context given.
|
173 |
+
* The original base RoBERTa model was introduced in [this paper](https://arxiv.org/abs/1907.11692) and [this repository](https://github.com/facebookresearch/fairseq/tree/main/examples/roberta). Here's a citation for that base model:
|
174 |
+
```bibtex
|
175 |
+
@article{DBLP:journals/corr/abs-1907-11692,
|
176 |
+
author = {Yinhan Liu and
|
177 |
+
Myle Ott and
|
178 |
+
Naman Goyal and
|
179 |
+
Jingfei Du and
|
180 |
+
Mandar Joshi and
|
181 |
+
Danqi Chen and
|
182 |
+
Omer Levy and
|
183 |
+
Mike Lewis and
|
184 |
+
Luke Zettlemoyer and
|
185 |
+
Veselin Stoyanov},
|
186 |
+
title = {RoBERTa: {A} Robustly Optimized {BERT} Pretraining Approach},
|
187 |
+
journal = {CoRR},
|
188 |
+
volume = {abs/1907.11692},
|
189 |
+
year = {2019},
|
190 |
+
url = {http://arxiv.org/abs/1907.11692},
|
191 |
+
archivePrefix = {arXiv},
|
192 |
+
eprint = {1907.11692},
|
193 |
+
timestamp = {Thu, 01 Aug 2019 08:59:33 +0200},
|
194 |
+
biburl = {https://dblp.org/rec/journals/corr/abs-1907-11692.bib},
|
195 |
+
bibsource = {dblp computer science bibliography, https://dblp.org}
|
196 |
+
}
|
197 |
+
```
|
198 |
+
''')
|
199 |
+
# st.markdown('''
|
200 |
+
# Please type or paste a context paragraph and question you'd like to ask about it. The model will attempt to answer the question based on the context you provided. If the model cannot find the answer in the context, it will tell you so - the model is also trained to recognize when the context doesn't provide the answer.
|
201 |
+
|
202 |
+
# Your results will appear below the question field when the model is finished running.
|
203 |
+
|
204 |
+
# Alternatively, you can try an example by clicking one of the buttons below:
|
205 |
+
# ''')
|
206 |
+
|
207 |
+
# Generate containers in order
|
208 |
+
# example_container = st.container()
|
209 |
+
input_container = st.container()
|
210 |
+
button_container = st.container()
|
211 |
+
response_container = st.container()
|
212 |
+
|
213 |
+
###########################
|
214 |
+
### Populate containers ###
|
215 |
+
###########################
|
216 |
+
|
217 |
+
# Populate example button container
|
218 |
+
# with example_container:
|
219 |
+
# ex_cols = st.columns(len(ex_q)+1)
|
220 |
+
# for i in range(len(ex_q)):
|
221 |
+
# with ex_cols[i]:
|
222 |
+
# st.button(
|
223 |
+
# label = f'Try example {i+1}',
|
224 |
+
# key = f'ex_button_{i+1}',
|
225 |
+
# on_click = fill_in_example,
|
226 |
+
# args=(i,),
|
227 |
+
# )
|
228 |
+
# with ex_cols[-1]:
|
229 |
+
# st.button(
|
230 |
+
# label = "Clear all fields",
|
231 |
+
# key = "clear_button",
|
232 |
+
# on_click = clear_boxes,
|
233 |
+
# )
|
234 |
+
|
235 |
+
# Populate user input container
|
236 |
+
with input_container:
|
237 |
+
with st.form(key='input_form',clear_on_submit=False):
|
238 |
+
# Question input field
|
239 |
+
question = st.text_input(
|
240 |
+
label='Question',
|
241 |
+
value=st.session_state['question'],
|
242 |
+
key='question_field',
|
243 |
+
label_visibility='hidden',
|
244 |
+
placeholder='Enter your question here.',
|
245 |
+
)
|
246 |
+
# Form submit button
|
247 |
+
query_submitted = st.form_submit_button("Submit")
|
248 |
+
if query_submitted:
|
249 |
+
# update question, context in session state
|
250 |
+
st.session_state['question'] = question
|
251 |
+
with st.spinner('Retrieving documentation...'):
|
252 |
+
query = generate_query(nlp,question)
|
253 |
+
retriever = ContextRetriever()
|
254 |
+
retriever.get_pageids(query)
|
255 |
+
retriever.get_pages()
|
256 |
+
retriever.get_paragraphs()
|
257 |
+
retriever.rank_paragraphs(question)
|
258 |
+
with st.spinner('Generating response...'):
|
259 |
+
# Loop through best_paragraph contexts
|
260 |
+
# looking for answer in each
|
261 |
+
best_answer = ""
|
262 |
+
for context in retriever.best_paragraphs:
|
263 |
+
input = {
|
264 |
+
'context':context,
|
265 |
+
'question':st.session_state['question'],
|
266 |
+
}
|
267 |
+
# Pass to QA pipeline
|
268 |
+
response = qa_pipeline(**input)
|
269 |
+
if response['answer']!='':
|
270 |
+
best_answer = response['answer']
|
271 |
+
best_context = context
|
272 |
+
break
|
273 |
+
# Update response in session state
|
274 |
+
if best_answer == "":
|
275 |
+
st.session_state['response'] = "I cannot find the answer to your question."
|
276 |
+
else:
|
277 |
+
st.session_state['response'] = f"""
|
278 |
+
My answer is: {best_answer}
|
279 |
+
|
280 |
+
...and here's where I found it:
|
281 |
+
|
282 |
+
{best_context}
|
283 |
+
"""
|
284 |
+
|
285 |
+
# Button for clearing the form
|
286 |
+
with button_container:
|
287 |
+
st.button(
|
288 |
+
label = "Clear all fields",
|
289 |
+
key = "clear_button",
|
290 |
+
on_click = clear_boxes,
|
291 |
+
)
|
292 |
+
|
293 |
+
# Display response
|
294 |
+
with response_container:
|
295 |
+
st.write(st.session_state['response'])
|
296 |
+
|
lib/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
lib/.ipynb_checkpoints/__init__-checkpoint.py
ADDED
File without changes
|
lib/.ipynb_checkpoints/utils-checkpoint.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests, wikipedia, re, spacy
|
2 |
+
from rank_bm25 import BM25Okapi
|
3 |
+
import torch
|
4 |
+
from datasets import Dataset
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from transformers import (
|
7 |
+
AutoTokenizer,
|
8 |
+
AutoModelForQuestionAnswering,
|
9 |
+
pipeline,
|
10 |
+
)
|
11 |
+
|
12 |
+
class QueryProcessor:
|
13 |
+
"""
|
14 |
+
Processes text into queries using a spaCy model
|
15 |
+
"""
|
16 |
+
def __init__(self):
|
17 |
+
self.keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'}
|
18 |
+
self.nlp = spacy.load(
|
19 |
+
'en_core_web_sm',
|
20 |
+
disable = ['ner','parser','textcat']
|
21 |
+
)
|
22 |
+
|
23 |
+
def generate_query(self,text):
|
24 |
+
"""
|
25 |
+
Process text into a search query,
|
26 |
+
only retaining nouns, proper nouns numerals, verbs, and adjectives
|
27 |
+
Parameters:
|
28 |
+
-----------
|
29 |
+
text : str
|
30 |
+
The input text to be processed
|
31 |
+
Returns:
|
32 |
+
--------
|
33 |
+
query : str
|
34 |
+
The condensed search query
|
35 |
+
"""
|
36 |
+
tokens = self.nlp(text)
|
37 |
+
query = ' '.join(token.text for token in tokens \
|
38 |
+
if token.pos_ in self.keep)
|
39 |
+
return query
|
40 |
+
|
41 |
+
class ContextRetriever:
|
42 |
+
"""
|
43 |
+
Retrieves documents from Wikipedia based on a query,
|
44 |
+
and prepared context paragraphs for a RoBERTa model
|
45 |
+
"""
|
46 |
+
def __init__(self,url='https://en.wikipedia.org/w/api.php'):
|
47 |
+
self.url = url
|
48 |
+
self.pageids = None
|
49 |
+
self.pages = None
|
50 |
+
self.paragraphs = None
|
51 |
+
|
52 |
+
def get_pageids(self,query):
|
53 |
+
"""
|
54 |
+
Retrieve page ids corresponding to a search query
|
55 |
+
Parameters:
|
56 |
+
-----------
|
57 |
+
query : str
|
58 |
+
A query to use for Wikipedia page search
|
59 |
+
Returns: None, but stores:
|
60 |
+
--------
|
61 |
+
self.pageids : list(int)
|
62 |
+
A list of Wikipedia page ids corresponding to search results
|
63 |
+
"""
|
64 |
+
params = {
|
65 |
+
'action':'query',
|
66 |
+
'list':'search',
|
67 |
+
'srsearch':query,
|
68 |
+
'format':'json',
|
69 |
+
}
|
70 |
+
results = requests.get(self.url, params=params).json()
|
71 |
+
pageids = [page['pageid'] for page in results['query']['search']]
|
72 |
+
self.pageids = pageids
|
73 |
+
|
74 |
+
def get_pages(self):
|
75 |
+
"""
|
76 |
+
Use MediaWiki API to retrieve page content corresponding to
|
77 |
+
entries of self.pageids
|
78 |
+
Parameters: None
|
79 |
+
-----------
|
80 |
+
Returns: None, but stores
|
81 |
+
--------
|
82 |
+
self.pages : list(str)
|
83 |
+
Entries are content of pages corresponding to entries of self.pageid
|
84 |
+
"""
|
85 |
+
assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids"
|
86 |
+
self.pages = []
|
87 |
+
for pageid in self.pageids:
|
88 |
+
try:
|
89 |
+
self.pages.append(wikipedia.page(pageid=pageid,auto_suggest=False).content)
|
90 |
+
except wikipedia.DisambiguationError as e:
|
91 |
+
continue
|
92 |
+
|
93 |
+
def get_paragraphs(self):
|
94 |
+
"""
|
95 |
+
Process self.pages into list of paragraphs from pages
|
96 |
+
Parameters: None
|
97 |
+
-----------
|
98 |
+
Returns: None, but stores
|
99 |
+
--------
|
100 |
+
self.paragraphs : list(str)
|
101 |
+
List of paragraphs from all pages in self.pages, in order of self.pages
|
102 |
+
"""
|
103 |
+
assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
|
104 |
+
# Content from WikiMedia has these headings. We only grab content appearing
|
105 |
+
# before the first instance of any of these
|
106 |
+
pattern = '|'.join([
|
107 |
+
'== References ==',
|
108 |
+
'== Further reading ==',
|
109 |
+
'== External links',
|
110 |
+
'== See also ==',
|
111 |
+
'== Sources ==',
|
112 |
+
'== Notes ==',
|
113 |
+
'== Further references ==',
|
114 |
+
'== Footnotes ==',
|
115 |
+
'=== Notes ===',
|
116 |
+
'=== Sources ===',
|
117 |
+
'=== Citations ===',
|
118 |
+
])
|
119 |
+
pattern = re.compile(pattern)
|
120 |
+
paragraphs = []
|
121 |
+
for page in self.pages:
|
122 |
+
# Truncate page to the first index of the start of a matching heading,
|
123 |
+
# or the end of the page if no matches exist
|
124 |
+
idx = min([match.start() for match in pattern.finditer(page)]+[len(page)])
|
125 |
+
page = page[:idx]
|
126 |
+
# Split into paragraphs, omitting lines with headings (start with '='),
|
127 |
+
# empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
|
128 |
+
paragraphs += [
|
129 |
+
p for p in page.split('\n') if p \
|
130 |
+
and not p.startswith('=') \
|
131 |
+
and not p.startswith('\t\t')
|
132 |
+
]
|
133 |
+
self.paragraphs = paragraphs
|
134 |
+
|
135 |
+
def rank_paragraphs(self,query,topn=10):
|
136 |
+
"""
|
137 |
+
Ranks the elements of self.paragraphs in descending order
|
138 |
+
by relevance to query using BM25F, and returns top topn results
|
139 |
+
Parameters:
|
140 |
+
-----------
|
141 |
+
query : str
|
142 |
+
The query to use in ranking paragraphs by relevance
|
143 |
+
topn : int
|
144 |
+
The number of most relevant paragraphs to return
|
145 |
+
Returns: None, but stores
|
146 |
+
--------
|
147 |
+
self.best_paragraphs : list(str)
|
148 |
+
The topn most relevant paragraphs to the query
|
149 |
+
"""
|
150 |
+
tokenized_paragraphs = [p.split(" ") for p in self.paragraphs]
|
151 |
+
bm25 = BM25Okapi(tokenized_paragraphs)
|
152 |
+
tokenized_query = query.split(" ")
|
153 |
+
self.best_paragraphs = bm25.get_top_n(tokenized_query,self.paragraphs,n=topn)
|
154 |
+
|
lib/__init__.py
ADDED
File without changes
|
lib/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (168 Bytes). View file
|
|
lib/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (4.78 kB). View file
|
|
lib/utils.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests, wikipedia, re
|
2 |
+
from rank_bm25 import BM25Okapi
|
3 |
+
import torch
|
4 |
+
from datasets import Dataset
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from transformers import (
|
7 |
+
AutoTokenizer,
|
8 |
+
AutoModelForQuestionAnswering,
|
9 |
+
pipeline,
|
10 |
+
)
|
11 |
+
|
12 |
+
class ContextRetriever:
|
13 |
+
"""
|
14 |
+
Retrieves documents from Wikipedia based on a query,
|
15 |
+
and prepared context paragraphs for a RoBERTa model
|
16 |
+
"""
|
17 |
+
def __init__(self,url='https://en.wikipedia.org/w/api.php'):
|
18 |
+
self.url = url
|
19 |
+
self.pageids = None
|
20 |
+
self.pages = None
|
21 |
+
self.paragraphs = None
|
22 |
+
|
23 |
+
def get_pageids(self,query):
|
24 |
+
"""
|
25 |
+
Retrieve page ids corresponding to a search query
|
26 |
+
Parameters:
|
27 |
+
-----------
|
28 |
+
query : str
|
29 |
+
A query to use for Wikipedia page search
|
30 |
+
Returns: None, but stores:
|
31 |
+
--------
|
32 |
+
self.pageids : list(int)
|
33 |
+
A list of Wikipedia page ids corresponding to search results
|
34 |
+
"""
|
35 |
+
params = {
|
36 |
+
'action':'query',
|
37 |
+
'list':'search',
|
38 |
+
'srsearch':query,
|
39 |
+
'format':'json',
|
40 |
+
}
|
41 |
+
results = requests.get(self.url, params=params).json()
|
42 |
+
pageids = [page['pageid'] for page in results['query']['search']]
|
43 |
+
self.pageids = pageids
|
44 |
+
|
45 |
+
def get_pages(self):
|
46 |
+
"""
|
47 |
+
Use MediaWiki API to retrieve page content corresponding to
|
48 |
+
entries of self.pageids
|
49 |
+
Parameters: None
|
50 |
+
-----------
|
51 |
+
Returns: None, but stores
|
52 |
+
--------
|
53 |
+
self.pages : list(str)
|
54 |
+
Entries are content of pages corresponding to entries of self.pageid
|
55 |
+
"""
|
56 |
+
assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids"
|
57 |
+
self.pages = []
|
58 |
+
for pageid in self.pageids:
|
59 |
+
try:
|
60 |
+
self.pages.append(wikipedia.page(pageid=pageid,auto_suggest=False).content)
|
61 |
+
except wikipedia.DisambiguationError as e:
|
62 |
+
continue
|
63 |
+
|
64 |
+
def get_paragraphs(self):
|
65 |
+
"""
|
66 |
+
Process self.pages into list of paragraphs from pages
|
67 |
+
Parameters: None
|
68 |
+
-----------
|
69 |
+
Returns: None, but stores
|
70 |
+
--------
|
71 |
+
self.paragraphs : list(str)
|
72 |
+
List of paragraphs from all pages in self.pages, in order of self.pages
|
73 |
+
"""
|
74 |
+
assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
|
75 |
+
# Content from WikiMedia has these headings. We only grab content appearing
|
76 |
+
# before the first instance of any of these
|
77 |
+
pattern = '|'.join([
|
78 |
+
'== References ==',
|
79 |
+
'== Further reading ==',
|
80 |
+
'== External links',
|
81 |
+
'== See also ==',
|
82 |
+
'== Sources ==',
|
83 |
+
'== Notes ==',
|
84 |
+
'== Further references ==',
|
85 |
+
'== Footnotes ==',
|
86 |
+
'=== Notes ===',
|
87 |
+
'=== Sources ===',
|
88 |
+
'=== Citations ===',
|
89 |
+
])
|
90 |
+
pattern = re.compile(pattern)
|
91 |
+
paragraphs = []
|
92 |
+
for page in self.pages:
|
93 |
+
# Truncate page to the first index of the start of a matching heading,
|
94 |
+
# or the end of the page if no matches exist
|
95 |
+
idx = min([match.start() for match in pattern.finditer(page)]+[len(page)])
|
96 |
+
page = page[:idx]
|
97 |
+
# Split into paragraphs, omitting lines with headings (start with '='),
|
98 |
+
# empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
|
99 |
+
paragraphs += [
|
100 |
+
p for p in page.split('\n') if p \
|
101 |
+
and not p.startswith('=') \
|
102 |
+
and not p.startswith('\t\t')
|
103 |
+
]
|
104 |
+
self.paragraphs = paragraphs
|
105 |
+
|
106 |
+
def rank_paragraphs(self,query,topn=10):
|
107 |
+
"""
|
108 |
+
Ranks the elements of self.paragraphs in descending order
|
109 |
+
by relevance to query using BM25F, and returns top topn results
|
110 |
+
Parameters:
|
111 |
+
-----------
|
112 |
+
query : str
|
113 |
+
The query to use in ranking paragraphs by relevance
|
114 |
+
topn : int
|
115 |
+
The number of most relevant paragraphs to return
|
116 |
+
Returns: None, but stores
|
117 |
+
--------
|
118 |
+
self.best_paragraphs : list(str)
|
119 |
+
The topn most relevant paragraphs to the query
|
120 |
+
"""
|
121 |
+
tokenized_paragraphs = [p.split(" ") for p in self.paragraphs]
|
122 |
+
bm25 = BM25Okapi(tokenized_paragraphs)
|
123 |
+
tokenized_query = query.split(" ")
|
124 |
+
self.best_paragraphs = bm25.get_top_n(tokenized_query,self.paragraphs,n=topn)
|
125 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.8.4
|
2 |
+
aiosignal==1.3.1
|
3 |
+
altair==5.0.1
|
4 |
+
async-timeout==4.0.2
|
5 |
+
attrs==23.1.0
|
6 |
+
beautifulsoup4==4.12.2
|
7 |
+
blinker==1.6.2
|
8 |
+
blis==0.7.9
|
9 |
+
cachetools==5.3.1
|
10 |
+
catalogue==2.0.8
|
11 |
+
certifi==2023.5.7
|
12 |
+
charset-normalizer==3.2.0
|
13 |
+
click==8.1.4
|
14 |
+
confection==0.1.0
|
15 |
+
cymem==2.0.7
|
16 |
+
datasets==2.13.1
|
17 |
+
decorator==5.1.1
|
18 |
+
dill==0.3.6
|
19 |
+
filelock==3.12.2
|
20 |
+
frozenlist==1.3.3
|
21 |
+
fsspec==2023.6.0
|
22 |
+
gitdb==4.0.10
|
23 |
+
GitPython==3.1.31
|
24 |
+
huggingface-hub==0.16.4
|
25 |
+
idna==3.4
|
26 |
+
importlib-metadata==6.8.0
|
27 |
+
Jinja2==3.1.2
|
28 |
+
jsonschema==4.18.0
|
29 |
+
jsonschema-specifications==2023.6.1
|
30 |
+
langcodes==3.3.0
|
31 |
+
markdown-it-py==3.0.0
|
32 |
+
MarkupSafe==2.1.3
|
33 |
+
mdurl==0.1.2
|
34 |
+
mpmath==1.3.0
|
35 |
+
multidict==6.0.4
|
36 |
+
multiprocess==0.70.14
|
37 |
+
murmurhash==1.0.9
|
38 |
+
networkx==3.1
|
39 |
+
numpy==1.25.1
|
40 |
+
packaging==23.1
|
41 |
+
pandas==2.0.3
|
42 |
+
pathy==0.10.2
|
43 |
+
Pillow==9.5.0
|
44 |
+
preshed==3.0.8
|
45 |
+
protobuf==4.23.4
|
46 |
+
pyarrow==12.0.1
|
47 |
+
pydantic==1.10.11
|
48 |
+
pydeck==0.8.1b0
|
49 |
+
Pygments==2.15.1
|
50 |
+
Pympler==1.0.1
|
51 |
+
python-dateutil==2.8.2
|
52 |
+
pytz==2023.3
|
53 |
+
pytz-deprecation-shim==0.1.0.post0
|
54 |
+
PyYAML==6.0
|
55 |
+
rank-bm25==0.2.2
|
56 |
+
referencing==0.29.1
|
57 |
+
regex==2023.6.3
|
58 |
+
requests==2.31.0
|
59 |
+
rich==13.4.2
|
60 |
+
rpds-py==0.8.10
|
61 |
+
safetensors==0.3.1
|
62 |
+
six==1.16.0
|
63 |
+
smart-open==6.3.0
|
64 |
+
smmap==5.0.0
|
65 |
+
soupsieve==2.4.1
|
66 |
+
spacy==3.6.0
|
67 |
+
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.6.0/en_core_web_sm-3.6.0-py3-none-any.whl
|
68 |
+
spacy-legacy==3.0.12
|
69 |
+
spacy-loggers==1.0.4
|
70 |
+
srsly==2.4.6
|
71 |
+
streamlit==1.24.1
|
72 |
+
sympy==1.12
|
73 |
+
tenacity==8.2.2
|
74 |
+
thinc==8.1.10
|
75 |
+
tokenizers==0.13.3
|
76 |
+
toml==0.10.2
|
77 |
+
toolz==0.12.0
|
78 |
+
torch==2.0.1
|
79 |
+
tornado==6.3.2
|
80 |
+
tqdm==4.65.0
|
81 |
+
transformers==4.30.2
|
82 |
+
typer==0.8.0
|
83 |
+
typing_extensions==4.7.1
|
84 |
+
tzdata==2023.3
|
85 |
+
tzlocal==4.3.1
|
86 |
+
urllib3==2.0.3
|
87 |
+
validators==0.20.0
|
88 |
+
wasabi==1.1.2
|
89 |
+
wikipedia==1.4.0
|
90 |
+
xxhash==3.2.0
|
91 |
+
yarl==1.9.2
|
92 |
+
zipp==3.16.0
|