Jingxiang Mo commited on
Commit
af2e54d
·
1 Parent(s): bde6562
Files changed (1) hide show
  1. test.py +0 -147
test.py DELETED
@@ -1,147 +0,0 @@
1
- import os
2
- import gradio as gr
3
- import numpy as np
4
- import wikipediaapi as wk
5
- from transformers import (
6
- TokenClassificationPipeline,
7
- AutoModelForTokenClassification,
8
- AutoTokenizer,
9
- )
10
- import torch
11
- from transformers.pipelines import AggregationStrategy
12
- from transformers import BertForQuestionAnswering
13
- from transformers import BertTokenizer
14
-
15
- # =====[ DEFINE PIPELINE ]===== #
16
- class KeyphraseExtractionPipeline(TokenClassificationPipeline):
17
- def __init__(self, model, *args, **kwargs):
18
- super().__init__(
19
- model=AutoModelForTokenClassification.from_pretrained(model),
20
- tokenizer=AutoTokenizer.from_pretrained(model),
21
- *args,
22
- **kwargs
23
- )
24
-
25
- def postprocess(self, model_outputs):
26
- results = super().postprocess(
27
- model_outputs=model_outputs,
28
- aggregation_strategy=AggregationStrategy.SIMPLE,
29
- )
30
- return np.unique([result.get("word").strip() for result in results])
31
-
32
- # =====[ LOAD PIPELINE ]===== #
33
- keyPhraseExtractionModel = "ml6team/keyphrase-extraction-kbir-inspec"
34
- extractor = KeyphraseExtractionPipeline(model=keyPhraseExtractionModel)
35
- model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
36
- tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
37
-
38
- #TODO: add further preprocessing
39
- def keyphrases_extraction(text: str) -> str:
40
- keyphrases = extractor(text)
41
- return keyphrases
42
-
43
- def wikipedia_search(input: str) -> str:
44
- input = input.replace("\n", " ")
45
- keyphrases = keyphrases_extraction(input)
46
- wiki = wk.Wikipedia('en')
47
-
48
- try :
49
- #TODO: add better extraction and search
50
- keyphrase_index = 0
51
- page = wiki.page(keyphrases[keyphrase_index])
52
-
53
- while not ('.' in page.summary) or not page.exists():
54
- keyphrase_index += 1
55
- if keyphrase_index == len(keyphrases):
56
- raise Exception
57
- page = wiki.page(keyphrases[keyphrase_index])
58
- return page.summary
59
- except:
60
- return "I cannot answer this question"
61
-
62
- def answer_question(question):
63
-
64
- context = wikipedia_search(question)
65
- if context == "I cannot answer this question":
66
- return context
67
-
68
- # ======== Tokenize ========
69
- # Apply the tokenizer to the input text, treating them as a text-pair.
70
- input_ids = tokenizer.encode(question, context)
71
-
72
- # Report how long the input sequence is. if longer than 512 tokens, make it shorter
73
- while(len(input_ids) > 512):
74
- input_ids.pop()
75
-
76
- print('Query has {:,} tokens.\n'.format(len(input_ids)))
77
-
78
- # ======== Set Segment IDs ========
79
- # Search the input_ids for the first instance of the `[SEP]` token.
80
- sep_index = input_ids.index(tokenizer.sep_token_id)
81
-
82
- # The number of segment A tokens includes the [SEP] token istelf.
83
- num_seg_a = sep_index + 1
84
-
85
- # The remainder are segment B.
86
- num_seg_b = len(input_ids) - num_seg_a
87
-
88
- # Construct the list of 0s and 1s.
89
- segment_ids = [0]*num_seg_a + [1]*num_seg_b
90
-
91
- # There should be a segment_id for every input token.
92
- assert len(segment_ids) == len(input_ids)
93
-
94
- # ======== Evaluate ========
95
- # Run our example through the model.
96
- outputs = model(torch.tensor([input_ids]), # The tokens representing our input text.
97
- token_type_ids=torch.tensor([segment_ids]), # The segment IDs to differentiate question from answer_text
98
- return_dict=True)
99
-
100
- start_scores = outputs.start_logits
101
- end_scores = outputs.end_logits
102
-
103
- # ======== Reconstruct Answer ========
104
- # Find the tokens with the highest `start` and `end` scores.
105
- answer_start = torch.argmax(start_scores)
106
- answer_end = torch.argmax(end_scores)
107
-
108
- # Get the string versions of the input tokens.
109
- tokens = tokenizer.convert_ids_to_tokens(input_ids)
110
-
111
- # Start with the first token.
112
- answer = tokens[answer_start]
113
-
114
- # Select the remaining answer tokens and join them with whitespace.
115
- for i in range(answer_start + 1, answer_end + 1):
116
-
117
- # If it's a subword token, then recombine it with the previous token.
118
- if tokens[i][0:2] == '##':
119
- answer += tokens[i][2:]
120
-
121
- # Otherwise, add a space then the token.
122
- else:
123
- answer += ' ' + tokens[i]
124
-
125
- return 'Answer: "' + answer + '"'
126
-
127
- # =====[ DEFINE INTERFACE ]===== #'
128
- title = "Azza Chatbot"
129
- examples = [
130
- ["Where is the Eiffel Tower?"],
131
- ["What is the population of France?"]
132
- ]
133
-
134
-
135
-
136
- demo = gr.Interface(
137
- title = title,
138
-
139
- fn=answer_question,
140
- inputs = "text",
141
- outputs = "text",
142
-
143
- examples=examples,
144
- )
145
-
146
- if __name__ == "__main__":
147
- demo.launch(share=True)