Spaces:
Sleeping
Sleeping
Aakash Goel
commited on
Commit
·
d0941bb
1
Parent(s):
ab36985
cleaning
Browse files- code/question_generation/CITATION.cff +0 -10
- code/question_generation/LICENSE +0 -21
- code/question_generation/README.md +0 -352
- code/question_generation/__init__.py.py +0 -0
- code/question_generation/data_collator.py +0 -83
- code/question_generation/eval.py +0 -92
- code/question_generation/pipelines.py +0 -386
- code/question_generation/prepare_data.py +0 -204
- code/question_generation/question_generation.ipynb +0 -0
- code/question_generation/run_qg.py +0 -236
- code/question_generation/trainer.py +0 -56
- code/question_generation/utils.py +0 -49
- code/quiz_gen.py +0 -149
- code/quiz_gen_new.py +0 -257
- code/quiz_gen_new2.py +0 -277
- input/input1.txt +13 -0
- results/df_quiz_log_file_v1.csv +0 -0
code/question_generation/CITATION.cff
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
cff-version: 1.2.0
|
2 |
-
message: "If you use this software, please cite it as below."
|
3 |
-
authors:
|
4 |
-
- family-names: "Patil"
|
5 |
-
given-names: "Suraj"
|
6 |
-
title: "Question Generation using transformers"
|
7 |
-
version: 1.0.0
|
8 |
-
date-released: 2020-07
|
9 |
-
publisher: "GitHub"
|
10 |
-
url: "https://github.com/patil-suraj/question_generation"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/question_generation/LICENSE
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
MIT License
|
2 |
-
|
3 |
-
Copyright (c) 2020 Suraj Patil
|
4 |
-
|
5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
-
of this software and associated documentation files (the "Software"), to deal
|
7 |
-
in the Software without restriction, including without limitation the rights
|
8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
-
copies of the Software, and to permit persons to whom the Software is
|
10 |
-
furnished to do so, subject to the following conditions:
|
11 |
-
|
12 |
-
The above copyright notice and this permission notice shall be included in all
|
13 |
-
copies or substantial portions of the Software.
|
14 |
-
|
15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/question_generation/README.md
DELETED
@@ -1,352 +0,0 @@
|
|
1 |
-
# Question Generation using 🤗transformers
|
2 |
-
|
3 |
-
- [Question Generation using 🤗transformers](#question-generation-using-transformers)
|
4 |
-
- [Project Details](#project-details)
|
5 |
-
- [Initial experiments](#initial-experiments)
|
6 |
-
- [answer aware question generation](#answer-aware-question-generation)
|
7 |
-
- [answer extraction models](#answer-extraction-models)
|
8 |
-
- [Multitask QA-QG](#multitask-qa-qg)
|
9 |
-
- [End-to-End question generation (answer agnostic)](#end-to-end-question-generation-answer-agnostic)
|
10 |
-
- [Results](#results)
|
11 |
-
- [Requirements](#requirements)
|
12 |
-
- [Usage](#usage)
|
13 |
-
- [Question Generation](#question-generation)
|
14 |
-
- [Multitask QA-QG](#multitask-qa-qg-1)
|
15 |
-
- [End-to-end question generation (without answer supervision)](#end-to-end-question-generation-without-answer-supervision)
|
16 |
-
- [Fine-tuning](#fine-tuning)
|
17 |
-
- [Data processing](#data-processing)
|
18 |
-
- [training](#training)
|
19 |
-
- [Evaluation](#evaluation)
|
20 |
-
- [Applications 🚀](#applications-)
|
21 |
-
- [Relevant papers](#relevant-papers)
|
22 |
-
|
23 |
-
|
24 |
-
## Project Details
|
25 |
-
Question generation is the task of automatically generating questions from a text paragraph. The most straight-forward way for this is answer aware question generation. In answer aware question generation the model is presented with the answer and the passage and asked to generate a question for that answer by considering the passage context. While there are many papers available for QG task, it's still not as mainstream as QA. One of the reasons is most of the earlier papers use complicated models/processing pipelines and have no pre-trained models available. Few recent papers, specifically UniLM and ProphetNet have SOTA pre-trained weights availble for QG but the usage seems quite complicated.
|
26 |
-
|
27 |
-
This project is aimed as an open source study on question generation with pre-trained transformers (specifically seq-2-seq models) using straight-forward end-to-end methods without much complicated pipelines. The goal is to provide simplified data processing and training scripts and easy to use pipelines for inference.
|
28 |
-
|
29 |
-
|
30 |
-
## Initial experiments
|
31 |
-
Initial experiments are conducted using the SQuADv1 dataset and T5 model with different input processing formats as described below.
|
32 |
-
|
33 |
-
### answer aware question generation
|
34 |
-
|
35 |
-
For answer aware models the input text can be processed in two ways.
|
36 |
-
|
37 |
-
**1. prepend format:**
|
38 |
-
|
39 |
-
Here the answer is simply added before the context and seperated by sep token. For example
|
40 |
-
|
41 |
-
`42 [SEP] 42 is the answer to life, the universe and everything.`
|
42 |
-
|
43 |
-
for T5 model the input is processed like this
|
44 |
-
|
45 |
-
`answer: 42 context: 42 is the answer to life, the universe and everything.`
|
46 |
-
|
47 |
-
**2. highlight format**
|
48 |
-
|
49 |
-
Here the answer span is highlighted within the text with special highlight tokens.
|
50 |
-
|
51 |
-
`<hl> 42 <hl> is the answer to life, the universe and everything.`
|
52 |
-
|
53 |
-
This idea is proposed in the "A Recurrent BERT-based Model for Question Generation" [paper](https://www.aclweb.org/anthology/D19-5821.pdf). See section 4.3
|
54 |
-
|
55 |
-
### answer extraction models
|
56 |
-
|
57 |
-
As the answer aware models need answers for generating question, we need something which can extract answer like spans from the text. This can be done using various methods like NER, noun-phrase extarction etc. But here a model is trained to extract answer like spans, to see how it'll work. With T5, answer extarction is done using the text-to-format.
|
58 |
-
|
59 |
-
As the highlight format will need to know the position of extracted answer spans the input for answer extraction is processed as follows
|
60 |
-
|
61 |
-
1. split the text into senteces.
|
62 |
-
2. for each sentence that has answers, highlight the sentence with `<hl>` tokens.
|
63 |
-
3. for the target text join the answers in that sentence with `<sep>` tokens.
|
64 |
-
|
65 |
-
For example for this text
|
66 |
-
|
67 |
-
`Python is a programming language. Created by Guido van Rossum and first released in 1991.`
|
68 |
-
|
69 |
-
following examples will be created
|
70 |
-
|
71 |
-
Input text:
|
72 |
-
`<hl> Python is a programming language. <hl> Created by Guido van Rossum and first released in 1991.`
|
73 |
-
|
74 |
-
target text:
|
75 |
-
`Python <sep>`
|
76 |
-
|
77 |
-
and
|
78 |
-
|
79 |
-
Input text:
|
80 |
-
`Python is a programming language. <hl> Created by Guido van Rossum and first released in 1991 <hl>.`
|
81 |
-
|
82 |
-
target text:
|
83 |
-
`Guido van Rossum <sep> 1991 <sep>`
|
84 |
-
|
85 |
-
At inference time the text is split into sentences and each sentence is highlighted.
|
86 |
-
|
87 |
-
### Multitask QA-QG
|
88 |
-
|
89 |
-
For answer aware question generation we usually need 3 models, first which will extract answer like spans, second model will generate question on that answer and third will be a QA model which will take the question and produce an answer,
|
90 |
-
then we can compare the two answers to see if the generated question is correct or not.
|
91 |
-
|
92 |
-
Having 3 models for single task is lot of complexity, so goal is to create a multi-task model which can do all of these 3 tasks
|
93 |
-
|
94 |
-
1. extract answer like spans
|
95 |
-
2. generate question based on the answer
|
96 |
-
3. QA
|
97 |
-
|
98 |
-
T5 model is fine-tuned in multi-task way using task prefixes as described in the paper.
|
99 |
-
|
100 |
-
<p align="center">
|
101 |
-
<img width="80%", src="https://i.ibb.co/TBS3nsr/t5-ss-2.png">
|
102 |
-
</p>
|
103 |
-
|
104 |
-
### End-to-End question generation (answer agnostic)
|
105 |
-
|
106 |
-
In end-to-end question generation the model is aksed to generate questions without providing the answers. [This](https://arxiv.org/pdf/2005.01107v1.pdf) paper discusses these ideas in more detail. Here the T5 model is trained to generate multiple questions simultaneously by just providing the context. The questions are seperated by the `<sep>` token. Here's how the examples are processed
|
107 |
-
|
108 |
-
input text: `Python is a programming language. Created by Guido van Rossum and first released in 1991.`
|
109 |
-
|
110 |
-
target text: `Who created Python ? <sep> When was python released ? <sep>`
|
111 |
-
|
112 |
-
**All the training details can be found in [this](https://app.wandb.ai/psuraj/question-generation) wandb project**
|
113 |
-
|
114 |
-
## Results
|
115 |
-
|
116 |
-
Results on the SQuAD1.0 dev set using above approaches. For decoding, beam search with num_beams 4 is used with max decoding length set to 32.
|
117 |
-
|
118 |
-
For multitask qa-qg models the EM and F1 scores are privded as QA-EM and QA-F1.
|
119 |
-
|
120 |
-
The [nlg-eval](https://github.com/Maluuba/nlg-eval) package is used for calculating the metrics.
|
121 |
-
|
122 |
-
|
123 |
-
| Name | BLEU-4 | METEOR | ROUGE-L | QA-EM | QA-F1 | QG-FORMAT |
|
124 |
-
|----------------------------------------------------------------------------|---------|---------|---------|--------|--------|-----------|
|
125 |
-
| [t5-base-qg-hl](https://huggingface.co/valhalla/t5-base-qg-hl) | 21.3226 | 27.0854 | 43.5962 | - | - | highlight |
|
126 |
-
| [t5-base-qa-qg-hl](https://huggingface.co/valhalla/t5-base-qa-qg-hl) | 21.0141 | 26.9113 | 43.2484 | 82.46 | 90.272 | highlight |
|
127 |
-
| [t5-small-qa-qg-hl](https://huggingface.co/valhalla/t5-small-qa-qg-hl) | 18.9872 | 25.2217 | 40.7893 | 76.121 | 84.904 | highlight |
|
128 |
-
| [t5-small-qg-hl](https://huggingface.co/valhalla/t5-small-qg-hl) | 18.5921 | 24.9915 | 40.1886 | - | - | highlight |
|
129 |
-
| [t5-small-qg-prepend](https://huggingface.co/valhalla/t5-small-qg-prepend) | 18.2791 | 24.6722 | 39.958 | - | - | prepend |
|
130 |
-
|
131 |
-
|
132 |
-
## Requirements
|
133 |
-
```
|
134 |
-
transformers==3.0.0
|
135 |
-
nltk
|
136 |
-
nlp==0.2.0 # only if you want to fine-tune.
|
137 |
-
```
|
138 |
-
|
139 |
-
after installing `nltk` do
|
140 |
-
```bash
|
141 |
-
python -m nltk.downloader punkt
|
142 |
-
```
|
143 |
-
|
144 |
-
## Usage
|
145 |
-
Use the pipeline whch mimics 🤗transformers pipeline for easy inference.
|
146 |
-
|
147 |
-
The pipeline is divided into 3 tasks
|
148 |
-
1. `question-generation`: for single task question generation models.
|
149 |
-
2. `multitask-qa-qg`: for multi-task qa,qg models.
|
150 |
-
3. `e2e-qg`: for end-to-end question generation.
|
151 |
-
|
152 |
-
[](https://colab.research.google.com/github/patil-suraj/question_generation/blob/master/question_generation.ipynb)
|
153 |
-
|
154 |
-
#### Question Generation
|
155 |
-
|
156 |
-
```python3
|
157 |
-
from pipelines import pipeline
|
158 |
-
|
159 |
-
nlp = pipeline("question-generation")
|
160 |
-
nlp("42 is the answer to life, the universe and everything.")
|
161 |
-
=> [{'answer': '42', 'question': 'What is the answer to life, the universe and everything?'}]
|
162 |
-
```
|
163 |
-
|
164 |
-
**prepend format**
|
165 |
-
```python3
|
166 |
-
nlp = pipeline("question-generation", model="valhalla/t5-small-qg-prepend", qg_format="prepend")
|
167 |
-
nlp("42 is the answer to life, the universe and everything.")
|
168 |
-
=> [{'answer': '42 ', 'question': 'What is the answer to life, the universe, and everything?'}]
|
169 |
-
```
|
170 |
-
|
171 |
-
#### Multitask QA-QG
|
172 |
-
```python3
|
173 |
-
nlp = pipeline("multitask-qa-qg")
|
174 |
-
|
175 |
-
# to generate questions simply pass the text
|
176 |
-
nlp("42 is the answer to life, the universe and everything.")
|
177 |
-
=> [{'answer': '42', 'question': 'What is the answer to life, the universe and everything?'}]
|
178 |
-
|
179 |
-
# for qa pass a dict with "question" and "context"
|
180 |
-
nlp({
|
181 |
-
"question": "What is 42 ?",
|
182 |
-
"context": "42 is the answer to life, the universe and everything."
|
183 |
-
})
|
184 |
-
=> 'the answer to life, the universe and everything'
|
185 |
-
```
|
186 |
-
|
187 |
-
#### End-to-end question generation (without answer supervision)
|
188 |
-
```python3
|
189 |
-
nlp = pipeline("e2e-qg")
|
190 |
-
nlp("Python is a programming language. Created by Guido van Rossum and first released in 1991.")
|
191 |
-
=> [
|
192 |
-
'What is a programming language?',
|
193 |
-
'Who created Python?',
|
194 |
-
'When was Python first released?'
|
195 |
-
]
|
196 |
-
```
|
197 |
-
|
198 |
-
By default both pipelines will use the t5-small* models, to use the other models pass the path through `model` paramter.
|
199 |
-
|
200 |
-
By default the `question-generation` pipeline will download the [valhalla/t5-small-qg-hl](https://huggingface.co/valhalla/t5-small-qg-hl) model with `highlight` qg format. If you want to use prepend format then provide the path to the prepend model and set `qg_format` to `"prepend"`. For extracting answer like spans it uses [valhalla/t5-small-qa-qg-hl](https://huggingface.co/valhalla/t5-small-qa-qg-hl) model, you can provide a different model through `ans_model` parameter.
|
201 |
-
|
202 |
-
The `multitask-qa-qg` model is for multitask models which can extract answer like spans, do qg and qa, so it won't need seperate `ans_model`. By default [valhalla/t5-small-qa-qg-hl](https://huggingface.co/valhalla/t5-small-qa-qg-hl) model is used with `highlight` format. If you want to use prepend format then provide the path to the prepend model and set `qg_format` to `"prepend"`
|
203 |
-
|
204 |
-
The `e2e-qg` pipeline is for end-to-end question generation. These models can generate multiple questions simultaneously without answer supervision. By default it uses [valhalla/t5-small-e2e-qg](https://huggingface.co/valhalla/t5-small-e2e-qg)
|
205 |
-
|
206 |
-
## Fine-tuning
|
207 |
-
|
208 |
-
### Data processing
|
209 |
-
|
210 |
-
To support different data formats the trainer expects pre-processed cached dataset, so you can process the data the way you want.
|
211 |
-
The cached dataset should be saved using `torch.save` and it should return a `dict` with `source_ids`, `target_ids`, `attention_mask` keys from `__getitem__`.
|
212 |
-
|
213 |
-
- `source_ids`: encoded source text
|
214 |
-
- `target_ids`: encoded target text
|
215 |
-
- `attention_mask`: attention mask for the `source_ids`
|
216 |
-
|
217 |
-
The `T2TDataCollator` takes care of preparing right `input_ids` and `labels`. It also trims the batches dynamically to remove excessive padding tokens, to speed up the training.
|
218 |
-
|
219 |
-
The `data/squad_multitask` containes the modifed SQuAD dataset for answer aware question generation (using both prepend and highlight formats), question answering (text-to-text), answer extraction and end-to-end question generation. This dataset can be loaded using the awesome 🤗`nlp` library, this makes processing very easy.
|
220 |
-
|
221 |
-
To process and cache the dataset use `prepare_data.py` script. It will load the correct tokenizer depending on the `model_type` argument. It adds two new tokens `<sep>` and `<hl>` to the tokenizer and saves it at `{model_type}_qg_tokenizer` path. You should pass this tokenizer to the fine-tuning script.
|
222 |
-
|
223 |
-
The datasets will be saved in `data/` directory. You should provide filenames using `train_file_name` and `valid_file_name` arguments.
|
224 |
-
|
225 |
-
**process data for single task question generation with highlight_qg_format**
|
226 |
-
```bash
|
227 |
-
python prepare_data.py \
|
228 |
-
--task qg \
|
229 |
-
--model_type t5 \
|
230 |
-
--dataset_path data/squad_multitask/ \
|
231 |
-
--qg_format highlight_qg_format \
|
232 |
-
--max_source_length 512 \
|
233 |
-
--max_target_length 32 \
|
234 |
-
--train_file_name train_data_qg_hl_t5.pt \
|
235 |
-
--valid_file_name valid_data_qg_hl_t5.pt \
|
236 |
-
```
|
237 |
-
|
238 |
-
**process data for multi-task qa-qg with highlight_qg_format**
|
239 |
-
|
240 |
-
`valid_for_qg_only` argument is used to decide if the validation set should only contain data for qg task. For my multi-task experiments I used validation data with only qg task so that the eval loss curve can be easly compared with other single task models
|
241 |
-
|
242 |
-
```bash
|
243 |
-
python prepare_data.py \
|
244 |
-
--task multi \
|
245 |
-
--valid_for_qg_only \
|
246 |
-
--model_type t5 \
|
247 |
-
--dataset_path data/squad_multitask/ \
|
248 |
-
--qg_format highlight_qg_format \
|
249 |
-
--max_source_length 512 \
|
250 |
-
--max_target_length 32 \
|
251 |
-
--train_file_name train_data_qa_qg_hl_t5.pt \
|
252 |
-
--valid_file_name valid_data_qg_hl_t5.pt \
|
253 |
-
```
|
254 |
-
|
255 |
-
**process dataset for end-to-end question generation**
|
256 |
-
```bash
|
257 |
-
python prepare_data.py \
|
258 |
-
--task e2e_qg \
|
259 |
-
--valid_for_qg_only \
|
260 |
-
--model_type t5 \
|
261 |
-
--dataset_path data/squad_multitask/ \
|
262 |
-
--qg_format highlight_qg_format \
|
263 |
-
--max_source_length 512 \
|
264 |
-
--max_target_length 32 \
|
265 |
-
--train_file_name train_data_e2e_qg_t5.pt \
|
266 |
-
--valid_file_name valid_data_e2e_qg_t5.pt \
|
267 |
-
```
|
268 |
-
|
269 |
-
### training
|
270 |
-
Use the `run_qg.py` script to start training. It uses transformers `Trainer` class for training the models.
|
271 |
-
|
272 |
-
|
273 |
-
```bash
|
274 |
-
python run_qg.py \
|
275 |
-
--model_name_or_path t5-small \
|
276 |
-
--model_type t5 \
|
277 |
-
--tokenizer_name_or_path t5_qg_tokenizer \
|
278 |
-
--output_dir t5-small-qg-hl \
|
279 |
-
--train_file_path data/train_data_qg_hl_t5.pt \
|
280 |
-
--valid_file_path data/valid_data_qg_hl_t5.pt \
|
281 |
-
--per_device_train_batch_size 32 \
|
282 |
-
--per_device_eval_batch_size 32 \
|
283 |
-
--gradient_accumulation_steps 8 \
|
284 |
-
--learning_rate 1e-4 \
|
285 |
-
--num_train_epochs 10 \
|
286 |
-
--seed 42 \
|
287 |
-
--do_train \
|
288 |
-
--do_eval \
|
289 |
-
--evaluate_during_training \
|
290 |
-
--logging_steps 100
|
291 |
-
```
|
292 |
-
|
293 |
-
or if you want to train it from script or notebook then
|
294 |
-
|
295 |
-
```python3
|
296 |
-
from run_qg import run_qg
|
297 |
-
|
298 |
-
args_dict = {
|
299 |
-
"model_name_or_path": "t5-small",
|
300 |
-
"model_type": "t5",
|
301 |
-
"tokenizer_name_or_path": "t5_qg_tokenizer",
|
302 |
-
"output_dir": "t5-small-qg-hl",
|
303 |
-
"train_file_path": "data/train_data_qg_hl_t5.pt",
|
304 |
-
"valid_file_path": "data/valid_data_qg_hl_t5.pt",
|
305 |
-
"per_device_train_batch_size": 32,
|
306 |
-
"per_device_eval_batch_size": 32,
|
307 |
-
"gradient_accumulation_steps": 8,
|
308 |
-
"learning_rate": 1e-4,
|
309 |
-
"num_train_epochs": 10,
|
310 |
-
"seed": 42,
|
311 |
-
"do_train": True,
|
312 |
-
"do_eval": True,
|
313 |
-
"evaluate_during_training": True,
|
314 |
-
"logging_steps": 100
|
315 |
-
}
|
316 |
-
|
317 |
-
# start training
|
318 |
-
run_qg(args_dict)
|
319 |
-
```
|
320 |
-
|
321 |
-
### Evaluation
|
322 |
-
|
323 |
-
Use the `eval.py` script for evaluting the model.
|
324 |
-
|
325 |
-
```bash
|
326 |
-
python eval.py \
|
327 |
-
--model_name_or_path t5-base-qg-hl \
|
328 |
-
--valid_file_path valid_data_qg_hl_t5.pt \
|
329 |
-
--model_type t5 \
|
330 |
-
--num_beams 4 \
|
331 |
-
--max_decoding_length 32 \
|
332 |
-
--output_path hypothesis_t5-base-qg-hl.txt
|
333 |
-
```
|
334 |
-
|
335 |
-
This will save the output at {output_path} file.
|
336 |
-
|
337 |
-
To calculate the metrics install the [nlg-eval](https://github.com/Maluuba/nlg-eval) package and run
|
338 |
-
|
339 |
-
```bash
|
340 |
-
nlg-eval --hypothesis=hypothesis_t5-base-qg-hl.txt --references=data/references.txt --no-skipthoughts --no-glove
|
341 |
-
```
|
342 |
-
|
343 |
-
## Applications 🚀
|
344 |
-
|
345 |
-
1. A simple Trivia Quiz on topics of your choice - <br/>
|
346 |
-
[Medium article](https://medium.com/@nvarshney97/using-the-latest-nlp-techniques-for-fun-98f31ce7b556) and its [Colab Notebook](https://colab.research.google.com/gist/nrjvarshney/39ed6c80e2fe293b9e7eca5bc3a45b7d/quiz.ipynb)
|
347 |
-
2. [Autocards, Accelerating learning through machine-generated flashcards](https://paulbricman.com/docs/tools/autocards/)
|
348 |
-
|
349 |
-
## Relevant papers
|
350 |
-
- https://arxiv.org/abs/1906.05416
|
351 |
-
- https://www.aclweb.org/anthology/D19-5821/
|
352 |
-
- https://arxiv.org/abs/2005.01107v1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/question_generation/__init__.py.py
DELETED
File without changes
|
code/question_generation/data_collator.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
from typing import Dict, List, Optional
|
2 |
-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
def trim_batch(
|
7 |
-
input_ids, pad_token_id, attention_mask=None,
|
8 |
-
):
|
9 |
-
"""Remove columns that are populated exclusively by pad_token_id"""
|
10 |
-
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
11 |
-
if attention_mask is None:
|
12 |
-
return input_ids[:, keep_column_mask]
|
13 |
-
else:
|
14 |
-
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
15 |
-
|
16 |
-
|
17 |
-
# prepares lm_labels from target_ids, returns examples with keys as expected by the forward method
|
18 |
-
# this is necessacry because the trainer directly passes this dict as arguments to the model
|
19 |
-
# so make sure the keys match the parameter names of the forward method
|
20 |
-
class T2TDataCollator():
|
21 |
-
def __init__(self, tokenizer, model_type="t5", mode='training', using_tpu=False):
|
22 |
-
self.tokenizer = tokenizer
|
23 |
-
self.model_type = model_type
|
24 |
-
self.mode = mode
|
25 |
-
self.using_tpu = using_tpu
|
26 |
-
|
27 |
-
def __call__(self, batch: List) -> Dict[str, torch.Tensor]:
|
28 |
-
"""
|
29 |
-
Take a list of samples from a Dataset and collate them into a batch.
|
30 |
-
Returns:
|
31 |
-
A dictionary of tensors
|
32 |
-
"""
|
33 |
-
input_ids = torch.stack([example['source_ids'] for example in batch])
|
34 |
-
target_ids = torch.stack([example['target_ids'] for example in batch])
|
35 |
-
attention_mask = torch.stack([example['attention_mask'] for example in batch])
|
36 |
-
|
37 |
-
pad_token_id = self.tokenizer.pad_token_id
|
38 |
-
|
39 |
-
# don't trim on tpu, for some reason trimming leads to slower training on TPU
|
40 |
-
if not self.using_tpu:
|
41 |
-
input_ids, attention_mask = trim_batch(input_ids, pad_token_id, attention_mask=attention_mask)
|
42 |
-
target_ids = trim_batch(target_ids, pad_token_id)
|
43 |
-
|
44 |
-
if self.model_type == "t5":
|
45 |
-
lm_labels = target_ids.clone()
|
46 |
-
decoder_input_ids = self._shift_right_t5(lm_labels)
|
47 |
-
if self.mode == 'training':
|
48 |
-
lm_labels[lm_labels[:, :] == pad_token_id] = -100
|
49 |
-
else:
|
50 |
-
decoder_input_ids = target_ids[:, :-1].contiguous()
|
51 |
-
lm_labels = target_ids[:, 1:].clone()
|
52 |
-
if self.mode == 'training':
|
53 |
-
lm_labels[target_ids[:, 1:] == pad_token_id] = -100
|
54 |
-
|
55 |
-
params = {
|
56 |
-
"input_ids": input_ids,
|
57 |
-
"attention_mask": attention_mask,
|
58 |
-
"labels": lm_labels,
|
59 |
-
"decoder_input_ids": decoder_input_ids
|
60 |
-
}
|
61 |
-
|
62 |
-
return params
|
63 |
-
|
64 |
-
def _shift_right_t5(self, input_ids):
|
65 |
-
decoder_start_token_id = self.tokenizer.pad_token_id
|
66 |
-
pad_token_id = self.tokenizer.pad_token_id
|
67 |
-
|
68 |
-
assert (
|
69 |
-
decoder_start_token_id is not None
|
70 |
-
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
71 |
-
|
72 |
-
# shift inputs to the right
|
73 |
-
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
74 |
-
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
75 |
-
shifted_input_ids[..., 0] = decoder_start_token_id
|
76 |
-
|
77 |
-
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
|
78 |
-
# replace possible -100 values in labels by `pad_token_id`
|
79 |
-
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
80 |
-
|
81 |
-
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `labels` has only positive values and -100"
|
82 |
-
|
83 |
-
return shifted_input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/question_generation/eval.py
DELETED
@@ -1,92 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from dataclasses import dataclass, field
|
3 |
-
from typing import Optional
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from tqdm.auto import tqdm
|
7 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser
|
8 |
-
|
9 |
-
from data_collator import T2TDataCollator
|
10 |
-
|
11 |
-
device = 'cuda' if torch.cuda.is_available else 'cpu'
|
12 |
-
|
13 |
-
logger = logging.getLogger(__name__)
|
14 |
-
|
15 |
-
@dataclass
|
16 |
-
class EvalArguments:
|
17 |
-
model_name_or_path: str = field(
|
18 |
-
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
19 |
-
)
|
20 |
-
valid_file_path: str = field(
|
21 |
-
metadata={"help": "Path for cached valid dataset"}
|
22 |
-
)
|
23 |
-
model_type: str = field(metadata={"help": "One of 't5', 'bart'"})
|
24 |
-
tokenizer_name_or_path: Optional[str] = field(
|
25 |
-
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
26 |
-
)
|
27 |
-
num_beams: Optional[int] = field(
|
28 |
-
default=4,
|
29 |
-
metadata={"help": "num_beams to use for decoding"}
|
30 |
-
)
|
31 |
-
max_decoding_length: Optional[int] = field(
|
32 |
-
default=32,
|
33 |
-
metadata={"help": "maximum length for decoding"}
|
34 |
-
)
|
35 |
-
output_path: Optional[str] = field(
|
36 |
-
default="hypothesis.txt",
|
37 |
-
metadata={"help": "path to save the generated questions."}
|
38 |
-
)
|
39 |
-
|
40 |
-
def get_predictions(model, tokenizer, data_loader, num_beams=4, max_length=32, length_penalty=1):
|
41 |
-
model.to(device)
|
42 |
-
|
43 |
-
predictions = []
|
44 |
-
model.eval()
|
45 |
-
with torch.no_grad():
|
46 |
-
for batch in tqdm(data_loader):
|
47 |
-
outs = model.generate(
|
48 |
-
input_ids=batch['input_ids'].to(device),
|
49 |
-
attention_mask=batch['attention_mask'].to(device),
|
50 |
-
num_beams=num_beams,
|
51 |
-
max_length=max_length,
|
52 |
-
length_penalty=length_penalty,
|
53 |
-
)
|
54 |
-
|
55 |
-
prediction = [tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
|
56 |
-
predictions.extend(prediction)
|
57 |
-
|
58 |
-
return predictions
|
59 |
-
|
60 |
-
def main():
|
61 |
-
parser = HfArgumentParser((EvalArguments,))
|
62 |
-
args = parser.parse_args_into_dataclasses()[0]
|
63 |
-
|
64 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
65 |
-
args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path,
|
66 |
-
)
|
67 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
|
68 |
-
|
69 |
-
valid_dataset = torch.load(args.valid_file_path)
|
70 |
-
collator = T2TDataCollator(
|
71 |
-
tokenizer=tokenizer,
|
72 |
-
model_type=args.model_type,
|
73 |
-
mode="inference"
|
74 |
-
)
|
75 |
-
loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, collate_fn=collator)
|
76 |
-
|
77 |
-
predictions = get_predictions(
|
78 |
-
model=model,
|
79 |
-
tokenizer=tokenizer,
|
80 |
-
data_loader=loader,
|
81 |
-
num_beams=args.num_beams,
|
82 |
-
max_length=args.max_decoding_length
|
83 |
-
)
|
84 |
-
|
85 |
-
with open(args.output_path, 'w') as f:
|
86 |
-
f.write("\n".join(predictions))
|
87 |
-
|
88 |
-
logging.info(f"Output saved at {args.output_path}")
|
89 |
-
|
90 |
-
|
91 |
-
if __name__ == "__main__":
|
92 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/question_generation/pipelines.py
DELETED
@@ -1,386 +0,0 @@
|
|
1 |
-
import itertools
|
2 |
-
import logging
|
3 |
-
from typing import Optional, Dict, Union
|
4 |
-
|
5 |
-
from nltk import sent_tokenize
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from transformers import(
|
9 |
-
AutoModelForSeq2SeqLM,
|
10 |
-
AutoTokenizer,
|
11 |
-
PreTrainedModel,
|
12 |
-
PreTrainedTokenizer,
|
13 |
-
)
|
14 |
-
|
15 |
-
logger = logging.getLogger(__name__)
|
16 |
-
|
17 |
-
class QGPipeline:
|
18 |
-
"""Poor man's QG pipeline"""
|
19 |
-
def __init__(
|
20 |
-
self,
|
21 |
-
model: PreTrainedModel,
|
22 |
-
tokenizer: PreTrainedTokenizer,
|
23 |
-
ans_model: PreTrainedModel,
|
24 |
-
ans_tokenizer: PreTrainedTokenizer,
|
25 |
-
qg_format: str,
|
26 |
-
use_cuda: bool
|
27 |
-
):
|
28 |
-
self.model = model
|
29 |
-
self.tokenizer = tokenizer
|
30 |
-
|
31 |
-
self.ans_model = ans_model
|
32 |
-
self.ans_tokenizer = ans_tokenizer
|
33 |
-
|
34 |
-
self.qg_format = qg_format
|
35 |
-
|
36 |
-
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
|
37 |
-
self.model.to(self.device)
|
38 |
-
|
39 |
-
if self.ans_model is not self.model:
|
40 |
-
self.ans_model.to(self.device)
|
41 |
-
|
42 |
-
assert self.model.__class__.__name__ in ["T5ForConditionalGeneration", "BartForConditionalGeneration"]
|
43 |
-
|
44 |
-
if "T5ForConditionalGeneration" in self.model.__class__.__name__:
|
45 |
-
self.model_type = "t5"
|
46 |
-
else:
|
47 |
-
self.model_type = "bart"
|
48 |
-
|
49 |
-
def __call__(self, inputs: str):
|
50 |
-
inputs = " ".join(inputs.split())
|
51 |
-
sents, answers = self._extract_answers(inputs)
|
52 |
-
flat_answers = list(itertools.chain(*answers))
|
53 |
-
|
54 |
-
if len(flat_answers) == 0:
|
55 |
-
return []
|
56 |
-
|
57 |
-
if self.qg_format == "prepend":
|
58 |
-
qg_examples = self._prepare_inputs_for_qg_from_answers_prepend(inputs, answers)
|
59 |
-
else:
|
60 |
-
qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)
|
61 |
-
|
62 |
-
qg_inputs = [example['source_text'] for example in qg_examples]
|
63 |
-
questions = self._generate_questions(qg_inputs)
|
64 |
-
output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)]
|
65 |
-
return output
|
66 |
-
|
67 |
-
def _generate_questions(self, inputs):
|
68 |
-
inputs = self._tokenize(inputs, padding=True, truncation=True)
|
69 |
-
|
70 |
-
outs = self.model.generate(
|
71 |
-
input_ids=inputs['input_ids'].to(self.device),
|
72 |
-
attention_mask=inputs['attention_mask'].to(self.device),
|
73 |
-
max_length=32,
|
74 |
-
num_beams=4,
|
75 |
-
)
|
76 |
-
|
77 |
-
questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
|
78 |
-
return questions
|
79 |
-
|
80 |
-
def _extract_answers(self, context):
|
81 |
-
sents, inputs = self._prepare_inputs_for_ans_extraction(context)
|
82 |
-
inputs = self._tokenize(inputs, padding=True, truncation=True)
|
83 |
-
|
84 |
-
outs = self.ans_model.generate(
|
85 |
-
input_ids=inputs['input_ids'].to(self.device),
|
86 |
-
attention_mask=inputs['attention_mask'].to(self.device),
|
87 |
-
max_length=32,
|
88 |
-
)
|
89 |
-
|
90 |
-
dec = [self.ans_tokenizer.decode(ids, skip_special_tokens=False) for ids in outs]
|
91 |
-
answers = [item.split('<sep>') for item in dec]
|
92 |
-
answers = [i[:-1] for i in answers]
|
93 |
-
|
94 |
-
return sents, answers
|
95 |
-
|
96 |
-
def _tokenize(self,
|
97 |
-
inputs,
|
98 |
-
padding=True,
|
99 |
-
truncation=True,
|
100 |
-
add_special_tokens=True,
|
101 |
-
max_length=512
|
102 |
-
):
|
103 |
-
inputs = self.tokenizer.batch_encode_plus(
|
104 |
-
inputs,
|
105 |
-
max_length=max_length,
|
106 |
-
add_special_tokens=add_special_tokens,
|
107 |
-
truncation=truncation,
|
108 |
-
padding="max_length" if padding else False,
|
109 |
-
pad_to_max_length=padding,
|
110 |
-
return_tensors="pt"
|
111 |
-
)
|
112 |
-
return inputs
|
113 |
-
|
114 |
-
def _prepare_inputs_for_ans_extraction(self, text):
|
115 |
-
sents = sent_tokenize(text)
|
116 |
-
|
117 |
-
inputs = []
|
118 |
-
for i in range(len(sents)):
|
119 |
-
source_text = "extract answers:"
|
120 |
-
for j, sent in enumerate(sents):
|
121 |
-
if i == j:
|
122 |
-
sent = "<hl> %s <hl>" % sent
|
123 |
-
source_text = "%s %s" % (source_text, sent)
|
124 |
-
source_text = source_text.strip()
|
125 |
-
|
126 |
-
if self.model_type == "t5":
|
127 |
-
source_text = source_text + " </s>"
|
128 |
-
inputs.append(source_text)
|
129 |
-
|
130 |
-
return sents, inputs
|
131 |
-
|
132 |
-
def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
|
133 |
-
inputs = []
|
134 |
-
for i, answer in enumerate(answers):
|
135 |
-
if len(answer) == 0: continue
|
136 |
-
for answer_text in answer:
|
137 |
-
sent = sents[i]
|
138 |
-
sents_copy = sents[:]
|
139 |
-
|
140 |
-
answer_text = answer_text.strip()
|
141 |
-
|
142 |
-
ans_start_idx = sent.index(answer_text)
|
143 |
-
|
144 |
-
sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
|
145 |
-
sents_copy[i] = sent
|
146 |
-
|
147 |
-
source_text = " ".join(sents_copy)
|
148 |
-
source_text = f"generate question: {source_text}"
|
149 |
-
if self.model_type == "t5":
|
150 |
-
source_text = source_text + " </s>"
|
151 |
-
|
152 |
-
inputs.append({"answer": answer_text, "source_text": source_text})
|
153 |
-
|
154 |
-
return inputs
|
155 |
-
|
156 |
-
def _prepare_inputs_for_qg_from_answers_prepend(self, context, answers):
|
157 |
-
flat_answers = list(itertools.chain(*answers))
|
158 |
-
examples = []
|
159 |
-
for answer in flat_answers:
|
160 |
-
source_text = f"answer: {answer} context: {context}"
|
161 |
-
if self.model_type == "t5":
|
162 |
-
source_text = source_text + " </s>"
|
163 |
-
|
164 |
-
examples.append({"answer": answer, "source_text": source_text})
|
165 |
-
return examples
|
166 |
-
|
167 |
-
|
168 |
-
class MultiTaskQAQGPipeline(QGPipeline):
|
169 |
-
def __init__(self, **kwargs):
|
170 |
-
super().__init__(**kwargs)
|
171 |
-
|
172 |
-
def __call__(self, inputs: Union[Dict, str]):
|
173 |
-
if type(inputs) is str:
|
174 |
-
# do qg
|
175 |
-
return super().__call__(inputs)
|
176 |
-
else:
|
177 |
-
# do qa
|
178 |
-
return self._extract_answer(inputs["question"], inputs["context"])
|
179 |
-
|
180 |
-
def _prepare_inputs_for_qa(self, question, context):
|
181 |
-
source_text = f"question: {question} context: {context}"
|
182 |
-
if self.model_type == "t5":
|
183 |
-
source_text = source_text + " </s>"
|
184 |
-
return source_text
|
185 |
-
|
186 |
-
def _extract_answer(self, question, context):
|
187 |
-
source_text = self._prepare_inputs_for_qa(question, context)
|
188 |
-
inputs = self._tokenize([source_text], padding=False)
|
189 |
-
|
190 |
-
outs = self.model.generate(
|
191 |
-
input_ids=inputs['input_ids'].to(self.device),
|
192 |
-
attention_mask=inputs['attention_mask'].to(self.device),
|
193 |
-
max_length=16,
|
194 |
-
)
|
195 |
-
|
196 |
-
answer = self.tokenizer.decode(outs[0], skip_special_tokens=True)
|
197 |
-
return answer
|
198 |
-
|
199 |
-
|
200 |
-
class E2EQGPipeline:
|
201 |
-
def __init__(
|
202 |
-
self,
|
203 |
-
model: PreTrainedModel,
|
204 |
-
tokenizer: PreTrainedTokenizer,
|
205 |
-
use_cuda: bool
|
206 |
-
) :
|
207 |
-
|
208 |
-
self.model = model
|
209 |
-
self.tokenizer = tokenizer
|
210 |
-
|
211 |
-
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
|
212 |
-
self.model.to(self.device)
|
213 |
-
|
214 |
-
assert self.model.__class__.__name__ in ["T5ForConditionalGeneration", "BartForConditionalGeneration"]
|
215 |
-
|
216 |
-
if "T5ForConditionalGeneration" in self.model.__class__.__name__:
|
217 |
-
self.model_type = "t5"
|
218 |
-
else:
|
219 |
-
self.model_type = "bart"
|
220 |
-
|
221 |
-
self.default_generate_kwargs = {
|
222 |
-
"max_length": 256,
|
223 |
-
"num_beams": 4,
|
224 |
-
"length_penalty": 1.5,
|
225 |
-
"no_repeat_ngram_size": 3,
|
226 |
-
"early_stopping": True,
|
227 |
-
}
|
228 |
-
|
229 |
-
def __call__(self, context: str, **generate_kwargs):
|
230 |
-
inputs = self._prepare_inputs_for_e2e_qg(context)
|
231 |
-
|
232 |
-
# TODO: when overrding default_generate_kwargs all other arguments need to be passsed
|
233 |
-
# find a better way to do this
|
234 |
-
if not generate_kwargs:
|
235 |
-
generate_kwargs = self.default_generate_kwargs
|
236 |
-
|
237 |
-
input_length = inputs["input_ids"].shape[-1]
|
238 |
-
|
239 |
-
# max_length = generate_kwargs.get("max_length", 256)
|
240 |
-
# if input_length < max_length:
|
241 |
-
# logger.warning(
|
242 |
-
# "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
|
243 |
-
# max_length, input_length
|
244 |
-
# )
|
245 |
-
# )
|
246 |
-
|
247 |
-
outs = self.model.generate(
|
248 |
-
input_ids=inputs['input_ids'].to(self.device),
|
249 |
-
attention_mask=inputs['attention_mask'].to(self.device),
|
250 |
-
**generate_kwargs
|
251 |
-
)
|
252 |
-
|
253 |
-
prediction = self.tokenizer.decode(outs[0], skip_special_tokens=True)
|
254 |
-
questions = prediction.split("<sep>")
|
255 |
-
questions = [question.strip() for question in questions[:-1]]
|
256 |
-
return questions
|
257 |
-
|
258 |
-
def _prepare_inputs_for_e2e_qg(self, context):
|
259 |
-
source_text = f"generate questions: {context}"
|
260 |
-
if self.model_type == "t5":
|
261 |
-
source_text = source_text + " </s>"
|
262 |
-
|
263 |
-
inputs = self._tokenize([source_text], padding=False)
|
264 |
-
return inputs
|
265 |
-
|
266 |
-
def _tokenize(
|
267 |
-
self,
|
268 |
-
inputs,
|
269 |
-
padding=True,
|
270 |
-
truncation=True,
|
271 |
-
add_special_tokens=True,
|
272 |
-
max_length=512
|
273 |
-
):
|
274 |
-
inputs = self.tokenizer.batch_encode_plus(
|
275 |
-
inputs,
|
276 |
-
max_length=max_length,
|
277 |
-
add_special_tokens=add_special_tokens,
|
278 |
-
truncation=truncation,
|
279 |
-
padding="max_length" if padding else False,
|
280 |
-
pad_to_max_length=padding,
|
281 |
-
return_tensors="pt"
|
282 |
-
)
|
283 |
-
return inputs
|
284 |
-
|
285 |
-
|
286 |
-
SUPPORTED_TASKS = {
|
287 |
-
"question-generation": {
|
288 |
-
"impl": QGPipeline,
|
289 |
-
"default": {
|
290 |
-
"model": "valhalla/t5-small-qg-hl",
|
291 |
-
"ans_model": "valhalla/t5-small-qa-qg-hl",
|
292 |
-
}
|
293 |
-
},
|
294 |
-
"multitask-qa-qg": {
|
295 |
-
"impl": MultiTaskQAQGPipeline,
|
296 |
-
"default": {
|
297 |
-
"model": "valhalla/t5-small-qa-qg-hl",
|
298 |
-
}
|
299 |
-
},
|
300 |
-
"e2e-qg": {
|
301 |
-
"impl": E2EQGPipeline,
|
302 |
-
"default": {
|
303 |
-
"model": "valhalla/t5-small-e2e-qg",
|
304 |
-
}
|
305 |
-
}
|
306 |
-
}
|
307 |
-
|
308 |
-
def pipeline(
|
309 |
-
task: str,
|
310 |
-
model: Optional = None,
|
311 |
-
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
|
312 |
-
qg_format: Optional[str] = "highlight",
|
313 |
-
ans_model: Optional = None,
|
314 |
-
ans_tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
|
315 |
-
use_cuda: Optional[bool] = True,
|
316 |
-
**kwargs,
|
317 |
-
):
|
318 |
-
# Retrieve the task
|
319 |
-
if task not in SUPPORTED_TASKS:
|
320 |
-
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
|
321 |
-
|
322 |
-
targeted_task = SUPPORTED_TASKS[task]
|
323 |
-
task_class = targeted_task["impl"]
|
324 |
-
|
325 |
-
# Use default model/config/tokenizer for the task if no model is provided
|
326 |
-
if model is None:
|
327 |
-
model = targeted_task["default"]["model"]
|
328 |
-
|
329 |
-
# Try to infer tokenizer from model or config name (if provided as str)
|
330 |
-
if tokenizer is None:
|
331 |
-
if isinstance(model, str):
|
332 |
-
tokenizer = model
|
333 |
-
else:
|
334 |
-
# Impossible to guest what is the right tokenizer here
|
335 |
-
raise Exception(
|
336 |
-
"Impossible to guess which tokenizer to use. "
|
337 |
-
"Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
|
338 |
-
)
|
339 |
-
|
340 |
-
# Instantiate tokenizer if needed
|
341 |
-
if isinstance(tokenizer, (str, tuple)):
|
342 |
-
if isinstance(tokenizer, tuple):
|
343 |
-
# For tuple we have (tokenizer name, {kwargs})
|
344 |
-
tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1])
|
345 |
-
else:
|
346 |
-
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
347 |
-
|
348 |
-
# Instantiate model if needed
|
349 |
-
if isinstance(model, str):
|
350 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model)
|
351 |
-
|
352 |
-
if task == "question-generation":
|
353 |
-
if ans_model is None:
|
354 |
-
# load default ans model
|
355 |
-
ans_model = targeted_task["default"]["ans_model"]
|
356 |
-
ans_tokenizer = AutoTokenizer.from_pretrained(ans_model)
|
357 |
-
ans_model = AutoModelForSeq2SeqLM.from_pretrained(ans_model)
|
358 |
-
else:
|
359 |
-
# Try to infer tokenizer from model or config name (if provided as str)
|
360 |
-
if ans_tokenizer is None:
|
361 |
-
if isinstance(ans_model, str):
|
362 |
-
ans_tokenizer = ans_model
|
363 |
-
else:
|
364 |
-
# Impossible to guest what is the right tokenizer here
|
365 |
-
raise Exception(
|
366 |
-
"Impossible to guess which tokenizer to use. "
|
367 |
-
"Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
|
368 |
-
)
|
369 |
-
|
370 |
-
# Instantiate tokenizer if needed
|
371 |
-
if isinstance(ans_tokenizer, (str, tuple)):
|
372 |
-
if isinstance(ans_tokenizer, tuple):
|
373 |
-
# For tuple we have (tokenizer name, {kwargs})
|
374 |
-
ans_tokenizer = AutoTokenizer.from_pretrained(ans_tokenizer[0], **ans_tokenizer[1])
|
375 |
-
else:
|
376 |
-
ans_tokenizer = AutoTokenizer.from_pretrained(ans_tokenizer)
|
377 |
-
|
378 |
-
if isinstance(ans_model, str):
|
379 |
-
ans_model = AutoModelForSeq2SeqLM.from_pretrained(ans_model)
|
380 |
-
|
381 |
-
if task == "e2e-qg":
|
382 |
-
return task_class(model=model, tokenizer=tokenizer, use_cuda=use_cuda)
|
383 |
-
elif task == "question-generation":
|
384 |
-
return task_class(model=model, tokenizer=tokenizer, ans_model=ans_model, ans_tokenizer=ans_tokenizer, qg_format=qg_format, use_cuda=use_cuda)
|
385 |
-
else:
|
386 |
-
return task_class(model=model, tokenizer=tokenizer, ans_model=model, ans_tokenizer=tokenizer, qg_format=qg_format, use_cuda=use_cuda)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/question_generation/prepare_data.py
DELETED
@@ -1,204 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import logging
|
3 |
-
from dataclasses import dataclass, field
|
4 |
-
from typing import Dict, List, Optional
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import nlp
|
8 |
-
from transformers import T5Tokenizer, BartTokenizer, HfArgumentParser
|
9 |
-
|
10 |
-
|
11 |
-
logger = logging.getLogger(__name__)
|
12 |
-
|
13 |
-
|
14 |
-
@dataclass
|
15 |
-
class DataTrainingArguments:
|
16 |
-
"""
|
17 |
-
Arguments pertaining to what data we are going to input our model for training and eval.
|
18 |
-
"""
|
19 |
-
task: str = field(
|
20 |
-
metadata={"help": "Which task 'qa', 'qg', 'e2e_qg', 'ans_ext', 'multi'. 'multi' means 'qa', 'qg', 'ans_ext' tasks"},
|
21 |
-
)
|
22 |
-
model_type: str = field(metadata={"help": "One of 't5', 'bart'"})
|
23 |
-
dataset_path: Optional[str] = field(
|
24 |
-
default="data/squad_multitask",
|
25 |
-
metadata={"help": "Path for dataset directory"},
|
26 |
-
)
|
27 |
-
train_file_name: Optional[str] = field(
|
28 |
-
default=None,
|
29 |
-
metadata={"help": "name for cached train dataset"},
|
30 |
-
)
|
31 |
-
valid_file_name: Optional[str] = field(
|
32 |
-
default=None,
|
33 |
-
metadata={"help": "name for cached valid dataset"},
|
34 |
-
)
|
35 |
-
valid_for_qg_only: bool = field(
|
36 |
-
default=False,
|
37 |
-
metadata={"help": "For multitask dataset valid split should contain only qg task or all tasks."}
|
38 |
-
)
|
39 |
-
qg_format: Optional[str] = field(
|
40 |
-
default='highlight_qg_format',
|
41 |
-
metadata={"help": "How to format inputs for que generation, 'highlight_qg_format' or 'prepend_qg_format'"},
|
42 |
-
)
|
43 |
-
max_source_length: Optional[int] = field(
|
44 |
-
default=512,
|
45 |
-
metadata={"help": "Max input length for the source text"},
|
46 |
-
)
|
47 |
-
max_target_length: Optional[int] = field(
|
48 |
-
default=32,
|
49 |
-
metadata={"help": "Max input length for the target text"},
|
50 |
-
)
|
51 |
-
|
52 |
-
class DataProcessor:
|
53 |
-
def __init__(self, tokenizer, model_type="t5", max_source_length=512, max_target_length=32):
|
54 |
-
self.tokenizer = tokenizer
|
55 |
-
self.max_source_length = max_source_length
|
56 |
-
self.max_target_length = max_target_length
|
57 |
-
self.model_type = model_type
|
58 |
-
self.hl_token = "<hl>"
|
59 |
-
|
60 |
-
if model_type == "t5":
|
61 |
-
self.sep_token = "<sep>"
|
62 |
-
elif model_type == "bart":
|
63 |
-
self.sep_token = "<sep>"
|
64 |
-
else:
|
65 |
-
self.sep_token = "[SEP]"
|
66 |
-
|
67 |
-
def process(self, dataset):
|
68 |
-
if self.model_type == "t5":
|
69 |
-
dataset = dataset.map(self._add_eos_examples)
|
70 |
-
|
71 |
-
dataset = dataset.map(self._add_special_tokens)
|
72 |
-
dataset = dataset.map(self._convert_to_features, batched=True)
|
73 |
-
|
74 |
-
return dataset
|
75 |
-
|
76 |
-
def _add_eos_examples(self, example):
|
77 |
-
example['source_text'] = example['source_text'] + " </s>"
|
78 |
-
example['target_text'] = example['target_text'] + " </s>"
|
79 |
-
return example
|
80 |
-
|
81 |
-
def _add_special_tokens(self, example):
|
82 |
-
example['source_text'] = example['source_text'].replace("{hl_token}", self.hl_token)
|
83 |
-
example['target_text'] = example['target_text'].replace("{sep_token}", self.sep_token)
|
84 |
-
return example
|
85 |
-
|
86 |
-
# tokenize the examples
|
87 |
-
def _convert_to_features(self, example_batch):
|
88 |
-
source_encoding = self.tokenizer.batch_encode_plus(
|
89 |
-
example_batch['source_text'],
|
90 |
-
max_length=self.max_source_length,
|
91 |
-
padding='max_length',
|
92 |
-
pad_to_max_length=True,
|
93 |
-
truncation=True,
|
94 |
-
)
|
95 |
-
target_encoding = self.tokenizer.batch_encode_plus(
|
96 |
-
example_batch['target_text'],
|
97 |
-
max_length=self.max_target_length,
|
98 |
-
padding='max_length',
|
99 |
-
pad_to_max_length=True,
|
100 |
-
truncation=True,
|
101 |
-
)
|
102 |
-
|
103 |
-
encodings = {
|
104 |
-
'source_ids': source_encoding['input_ids'],
|
105 |
-
'target_ids': target_encoding['input_ids'],
|
106 |
-
'attention_mask': source_encoding['attention_mask'],
|
107 |
-
}
|
108 |
-
|
109 |
-
return encodings
|
110 |
-
|
111 |
-
|
112 |
-
def filter_qa(example):
|
113 |
-
return example['task'] == 'qa'
|
114 |
-
|
115 |
-
def filter_qg(example):
|
116 |
-
return example['task'] == 'qg'
|
117 |
-
|
118 |
-
def filter_e2e_qg(example):
|
119 |
-
return example['task'] == 'e2e_qg'
|
120 |
-
|
121 |
-
def filter_ans_ext(example):
|
122 |
-
return example['task'] == 'ans_ext'
|
123 |
-
|
124 |
-
def filter_multi(example):
|
125 |
-
return example['task'] != 'e2e_qg'
|
126 |
-
|
127 |
-
|
128 |
-
TASK_TO_FILTER_FN = {
|
129 |
-
'qa': filter_qa,
|
130 |
-
'qg': filter_qg,
|
131 |
-
'e2e_qg': filter_e2e_qg,
|
132 |
-
'ans_ext': filter_ans_ext,
|
133 |
-
'multi': filter_multi
|
134 |
-
}
|
135 |
-
|
136 |
-
|
137 |
-
def main():
|
138 |
-
parser = HfArgumentParser((DataTrainingArguments,))
|
139 |
-
|
140 |
-
data_args = parser.parse_args_into_dataclasses()[0]
|
141 |
-
|
142 |
-
logging.basicConfig(
|
143 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
144 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
145 |
-
level=logging.INFO
|
146 |
-
)
|
147 |
-
|
148 |
-
if data_args.model_type == 't5':
|
149 |
-
tokenizer = T5Tokenizer.from_pretrained("t5-base")
|
150 |
-
else:
|
151 |
-
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
152 |
-
|
153 |
-
tokenizer.add_tokens(['<sep>', '<hl>'])
|
154 |
-
|
155 |
-
train_dataset = nlp.load_dataset(data_args.dataset_path, name=data_args.qg_format, split=nlp.Split.TRAIN)
|
156 |
-
valid_dataset = nlp.load_dataset(data_args.dataset_path, name=data_args.qg_format, split=nlp.Split.VALIDATION)
|
157 |
-
|
158 |
-
processor = DataProcessor(
|
159 |
-
tokenizer,
|
160 |
-
model_type=data_args.model_type,
|
161 |
-
max_source_length=data_args.max_source_length,
|
162 |
-
max_target_length=data_args.max_target_length
|
163 |
-
)
|
164 |
-
|
165 |
-
train_dataset = train_dataset.filter(TASK_TO_FILTER_FN[data_args.task])
|
166 |
-
if data_args.task == 'multi' and data_args.valid_for_qg_only:
|
167 |
-
logger.info("processing valid data only for qg task")
|
168 |
-
valid_dataset = valid_dataset.filter(filter_qg)
|
169 |
-
else:
|
170 |
-
valid_dataset = valid_dataset.filter(TASK_TO_FILTER_FN[data_args.task])
|
171 |
-
|
172 |
-
|
173 |
-
train_dataset = processor.process(train_dataset)
|
174 |
-
valid_dataset = processor.process(valid_dataset)
|
175 |
-
|
176 |
-
columns = ["source_ids", "target_ids", "attention_mask"]
|
177 |
-
train_dataset.set_format(type='torch', columns=columns)
|
178 |
-
valid_dataset.set_format(type='torch', columns=columns)
|
179 |
-
|
180 |
-
if data_args.train_file_name is None:
|
181 |
-
train_file_name = f"train_data_{data_args.task}_{data_args.qg_format}_{data_args.model_type}.pt"
|
182 |
-
train_path = os.path.join("data", train_file_name)
|
183 |
-
|
184 |
-
valid_file_name = f"valid_data_{data_args.task}_{data_args.qg_format}_{data_args.model_type}.pt"
|
185 |
-
valid_path = os.path.join("data", valid_file_name)
|
186 |
-
else:
|
187 |
-
train_path = os.path.join("data", data_args.train_file_name)
|
188 |
-
valid_path = os.path.join("data", data_args.valid_file_name)
|
189 |
-
|
190 |
-
torch.save(train_dataset, train_path)
|
191 |
-
logger.info(f"saved train dataset at {train_path}")
|
192 |
-
|
193 |
-
torch.save(valid_dataset, valid_path)
|
194 |
-
logger.info(f"saved validation dataset at {valid_path}")
|
195 |
-
|
196 |
-
tokenizer_path = f"{data_args.model_type}_qg_tokenizer"
|
197 |
-
if not os.path.exists(tokenizer_path):
|
198 |
-
os.mkdir(tokenizer_path)
|
199 |
-
tokenizer.save_pretrained(tokenizer_path)
|
200 |
-
logger.info(f"saved tokenizer at {tokenizer_path}")
|
201 |
-
|
202 |
-
|
203 |
-
if __name__ == "__main__":
|
204 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/question_generation/question_generation.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
code/question_generation/run_qg.py
DELETED
@@ -1,236 +0,0 @@
|
|
1 |
-
import dataclasses
|
2 |
-
import json
|
3 |
-
import logging
|
4 |
-
import os
|
5 |
-
import sys
|
6 |
-
from dataclasses import dataclass, field
|
7 |
-
from typing import Dict, List, Optional
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
|
12 |
-
from transformers import (
|
13 |
-
AutoModelForSeq2SeqLM,
|
14 |
-
AutoTokenizer,
|
15 |
-
T5Tokenizer,
|
16 |
-
BartTokenizer,
|
17 |
-
HfArgumentParser,
|
18 |
-
DataCollator,
|
19 |
-
TrainingArguments,
|
20 |
-
set_seed,
|
21 |
-
)
|
22 |
-
|
23 |
-
from trainer import Trainer
|
24 |
-
from data_collator import T2TDataCollator
|
25 |
-
from utils import freeze_embeds, assert_not_all_frozen
|
26 |
-
|
27 |
-
MODEL_TYPE_TO_TOKENIZER = {
|
28 |
-
"t5": T5Tokenizer,
|
29 |
-
"bart": BartTokenizer,
|
30 |
-
}
|
31 |
-
|
32 |
-
|
33 |
-
logger = logging.getLogger(__name__)
|
34 |
-
|
35 |
-
|
36 |
-
@dataclass
|
37 |
-
class ModelArguments:
|
38 |
-
"""
|
39 |
-
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
40 |
-
"""
|
41 |
-
|
42 |
-
model_name_or_path: str = field(
|
43 |
-
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
44 |
-
)
|
45 |
-
model_type: str = field(metadata={"help": "One of 't5', 'bart'"})
|
46 |
-
tokenizer_name_or_path: Optional[str] = field(
|
47 |
-
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
48 |
-
)
|
49 |
-
cache_dir: Optional[str] = field(
|
50 |
-
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
51 |
-
)
|
52 |
-
label_smoothing: Optional[float] = field(
|
53 |
-
default=0,
|
54 |
-
metadata={"help": "label smoothing rate, set to > 0 if you want to enable lable smoothing"}
|
55 |
-
)
|
56 |
-
freeze_embeds: bool = field(
|
57 |
-
default=False,
|
58 |
-
metadata={"help": "Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."}
|
59 |
-
)
|
60 |
-
|
61 |
-
@dataclass
|
62 |
-
class DataTrainingArguments:
|
63 |
-
"""
|
64 |
-
Arguments pertaining to what data we are going to input our model for training and eval.
|
65 |
-
"""
|
66 |
-
train_file_path: str = field(
|
67 |
-
metadata={"help": "Path for cached train dataset"},
|
68 |
-
)
|
69 |
-
valid_file_path: str = field(
|
70 |
-
metadata={"help": "Path for cached valid dataset"},
|
71 |
-
)
|
72 |
-
data_dir: Optional[str] = field(
|
73 |
-
default=None,
|
74 |
-
metadata={"help": "Path for data files"},
|
75 |
-
)
|
76 |
-
task: Optional[str] = field(
|
77 |
-
default=None,
|
78 |
-
metadata={"help": "Which task 'qa', 'qg', 'e2e_qg', 'ans_ext', 'multi'. 'multi' means 'qa', 'qg', 'ans_ext' tasks"},
|
79 |
-
)
|
80 |
-
qg_format: Optional[str] = field(
|
81 |
-
default='prepend_qg_format',
|
82 |
-
metadata={"help": "How to format inputs for que generation, 'highlight_qg_format' or 'prepend_qg_format'"},
|
83 |
-
)
|
84 |
-
max_source_length: Optional[int] = field(
|
85 |
-
default=512,
|
86 |
-
metadata={"help": "Max input length for the source text"},
|
87 |
-
)
|
88 |
-
max_target_length: Optional[int] = field(
|
89 |
-
default=32,
|
90 |
-
metadata={"help": "Max input length for the target text"},
|
91 |
-
)
|
92 |
-
|
93 |
-
|
94 |
-
def main(args_file=None):
|
95 |
-
# See all possible arguments in src/transformers/training_args.py
|
96 |
-
# or by passing the --help flag to this script.
|
97 |
-
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
98 |
-
|
99 |
-
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
100 |
-
|
101 |
-
if (len(sys.argv) == 2 and sys.argv[1].endswith(".json")) or args_file is not None:
|
102 |
-
# If we pass only one argument to the script and it's the path to a json file,
|
103 |
-
# let's parse it to get our arguments.
|
104 |
-
args_file_path = os.path.abspath(sys.argv[1]) if args_file is None else args_file
|
105 |
-
model_args, data_args, training_args = parser.parse_json_file(json_file=args_file_path)
|
106 |
-
else:
|
107 |
-
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
108 |
-
|
109 |
-
assert model_args.model_type in list(MODEL_TYPE_TO_TOKENIZER.keys()), "model type should be 't5' or 'bart'"
|
110 |
-
|
111 |
-
if (
|
112 |
-
os.path.exists(training_args.output_dir)
|
113 |
-
and os.listdir(training_args.output_dir)
|
114 |
-
and training_args.do_train
|
115 |
-
and not training_args.overwrite_output_dir
|
116 |
-
):
|
117 |
-
raise ValueError(
|
118 |
-
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
119 |
-
)
|
120 |
-
|
121 |
-
# Setup logging
|
122 |
-
logging.basicConfig(
|
123 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
124 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
125 |
-
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
126 |
-
)
|
127 |
-
logger.warning(
|
128 |
-
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
129 |
-
training_args.local_rank,
|
130 |
-
training_args.device,
|
131 |
-
training_args.n_gpu,
|
132 |
-
bool(training_args.local_rank != -1),
|
133 |
-
training_args.fp16,
|
134 |
-
)
|
135 |
-
logger.info("Training/evaluation parameters %s", training_args)
|
136 |
-
|
137 |
-
# Set seed
|
138 |
-
set_seed(training_args.seed)
|
139 |
-
|
140 |
-
# Set project name
|
141 |
-
os.environ["WANDB_PROJECT"] = "question-generation"
|
142 |
-
|
143 |
-
# Load pretrained model and tokenizer
|
144 |
-
#
|
145 |
-
# Distributed training:
|
146 |
-
# The .from_pretrained methods guarantee that only one local process can concurrently
|
147 |
-
# download model & vocab.
|
148 |
-
tokenizer_cls = MODEL_TYPE_TO_TOKENIZER[model_args.model_type]
|
149 |
-
tokenizer = tokenizer_cls.from_pretrained(
|
150 |
-
model_args.tokenizer_name_or_path if model_args.tokenizer_name_or_path else model_args.model_name_or_path,
|
151 |
-
cache_dir=model_args.cache_dir,
|
152 |
-
)
|
153 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
154 |
-
model_args.model_name_or_path,
|
155 |
-
cache_dir=model_args.cache_dir,
|
156 |
-
)
|
157 |
-
|
158 |
-
model.resize_token_embeddings(len(tokenizer))
|
159 |
-
|
160 |
-
if model_args.freeze_embeds:
|
161 |
-
logger.info("freezing embeddings of the model")
|
162 |
-
freeze_embeds(model)
|
163 |
-
assert_not_all_frozen(model)
|
164 |
-
|
165 |
-
# Get datasets
|
166 |
-
logger.info('loading dataset')
|
167 |
-
|
168 |
-
train_dataset = torch.load(data_args.train_file_path) if training_args.do_train else None
|
169 |
-
valid_dataset = torch.load(data_args.valid_file_path) if training_args.do_eval else None
|
170 |
-
|
171 |
-
logger.info('finished loading dataset')
|
172 |
-
|
173 |
-
# Initialize data_collator
|
174 |
-
data_collator = T2TDataCollator(
|
175 |
-
tokenizer=tokenizer,
|
176 |
-
model_type=model_args.model_type,
|
177 |
-
mode="training",
|
178 |
-
using_tpu=training_args.tpu_num_cores is not None
|
179 |
-
)
|
180 |
-
|
181 |
-
# Initialize our Trainer
|
182 |
-
trainer = Trainer(
|
183 |
-
model=model,
|
184 |
-
args=training_args,
|
185 |
-
train_dataset=train_dataset,
|
186 |
-
eval_dataset=valid_dataset,
|
187 |
-
data_collator=data_collator,
|
188 |
-
prediction_loss_only=True,
|
189 |
-
label_smoothing=model_args.label_smoothing
|
190 |
-
)
|
191 |
-
|
192 |
-
# disable wandb console logs
|
193 |
-
logging.getLogger('wandb.run_manager').setLevel(logging.WARNING)
|
194 |
-
|
195 |
-
# Training
|
196 |
-
if training_args.do_train:
|
197 |
-
trainer.train(
|
198 |
-
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
|
199 |
-
)
|
200 |
-
trainer.save_model()
|
201 |
-
# For convenience, we also re-save the tokenizer to the same directory,
|
202 |
-
# so that you can share your model easily on huggingface.co/models =)
|
203 |
-
if trainer.is_world_master():
|
204 |
-
tokenizer.save_pretrained(training_args.output_dir)
|
205 |
-
|
206 |
-
# Evaluation
|
207 |
-
results = {}
|
208 |
-
if training_args.do_eval and training_args.local_rank in [-1, 0]:
|
209 |
-
logger.info("*** Evaluate ***")
|
210 |
-
|
211 |
-
eval_output = trainer.evaluate()
|
212 |
-
|
213 |
-
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
|
214 |
-
with open(output_eval_file, "w") as writer:
|
215 |
-
logger.info("***** Eval results *****")
|
216 |
-
for key in sorted(eval_output.keys()):
|
217 |
-
logger.info(" %s = %s", key, str(eval_output[key]))
|
218 |
-
writer.write("%s = %s\n" % (key, str(eval_output[key])))
|
219 |
-
|
220 |
-
results.update(eval_output)
|
221 |
-
|
222 |
-
return results
|
223 |
-
|
224 |
-
|
225 |
-
def _mp_fn(index):
|
226 |
-
# For xla_spawn (TPUs)
|
227 |
-
main()
|
228 |
-
|
229 |
-
def run_qg(args_dict):
|
230 |
-
with open("args.json", 'w') as f:
|
231 |
-
json.dump(args_dict, f)
|
232 |
-
|
233 |
-
main(args_file="args.json")
|
234 |
-
|
235 |
-
if __name__ == "__main__":
|
236 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/question_generation/trainer.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
from typing import Any, Dict, Union
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
|
6 |
-
from transformers import Trainer as HFTrainer
|
7 |
-
from transformers.file_utils import is_apex_available
|
8 |
-
|
9 |
-
if is_apex_available():
|
10 |
-
from apex import amp
|
11 |
-
|
12 |
-
from utils import label_smoothed_nll_loss
|
13 |
-
|
14 |
-
class Trainer(HFTrainer):
|
15 |
-
def __init__(self, label_smoothing: float = 0, **kwargs):
|
16 |
-
super().__init__(**kwargs)
|
17 |
-
self.label_smoothing = label_smoothing
|
18 |
-
|
19 |
-
# override to support label smoothing
|
20 |
-
def _training_step(
|
21 |
-
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
|
22 |
-
) -> float:
|
23 |
-
model.train()
|
24 |
-
for k, v in inputs.items():
|
25 |
-
if isinstance(v, torch.Tensor):
|
26 |
-
inputs[k] = v.to(self.args.device)
|
27 |
-
|
28 |
-
|
29 |
-
# Our model outputs do not work with DataParallel, so forcing return tuple.
|
30 |
-
if isinstance(model, nn.DataParallel):
|
31 |
-
inputs["return_tuple"] = True
|
32 |
-
|
33 |
-
if self.label_smoothing == 0:
|
34 |
-
outputs = model(**inputs)
|
35 |
-
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
36 |
-
else:
|
37 |
-
labels = inputs.pop("labels")
|
38 |
-
labels[labels == -100] = model.config.pad_token_id
|
39 |
-
outputs = model(**inputs)
|
40 |
-
lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
|
41 |
-
loss, nll_loss = label_smoothed_nll_loss(
|
42 |
-
lprobs, labels, self.label_smoothing, ignore_index=model.config.pad_token_id
|
43 |
-
)
|
44 |
-
|
45 |
-
if self.args.n_gpu > 1:
|
46 |
-
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
47 |
-
if self.args.gradient_accumulation_steps > 1:
|
48 |
-
loss = loss / self.args.gradient_accumulation_steps
|
49 |
-
|
50 |
-
if self.args.fp16:
|
51 |
-
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
52 |
-
scaled_loss.backward()
|
53 |
-
else:
|
54 |
-
loss.backward()
|
55 |
-
|
56 |
-
return loss.item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/question_generation/utils.py
DELETED
@@ -1,49 +0,0 @@
|
|
1 |
-
from typing import Callable, Dict, Iterable, List
|
2 |
-
from torch import nn
|
3 |
-
|
4 |
-
# these functions are taken from transformers repo
|
5 |
-
def grad_status(model: nn.Module) -> Iterable:
|
6 |
-
return (par.requires_grad for par in model.parameters())
|
7 |
-
|
8 |
-
def freeze_params(model: nn.Module):
|
9 |
-
for par in model.parameters():
|
10 |
-
par.requires_grad = False
|
11 |
-
|
12 |
-
def freeze_embeds(model: nn.Module):
|
13 |
-
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
14 |
-
try:
|
15 |
-
freeze_params(model.model.shared)
|
16 |
-
for d in [model.model.encoder, model.model.decoder]:
|
17 |
-
freeze_params(d.embed_positions)
|
18 |
-
freeze_params(d.embed_tokens)
|
19 |
-
except AttributeError:
|
20 |
-
freeze_params(model.shared)
|
21 |
-
for d in [model.encoder, model.decoder]:
|
22 |
-
freeze_params(d.embed_tokens)
|
23 |
-
|
24 |
-
def assert_not_all_frozen(model):
|
25 |
-
model_grads: List[bool] = list(grad_status(model))
|
26 |
-
npars = len(model_grads)
|
27 |
-
assert any(model_grads), f"none of {npars} weights require grad"
|
28 |
-
|
29 |
-
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
30 |
-
"""From fairseq"""
|
31 |
-
if target.dim() == lprobs.dim() - 1:
|
32 |
-
target = target.unsqueeze(-1)
|
33 |
-
nll_loss = -lprobs.gather(dim=-1, index=target)
|
34 |
-
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
|
35 |
-
if ignore_index is not None:
|
36 |
-
pad_mask = target.eq(ignore_index)
|
37 |
-
nll_loss.masked_fill_(pad_mask, 0.0)
|
38 |
-
smooth_loss.masked_fill_(pad_mask, 0.0)
|
39 |
-
bs = pad_mask.long().sum()
|
40 |
-
else:
|
41 |
-
nll_loss = nll_loss.squeeze(-1)
|
42 |
-
smooth_loss = smooth_loss.squeeze(-1)
|
43 |
-
bs = lprobs.shape[0]
|
44 |
-
|
45 |
-
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
|
46 |
-
smooth_loss = smooth_loss.sum()
|
47 |
-
eps_i = epsilon / lprobs.size(-1)
|
48 |
-
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
|
49 |
-
return loss / bs, nll_loss / bs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/quiz_gen.py
DELETED
@@ -1,149 +0,0 @@
|
|
1 |
-
import sys,os,nltk,random,requests
|
2 |
-
nltk.download('wordnet')
|
3 |
-
nltk.download('omw-1.4')
|
4 |
-
from nltk.corpus import wordnet
|
5 |
-
import streamlit as st
|
6 |
-
# st.write(os.listdir("."))
|
7 |
-
# st.write(os.listdir("question_generation/"))
|
8 |
-
# st.write(os.getcwd())
|
9 |
-
# st.write(sys.path)
|
10 |
-
sys.path.append("code/question_generation/")
|
11 |
-
sys.path.append("code/question_generation")
|
12 |
-
sys.path.append("question_generation")
|
13 |
-
sys.path.append("question_generation/")
|
14 |
-
# sys.path.append("../question_generation/")
|
15 |
-
# st.write(sys.path)
|
16 |
-
|
17 |
-
from pipelines import pipeline
|
18 |
-
import pandas as pd
|
19 |
-
import datetime
|
20 |
-
|
21 |
-
@st.cache(allow_output_mutation = True)
|
22 |
-
def load_model():
|
23 |
-
get_ans = pipeline("multitask-qa-qg")
|
24 |
-
get_ques = pipeline("e2e-qg")
|
25 |
-
return get_ans, get_ques
|
26 |
-
|
27 |
-
def csv_downloader(df):
|
28 |
-
res = df.to_csv(index=False,sep="\t").encode('utf-8')
|
29 |
-
st.download_button(
|
30 |
-
label="Download logs data as CSV separated by tab",
|
31 |
-
data=res,
|
32 |
-
file_name='df_quiz_log_file.csv',
|
33 |
-
mime='text/csv')
|
34 |
-
|
35 |
-
|
36 |
-
def load_file():
|
37 |
-
"""Load text from file"""
|
38 |
-
uploaded_file = st.file_uploader("Upload Files",type=['txt'])
|
39 |
-
if uploaded_file is not None:
|
40 |
-
if uploaded_file.type == "text/plain":
|
41 |
-
raw_text = str(uploaded_file.read(),"utf-8")
|
42 |
-
return raw_text
|
43 |
-
|
44 |
-
|
45 |
-
# def get_related_word(word):
|
46 |
-
# options = [word]
|
47 |
-
# for synset in wordnet.synsets(word):
|
48 |
-
# for l in synset.lemmas():
|
49 |
-
# if l.antonyms():
|
50 |
-
# options.append(l.antonyms()[0].name())
|
51 |
-
# options.append(l.name())
|
52 |
-
# return list(options)[:3]
|
53 |
-
|
54 |
-
def get_related_word(word):
|
55 |
-
url = "https://api.datamuse.com/words"
|
56 |
-
querystring = {"ml":word}
|
57 |
-
responses = requests.request("GET", url, params=querystring)
|
58 |
-
related_words = []
|
59 |
-
count = 0
|
60 |
-
responses = responses.json()
|
61 |
-
for res in responses:
|
62 |
-
if count >= 4:
|
63 |
-
break
|
64 |
-
if res["word"]!=word and res["word"]!="":
|
65 |
-
related_words.append(res["word"])
|
66 |
-
count += 1
|
67 |
-
return related_words
|
68 |
-
|
69 |
-
|
70 |
-
def get_final_option_list(ans,other_options):
|
71 |
-
option1 = ans
|
72 |
-
option2,option3,option4 = "dummy","dummy","dummy"
|
73 |
-
try:
|
74 |
-
option2 = other_options[0]
|
75 |
-
except:
|
76 |
-
pass
|
77 |
-
try:
|
78 |
-
option3 = other_options[1]
|
79 |
-
except:
|
80 |
-
pass
|
81 |
-
try:
|
82 |
-
option4 = other_options[2]
|
83 |
-
except:
|
84 |
-
pass
|
85 |
-
final_options = [option1,option2,option3,option4]
|
86 |
-
#st.write(final_options)
|
87 |
-
random.shuffle(final_options)
|
88 |
-
#st.write(final_options)
|
89 |
-
final_options = [None]+final_options
|
90 |
-
final_options = tuple(final_options)
|
91 |
-
#st.write(final_options)
|
92 |
-
return final_options
|
93 |
-
|
94 |
-
st.markdown('')
|
95 |
-
|
96 |
-
# Loading Model
|
97 |
-
get_ans, get_ques =load_model()
|
98 |
-
|
99 |
-
# App title and description
|
100 |
-
st.title("Exam Assistant")
|
101 |
-
st.write("Upload text, Get ready for answering autogenerated questions")
|
102 |
-
|
103 |
-
# Load file
|
104 |
-
st.text("Disclaimer: This app stores user's input for model improvement purposes !!")
|
105 |
-
|
106 |
-
# Load file
|
107 |
-
raw_text = load_file()
|
108 |
-
start_time = str(datetime.datetime.now())
|
109 |
-
if raw_text != None and raw_text != '':
|
110 |
-
|
111 |
-
# Display text
|
112 |
-
with st.expander("See text"):
|
113 |
-
st.write(raw_text)
|
114 |
-
|
115 |
-
# get_ans = pipeline("multitask-qa-qg")
|
116 |
-
#get_ans = pipeline("multitask-qa-qg", model="valhalla/t5-base-qa-qg-hl")
|
117 |
-
# get_ques = pipeline("e2e-qg")
|
118 |
-
#get_ques = pipeline("e2e-qg", model="valhalla/t5-base-e2e-qg")
|
119 |
-
ans_list = []
|
120 |
-
questions = get_ques(raw_text)
|
121 |
-
for ques in questions:
|
122 |
-
# st.write("Question: {}".format(ques))
|
123 |
-
ans = get_ans({"question":ques,"context":raw_text})
|
124 |
-
ans_list.append(ans)
|
125 |
-
other_options = get_related_word(ans)
|
126 |
-
final_options = get_final_option_list(ans,other_options)
|
127 |
-
st.markdown(
|
128 |
-
""" <style>
|
129 |
-
div[role="radiogroup"] > :first-child{
|
130 |
-
display: none !important;
|
131 |
-
}
|
132 |
-
</style>
|
133 |
-
""",
|
134 |
-
unsafe_allow_html=True
|
135 |
-
)
|
136 |
-
sel_ans = st.radio(ques,final_options)
|
137 |
-
if sel_ans == ans:
|
138 |
-
st.success("Correct Answer!!")
|
139 |
-
st.balloons()
|
140 |
-
else:
|
141 |
-
st.error("Wrong Answer")
|
142 |
-
st.write("="*50)
|
143 |
-
output_path = "results/df_quiz_log_file.csv"
|
144 |
-
res_df = pd.DataFrame({"TimeStamp":[start_time]*len(questions),\
|
145 |
-
"Input":[str(raw_text)]*len(questions),\
|
146 |
-
"Question":questions,"Answer":ans_list})
|
147 |
-
res_df.to_csv(output_path, mode='a', index=False, sep="\t", header= not os.path.exists(output_path))
|
148 |
-
st.dataframe(pd.read_csv(output_path,sep="\t").tail(5))
|
149 |
-
csv_downloader(pd.read_csv(output_path,sep="\t"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/quiz_gen_new.py
DELETED
@@ -1,257 +0,0 @@
|
|
1 |
-
# !pip install --quiet transformers==4.5.0
|
2 |
-
# !pip install --quiet sentencepiece==0.1.95
|
3 |
-
# !pip install --quiet git+https://github.com/boudinfl/pke.git@dc4d5f21e0ffe64c4df93c46146d29d1c522476b
|
4 |
-
# pip install git+https://github.com/boudinfl/pke.git
|
5 |
-
# !pip install --quiet nltk==3.2.5
|
6 |
-
|
7 |
-
|
8 |
-
# pip install git+https://github.com/boudinfl/pke.git@dc4d5f21e0ffe64c4df93c46146d29d1c522476b
|
9 |
-
# pip install spacy==3.1.3
|
10 |
-
# pip install textwrap3==0.9.2
|
11 |
-
# pip install flashtext==2.7
|
12 |
-
|
13 |
-
|
14 |
-
import streamlit as st
|
15 |
-
from textwrap3 import wrap
|
16 |
-
from flashtext import KeywordProcessor
|
17 |
-
import torch, random, nltk, string, traceback, sys, os, requests, datetime
|
18 |
-
import numpy as np
|
19 |
-
import pandas as pd
|
20 |
-
from transformers import T5ForConditionalGeneration,T5Tokenizer
|
21 |
-
# st.write("Import pke")
|
22 |
-
import pke
|
23 |
-
|
24 |
-
def set_seed(seed: int):
|
25 |
-
random.seed(seed)
|
26 |
-
np.random.seed(seed)
|
27 |
-
torch.manual_seed(seed)
|
28 |
-
torch.cuda.manual_seed_all(seed)
|
29 |
-
|
30 |
-
set_seed(42)
|
31 |
-
|
32 |
-
|
33 |
-
@st.cache(allow_output_mutation = True)
|
34 |
-
def load_model():
|
35 |
-
nltk.download('punkt')
|
36 |
-
nltk.download('brown')
|
37 |
-
nltk.download('wordnet')
|
38 |
-
nltk.download('stopwords')
|
39 |
-
nltk.download('wordnet')
|
40 |
-
nltk.download('omw-1.4')
|
41 |
-
summary_model = T5ForConditionalGeneration.from_pretrained('t5-base')
|
42 |
-
summary_tokenizer = T5Tokenizer.from_pretrained('t5-base')
|
43 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
44 |
-
summary_model = summary_model.to(device)
|
45 |
-
question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_squad_v1')
|
46 |
-
question_tokenizer = T5Tokenizer.from_pretrained('ramsrigouthamg/t5_squad_v1')
|
47 |
-
question_model = question_model.to(device)
|
48 |
-
return summary_model, summary_tokenizer, question_tokenizer, question_model
|
49 |
-
|
50 |
-
from nltk.corpus import wordnet as wn
|
51 |
-
from nltk.tokenize import sent_tokenize
|
52 |
-
from nltk.corpus import stopwords
|
53 |
-
|
54 |
-
def postprocesstext (content):
|
55 |
-
final=""
|
56 |
-
for sent in sent_tokenize(content):
|
57 |
-
sent = sent.capitalize()
|
58 |
-
final = final +" "+sent
|
59 |
-
return final
|
60 |
-
|
61 |
-
def summarizer(text,model,tokenizer):
|
62 |
-
text = text.strip().replace("\n"," ")
|
63 |
-
text = "summarize: "+text
|
64 |
-
# print (text)
|
65 |
-
max_len = 512
|
66 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
67 |
-
encoding = tokenizer.encode_plus(text,max_length=max_len, pad_to_max_length=False,\
|
68 |
-
truncation=True, return_tensors="pt").to(device)
|
69 |
-
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
|
70 |
-
outs = model.generate(input_ids=input_ids,
|
71 |
-
attention_mask=attention_mask,
|
72 |
-
early_stopping=True,
|
73 |
-
num_beams=3,
|
74 |
-
num_return_sequences=1,
|
75 |
-
no_repeat_ngram_size=2,
|
76 |
-
min_length = 75,
|
77 |
-
max_length=300)
|
78 |
-
dec = [tokenizer.decode(ids,skip_special_tokens=True) for ids in outs]
|
79 |
-
summary = dec[0]
|
80 |
-
summary = postprocesstext(summary)
|
81 |
-
summary= summary.strip()
|
82 |
-
return summary
|
83 |
-
|
84 |
-
|
85 |
-
def get_nouns_multipartite(content):
|
86 |
-
out=[]
|
87 |
-
try:
|
88 |
-
extractor = pke.unsupervised.MultipartiteRank()
|
89 |
-
extractor.load_document(input=content)
|
90 |
-
# not contain punctuation marks or stopwords as candidates.
|
91 |
-
pos = {'PROPN','NOUN'}
|
92 |
-
#pos = {'PROPN','NOUN'}
|
93 |
-
stoplist = list(string.punctuation)
|
94 |
-
stoplist += ['-lrb-', '-rrb-', '-lcb-', '-rcb-', '-lsb-', '-rsb-']
|
95 |
-
stoplist += stopwords.words('english')
|
96 |
-
# extractor.candidate_selection(pos=pos, stoplist=stoplist)
|
97 |
-
extractor.candidate_selection(pos=pos)
|
98 |
-
# 4. build the Multipartite graph and rank candidates using random walk,
|
99 |
-
# alpha controls the weight adjustment mechanism, see TopicRank for
|
100 |
-
# threshold/method parameters.
|
101 |
-
extractor.candidate_weighting(alpha=1.1,
|
102 |
-
threshold=0.75,
|
103 |
-
method='average')
|
104 |
-
keyphrases = extractor.get_n_best(n=15)
|
105 |
-
for val in keyphrases:
|
106 |
-
out.append(val[0])
|
107 |
-
except Exception as e:
|
108 |
-
out = []
|
109 |
-
traceback.print_exc()
|
110 |
-
print("EXCEPTION: {}".format(e))
|
111 |
-
return out
|
112 |
-
|
113 |
-
def get_keywords(originaltext,summarytext):
|
114 |
-
keywords = get_nouns_multipartite(originaltext)
|
115 |
-
print ("keywords unsummarized: ",keywords)
|
116 |
-
keyword_processor = KeywordProcessor()
|
117 |
-
for keyword in keywords:
|
118 |
-
keyword_processor.add_keyword(keyword)
|
119 |
-
keywords_found = keyword_processor.extract_keywords(summarytext)
|
120 |
-
keywords_found = list(set(keywords_found))
|
121 |
-
print("keywords_found in summarized: ",keywords_found)
|
122 |
-
important_keywords =[]
|
123 |
-
for keyword in keywords:
|
124 |
-
if keyword in keywords_found:
|
125 |
-
important_keywords.append(keyword)
|
126 |
-
return important_keywords[:5]
|
127 |
-
|
128 |
-
def get_question(context,answer,model,tokenizer):
|
129 |
-
text = "context: {} answer: {}".format(context,answer)
|
130 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
131 |
-
encoding = tokenizer.encode_plus(text,max_length=384, pad_to_max_length=False,\
|
132 |
-
truncation=True, return_tensors="pt").to(device)
|
133 |
-
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
|
134 |
-
outs = model.generate(input_ids=input_ids,
|
135 |
-
attention_mask=attention_mask,
|
136 |
-
early_stopping=True,
|
137 |
-
num_beams=5,
|
138 |
-
num_return_sequences=1,
|
139 |
-
no_repeat_ngram_size=2,
|
140 |
-
max_length=72)
|
141 |
-
dec = [tokenizer.decode(ids,skip_special_tokens=True) for ids in outs]
|
142 |
-
Question = dec[0].replace("question:","")
|
143 |
-
Question= Question.strip()
|
144 |
-
return Question
|
145 |
-
|
146 |
-
|
147 |
-
def csv_downloader(df):
|
148 |
-
res = df.to_csv(index=False,sep="\t").encode('utf-8')
|
149 |
-
st.download_button(
|
150 |
-
label="Download logs data as CSV separated by tab",
|
151 |
-
data=res,
|
152 |
-
file_name='df_quiz_log_file.csv',
|
153 |
-
mime='text/csv')
|
154 |
-
|
155 |
-
def load_file():
|
156 |
-
"""Load text from file"""
|
157 |
-
uploaded_file = st.file_uploader("Upload Files",type=['txt'])
|
158 |
-
if uploaded_file is not None:
|
159 |
-
if uploaded_file.type == "text/plain":
|
160 |
-
raw_text = str(uploaded_file.read(),"utf-8")
|
161 |
-
return raw_text
|
162 |
-
|
163 |
-
|
164 |
-
def get_related_word(word):
|
165 |
-
url = "https://api.datamuse.com/words"
|
166 |
-
querystring = {"ml":word}
|
167 |
-
responses = requests.request("GET", url, params=querystring)
|
168 |
-
related_words = []
|
169 |
-
count = 0
|
170 |
-
responses = responses.json()
|
171 |
-
for res in responses:
|
172 |
-
if count >= 4:
|
173 |
-
break
|
174 |
-
if res["word"]!=word and res["word"]!="":
|
175 |
-
related_words.append(res["word"])
|
176 |
-
count += 1
|
177 |
-
return related_words
|
178 |
-
|
179 |
-
|
180 |
-
def get_final_option_list(ans,other_options):
|
181 |
-
option1 = ans
|
182 |
-
option2,option3,option4 = "dummy","dummy","dummy"
|
183 |
-
try:
|
184 |
-
option2 = other_options[0]
|
185 |
-
except:
|
186 |
-
pass
|
187 |
-
try:
|
188 |
-
option3 = other_options[1]
|
189 |
-
except:
|
190 |
-
pass
|
191 |
-
try:
|
192 |
-
option4 = other_options[2]
|
193 |
-
except:
|
194 |
-
pass
|
195 |
-
final_options = [option1,option2,option3,option4]
|
196 |
-
#st.write(final_options)
|
197 |
-
random.shuffle(final_options)
|
198 |
-
#st.write(final_options)
|
199 |
-
final_options = [None]+final_options
|
200 |
-
final_options = tuple(final_options)
|
201 |
-
#st.write(final_options)
|
202 |
-
return final_options
|
203 |
-
|
204 |
-
st.markdown('')
|
205 |
-
|
206 |
-
# Loading Model
|
207 |
-
summary_model, summary_tokenizer, question_tokenizer, question_model =load_model()
|
208 |
-
|
209 |
-
# App title and description
|
210 |
-
st.title("Exam Assistant")
|
211 |
-
st.write("Upload text, Get ready for answering autogenerated questions")
|
212 |
-
|
213 |
-
# Load file
|
214 |
-
st.text("Disclaimer: This app stores user's input for model improvement purposes !!")
|
215 |
-
|
216 |
-
# Load file
|
217 |
-
raw_text = load_file()
|
218 |
-
start_time = str(datetime.datetime.now())
|
219 |
-
if raw_text != None and raw_text != '':
|
220 |
-
|
221 |
-
# Display text
|
222 |
-
with st.expander("See text"):
|
223 |
-
st.write(raw_text)
|
224 |
-
|
225 |
-
summary_text = summarizer(raw_text,summary_model,summary_tokenizer)
|
226 |
-
ans_list = get_keywords(raw_text,summary_text)
|
227 |
-
questions = []
|
228 |
-
|
229 |
-
for ans in ans_list:
|
230 |
-
ques = get_question(summary_text,ans,question_model,question_tokenizer)
|
231 |
-
if ques not in questions:
|
232 |
-
other_options = get_related_word(ans)
|
233 |
-
final_options = get_final_option_list(ans,other_options)
|
234 |
-
st.markdown(
|
235 |
-
""" <style>
|
236 |
-
div[role="radiogroup"] > :first-child{
|
237 |
-
display: none !important;
|
238 |
-
}
|
239 |
-
</style>
|
240 |
-
""",
|
241 |
-
unsafe_allow_html=True
|
242 |
-
)
|
243 |
-
sel_ans = st.radio(ques,final_options)
|
244 |
-
if sel_ans == ans:
|
245 |
-
st.success("Correct Answer!!")
|
246 |
-
st.balloons()
|
247 |
-
else:
|
248 |
-
st.error("Wrong Answer")
|
249 |
-
st.write("="*50)
|
250 |
-
questions.append(ques)
|
251 |
-
output_path = "results/df_quiz_log_file.csv"
|
252 |
-
res_df = pd.DataFrame({"TimeStamp":[start_time]*len(questions),\
|
253 |
-
"Input":[str(raw_text)]*len(questions),\
|
254 |
-
"Question":questions,"Answer":ans_list})
|
255 |
-
# res_df.to_csv(output_path, mode='a', index=False, sep="\t", header= not os.path.exists(output_path))
|
256 |
-
# st.dataframe(pd.read_csv(output_path,sep="\t").tail(5))
|
257 |
-
# csv_downloader(pd.read_csv(output_path,sep="\t"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/quiz_gen_new2.py
DELETED
@@ -1,277 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from textwrap3 import wrap
|
3 |
-
from flashtext import KeywordProcessor
|
4 |
-
import torch, random, nltk, string, traceback, sys, os, requests, datetime
|
5 |
-
import numpy as np
|
6 |
-
import pandas as pd
|
7 |
-
from transformers import T5ForConditionalGeneration,T5Tokenizer
|
8 |
-
import pke
|
9 |
-
|
10 |
-
def set_seed(seed: int):
|
11 |
-
random.seed(seed)
|
12 |
-
np.random.seed(seed)
|
13 |
-
torch.manual_seed(seed)
|
14 |
-
torch.cuda.manual_seed_all(seed)
|
15 |
-
|
16 |
-
set_seed(42)
|
17 |
-
|
18 |
-
@st.cache(allow_output_mutation = True)
|
19 |
-
def load_model():
|
20 |
-
nltk.download('punkt')
|
21 |
-
nltk.download('brown')
|
22 |
-
nltk.download('wordnet')
|
23 |
-
nltk.download('stopwords')
|
24 |
-
nltk.download('wordnet')
|
25 |
-
nltk.download('omw-1.4')
|
26 |
-
summary_mod_name = os.environ["summary_mod_name"]
|
27 |
-
question_mod_name = os.environ["question_mod_name"]
|
28 |
-
summary_model = T5ForConditionalGeneration.from_pretrained(summary_mod_name)
|
29 |
-
summary_tokenizer = T5Tokenizer.from_pretrained(summary_mod_name)
|
30 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
-
summary_model = summary_model.to(device)
|
32 |
-
question_model = T5ForConditionalGeneration.from_pretrained(question_mod_name)
|
33 |
-
question_tokenizer = T5Tokenizer.from_pretrained(question_mod_name)
|
34 |
-
question_model = question_model.to(device)
|
35 |
-
return summary_model, summary_tokenizer, question_tokenizer, question_model
|
36 |
-
|
37 |
-
from nltk.corpus import wordnet as wn
|
38 |
-
from nltk.tokenize import sent_tokenize
|
39 |
-
from nltk.corpus import stopwords
|
40 |
-
|
41 |
-
def postprocesstext (content):
|
42 |
-
final=""
|
43 |
-
for sent in sent_tokenize(content):
|
44 |
-
sent = sent.capitalize()
|
45 |
-
final = final +" "+sent
|
46 |
-
return final
|
47 |
-
|
48 |
-
def summarizer(text,model,tokenizer):
|
49 |
-
text = text.strip().replace("\n"," ")
|
50 |
-
text = "summarize: "+text
|
51 |
-
# print (text)
|
52 |
-
max_len = 512
|
53 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
54 |
-
encoding = tokenizer.encode_plus(text,max_length=max_len, pad_to_max_length=False,\
|
55 |
-
truncation=True, return_tensors="pt").to(device)
|
56 |
-
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
|
57 |
-
outs = model.generate(input_ids=input_ids,
|
58 |
-
attention_mask=attention_mask,
|
59 |
-
early_stopping=True,
|
60 |
-
num_beams=3,
|
61 |
-
num_return_sequences=1,
|
62 |
-
no_repeat_ngram_size=2,
|
63 |
-
min_length = 75,
|
64 |
-
max_length=300)
|
65 |
-
dec = [tokenizer.decode(ids,skip_special_tokens=True) for ids in outs]
|
66 |
-
summary = dec[0]
|
67 |
-
summary = postprocesstext(summary)
|
68 |
-
summary= summary.strip()
|
69 |
-
return summary
|
70 |
-
|
71 |
-
|
72 |
-
def get_nouns_multipartite(content):
|
73 |
-
out=[]
|
74 |
-
try:
|
75 |
-
extractor = pke.unsupervised.MultipartiteRank()
|
76 |
-
extractor.load_document(input=content)
|
77 |
-
# not contain punctuation marks or stopwords as candidates.
|
78 |
-
pos = {'PROPN','NOUN'}
|
79 |
-
#pos = {'PROPN','NOUN'}
|
80 |
-
stoplist = list(string.punctuation)
|
81 |
-
stoplist += ['-lrb-', '-rrb-', '-lcb-', '-rcb-', '-lsb-', '-rsb-']
|
82 |
-
stoplist += stopwords.words('english')
|
83 |
-
# extractor.candidate_selection(pos=pos, stoplist=stoplist)
|
84 |
-
extractor.candidate_selection(pos=pos)
|
85 |
-
# 4. build the Multipartite graph and rank candidates using random walk,
|
86 |
-
# alpha controls the weight adjustment mechanism, see TopicRank for
|
87 |
-
# threshold/method parameters.
|
88 |
-
extractor.candidate_weighting(alpha=1.1,
|
89 |
-
threshold=0.75,
|
90 |
-
method='average')
|
91 |
-
keyphrases = extractor.get_n_best(n=15)
|
92 |
-
for val in keyphrases:
|
93 |
-
out.append(val[0])
|
94 |
-
except Exception as e:
|
95 |
-
out = []
|
96 |
-
traceback.print_exc()
|
97 |
-
print("EXCEPTION: {}".format(e))
|
98 |
-
return out
|
99 |
-
|
100 |
-
def get_keywords(originaltext,summarytext):
|
101 |
-
keywords = get_nouns_multipartite(originaltext)
|
102 |
-
print ("keywords unsummarized: ",keywords)
|
103 |
-
keyword_processor = KeywordProcessor()
|
104 |
-
for keyword in keywords:
|
105 |
-
keyword_processor.add_keyword(keyword)
|
106 |
-
keywords_found = keyword_processor.extract_keywords(summarytext)
|
107 |
-
keywords_found = list(set(keywords_found))
|
108 |
-
print("keywords_found in summarized: ",keywords_found)
|
109 |
-
important_keywords =[]
|
110 |
-
for keyword in keywords:
|
111 |
-
if keyword in keywords_found:
|
112 |
-
important_keywords.append(keyword)
|
113 |
-
return important_keywords[:5]
|
114 |
-
|
115 |
-
def get_question(context,answer,model,tokenizer):
|
116 |
-
text = "context: {} answer: {}".format(context,answer)
|
117 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
118 |
-
encoding = tokenizer.encode_plus(text,max_length=384, pad_to_max_length=False,\
|
119 |
-
truncation=True, return_tensors="pt").to(device)
|
120 |
-
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
|
121 |
-
outs = model.generate(input_ids=input_ids,
|
122 |
-
attention_mask=attention_mask,
|
123 |
-
early_stopping=True,
|
124 |
-
num_beams=5,
|
125 |
-
num_return_sequences=1,
|
126 |
-
no_repeat_ngram_size=2,
|
127 |
-
max_length=72)
|
128 |
-
dec = [tokenizer.decode(ids,skip_special_tokens=True) for ids in outs]
|
129 |
-
Question = dec[0].replace("question:","")
|
130 |
-
Question= Question.strip()
|
131 |
-
return Question
|
132 |
-
|
133 |
-
|
134 |
-
def csv_downloader(df):
|
135 |
-
res = df.to_csv(index=False,sep="\t").encode('utf-8')
|
136 |
-
st.download_button(
|
137 |
-
label="Download logs data as CSV separated by tab",
|
138 |
-
data=res,
|
139 |
-
file_name='df_quiz_log_file.csv',
|
140 |
-
mime='text/csv')
|
141 |
-
|
142 |
-
def load_file():
|
143 |
-
"""Load text from file"""
|
144 |
-
uploaded_file = st.file_uploader("Upload Files",type=['txt'])
|
145 |
-
if uploaded_file is not None:
|
146 |
-
if uploaded_file.type == "text/plain":
|
147 |
-
raw_text = str(uploaded_file.read(),"utf-8")
|
148 |
-
return raw_text
|
149 |
-
|
150 |
-
|
151 |
-
def get_related_word(word):
|
152 |
-
url = "https://api.datamuse.com/words"
|
153 |
-
querystring = {"ml":word}
|
154 |
-
responses = requests.request("GET", url, params=querystring)
|
155 |
-
related_words = []
|
156 |
-
count = 0
|
157 |
-
responses = responses.json()
|
158 |
-
for res in responses:
|
159 |
-
if count >= 4:
|
160 |
-
break
|
161 |
-
if res["word"]!=word and res["word"]!="":
|
162 |
-
related_words.append(res["word"])
|
163 |
-
count += 1
|
164 |
-
return related_words
|
165 |
-
|
166 |
-
|
167 |
-
def get_final_option_list(ans,other_options):
|
168 |
-
option1 = ans
|
169 |
-
option2,option3,option4 = "dummy","dummy","dummy"
|
170 |
-
try:
|
171 |
-
option2 = other_options[0]
|
172 |
-
except:
|
173 |
-
pass
|
174 |
-
try:
|
175 |
-
option3 = other_options[1]
|
176 |
-
except:
|
177 |
-
pass
|
178 |
-
try:
|
179 |
-
option4 = other_options[2]
|
180 |
-
except:
|
181 |
-
pass
|
182 |
-
final_options = [option1,option2,option3,option4]
|
183 |
-
random.shuffle(final_options)
|
184 |
-
final_options = tuple(final_options)
|
185 |
-
ans_index= 0
|
186 |
-
for i in range(4):
|
187 |
-
if final_options[i] == ans:
|
188 |
-
ans_index = i
|
189 |
-
return final_options, ans_index
|
190 |
-
|
191 |
-
def load_raw_text():
|
192 |
-
return " Elon Musk has shown again he can influence the digital currency market with just his tweets. After saying that his electric vehicle-making company Tesla will not accept payments in Bitcoin because of environmental concerns, he tweeted that he was working with developers of Dogecoin to improve system transaction efficiency. Following the two distinct statements from him, the world's largest cryptocurrency hit a two-month low, while Dogecoin rallied by about 20 percent. The SpaceX CEO has in recent months often tweeted in support of Dogecoin, but rarely for Bitcoin. In a recent tweet, Musk put out a statement from Tesla that it was concerned about the rapidly increasing use of fossil fuels for Bitcoin (price in India) mining and transaction, and hence was suspending vehicle purchases using the cryptocurrency. A day later he again tweeted saying, To be clear, I strongly believe in crypto, but it can't drive a massive increase in fossil fuel use, especially coal. It triggered a downward spiral for Bitcoin value but the cryptocurrency has stabilised since. A number of Twitter users welcomed Musk's statement. One of them said it's time people started realising that Dogecoin is here to stay and another referred to Musk's previous assertion that crypto could become the world's future currency."
|
193 |
-
|
194 |
-
# def get_final_option_list(ans,other_options):
|
195 |
-
# option1 = ans
|
196 |
-
# option2,option3,option4 = "dummy","dummy","dummy"
|
197 |
-
# try:
|
198 |
-
# option2 = other_options[0]
|
199 |
-
# except:
|
200 |
-
# pass
|
201 |
-
# try:
|
202 |
-
# option3 = other_options[1]
|
203 |
-
# except:
|
204 |
-
# pass
|
205 |
-
# try:
|
206 |
-
# option4 = other_options[2]
|
207 |
-
# except:
|
208 |
-
# pass
|
209 |
-
# final_options = [option1,option2,option3,option4]
|
210 |
-
# #st.write(final_options)
|
211 |
-
# random.shuffle(final_options)
|
212 |
-
# #st.write(final_options)
|
213 |
-
# final_options = [None]+final_options
|
214 |
-
# final_options = tuple(final_options)
|
215 |
-
# #st.write(final_options)
|
216 |
-
# return final_options
|
217 |
-
|
218 |
-
st.markdown('')
|
219 |
-
|
220 |
-
# Loading Model
|
221 |
-
summary_model, summary_tokenizer, question_tokenizer, question_model =load_model()
|
222 |
-
|
223 |
-
# App title and description
|
224 |
-
st.title("Exam Assistant")
|
225 |
-
st.write("Upload text, Get ready for answering autogenerated questions")
|
226 |
-
|
227 |
-
# Load file
|
228 |
-
st.text("Disclaimer: This app stores user's input for model improvement purposes !!")
|
229 |
-
|
230 |
-
# Load file
|
231 |
-
|
232 |
-
default_text = load_raw_text()
|
233 |
-
raw_text = st.text_area("Enter text here", default_text, height=400, max_chars=1000000, )
|
234 |
-
|
235 |
-
# raw_text = load_file()
|
236 |
-
start_time = str(datetime.datetime.now())
|
237 |
-
if raw_text != None and raw_text != '':
|
238 |
-
|
239 |
-
# Display text
|
240 |
-
# with st.expander("See text"):
|
241 |
-
# st.write(raw_text)
|
242 |
-
|
243 |
-
summary_text = summarizer(raw_text,summary_model,summary_tokenizer)
|
244 |
-
ans_list = get_keywords(raw_text,summary_text)
|
245 |
-
questions = []
|
246 |
-
|
247 |
-
for idx,ans in enumerate(ans_list):
|
248 |
-
ques = get_question(summary_text,ans,question_model,question_tokenizer)
|
249 |
-
if ques not in questions:
|
250 |
-
other_options = get_related_word(ans)
|
251 |
-
final_options = get_final_option_list(ans,other_options)
|
252 |
-
final_options, ans_index = get_final_option_list(ans,other_options)
|
253 |
-
# st.write(final_options)
|
254 |
-
html_str = f"""
|
255 |
-
<div>
|
256 |
-
<p>
|
257 |
-
{idx+1}: <b> {ques} </b>
|
258 |
-
</p>
|
259 |
-
</div>
|
260 |
-
"""
|
261 |
-
html_str += f' <p style="color:Green;"><b> {final_options[0]} </b></p> ' if ans_index == 0 else f' <p><b> {final_options[0]} </b></p> '
|
262 |
-
html_str += f' <p style="color:Green;"><b> {final_options[1]} </b></p> ' if ans_index == 1 else f' <p><b> {final_options[1]} </b></p> '
|
263 |
-
html_str += f' <p style="color:Green;"><b> {final_options[2]} </b></p> ' if ans_index == 2 else f' <p><b> {final_options[2]} </b></p> '
|
264 |
-
html_str += f' <p style="color:Green;"><b> {final_options[3]} </b></p> ' if ans_index == 3 else f' <p><b> {final_options[3]} </b></p> '
|
265 |
-
html_str += f"""
|
266 |
-
"""
|
267 |
-
st.markdown(html_str , unsafe_allow_html=True)
|
268 |
-
st.markdown("-----")
|
269 |
-
# st.write("="*50)
|
270 |
-
questions.append(ques)
|
271 |
-
output_path = "results/df_quiz_log_file.csv"
|
272 |
-
res_df = pd.DataFrame({"TimeStamp":[start_time]*len(questions),\
|
273 |
-
"Input":[str(raw_text)]*len(questions),\
|
274 |
-
"Question":questions,"Answer":ans_list})
|
275 |
-
# res_df.to_csv(output_path, mode='a', index=False, sep="\t", header= not os.path.exists(output_path))
|
276 |
-
# st.dataframe(pd.read_csv(output_path,sep="\t").tail(5))
|
277 |
-
# csv_downloader(pd.read_csv(output_path,sep="\t"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input/input1.txt
CHANGED
@@ -1 +1,14 @@
|
|
1 |
Talk to your stakeholders before building any data product. Understand the business’ needs at the current stage: If it’s a startup, I bet your stakeholder won’t care too much about the format and color of the data visualizations you build but wants to instead focus on the accuracy of the data behind the visualizations and insights from them. Similarly, truly understand the audience and use case; for example, you would spend more time on a polished and simple user interface if the data product is intended to be used regularly by non-technical audiences.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
Talk to your stakeholders before building any data product. Understand the business’ needs at the current stage: If it’s a startup, I bet your stakeholder won’t care too much about the format and color of the data visualizations you build but wants to instead focus on the accuracy of the data behind the visualizations and insights from them. Similarly, truly understand the audience and use case; for example, you would spend more time on a polished and simple user interface if the data product is intended to be used regularly by non-technical audiences.
|
2 |
+
|
3 |
+
|
4 |
+
Bunny and Crow live in the forest. Bunny has soft white fur and long floppy ears. She lives in a nice little house under a rock. She likes to hop through the bushes. She likes to visit her animal friends. But she always wondered what the world outside the forest was like. Crow has shiny black feathers. His house is a nice nest at the top of a tree. He likes to fly high above the forest. He likes to fly over rivers and towns outside the forest. One day, Bunny saw Crow sitting in a tree. “Hello Crow!” she called. “Can you tell me what the world is like outside the forest?” So Crow told Bunny all about the rivers and towns outside the forest. Bunny and Crow became good friends.
|
5 |
+
|
6 |
+
|
7 |
+
Billy and Ron are brothers. Billy is 5 years old. Ron is 7 years old. One day their mom took them to the zoo. Billy wore his red cap, and Ron wore his blue cap. They had fun watching all the animals. Ron liked the monkeys the best. He wanted to stay and watch them some more, but Billy wanted to go see the elephants. Elephants were Billy’s favorite. Their mom said it was time to go see the elephants, and Ron was sad. But their mom said they could come back and see the monkeys again before they left the zoo. Billy and Ron had a great day at the zoo.
|
8 |
+
|
9 |
+
Luca’s grandpa lives on a farm. His grandpa has a big garden, and many animals. When Luca was little, he was afraid of the chickens. When he helped his grandpa feed the chickens, the chickens chased him, crying, “Cluck, cluck, cluck!” But his grandpa showed him how to shoo the chickens away with his hand, saying, “Shoo chicks, shoo chicks!” Luca still doesn’t like the chickens much, but he isn’t afraid of them now. He feels very grown up.
|
10 |
+
|
11 |
+
|
12 |
+
Alice’s aunt Lucy gave her a book that shows how to draw cartoons. It shows, step-by-step, how to draw a cartoon person, a cartoon cat, a cartoon mouse, and other animals. Most cartoon drawings start with circles. First you sketch lightly, until you like the way the drawing looks. Then you draw the lines in darker. Now Alice likes drawing cartoons even more than she likes coloring.
|
13 |
+
|
14 |
+
Molly took her new bike out to the sidewalk. Her dad was going to teach her to ride it. He said it was easy, but she wasn’t so sure. He said not to worry, but she wasn’t so sure. She got on the bike, and her dad ran beside her, holding on to keep her steady. The bike wiggled a little, but Molly rode to the end of the block. She put on the brakes and looked around. Her dad was all the up the sidewalk—she had been riding all by herself! After that, Molly never worried about riding a bike. She knew that she could do it.
|
results/df_quiz_log_file_v1.csv
CHANGED
The diff for this file is too large to render.
See raw diff
|
|