Files changed (1) hide show
  1. app.py +0 -203
app.py DELETED
@@ -1,203 +0,0 @@
1
- import argparse
2
-
3
- import nltk
4
- import torch
5
- import numpy as np
6
- import gradio as gr
7
- from nltk import sent_tokenize
8
-
9
- from transformers import (
10
- RobertaTokenizer,
11
- RobertaForMaskedLM,
12
- LogitsProcessorList,
13
- TopKLogitsWarper,
14
- TemperatureLogitsWarper,
15
- )
16
- from transformers.generation_logits_process import TypicalLogitsWarper
17
-
18
- nltk.download('punkt')
19
-
20
- device = "cuda" if torch.cuda.is_available() else "cpu"
21
- pretrained = "roberta-large" if device == "cuda" else "roberta-base"
22
- tokenizer = RobertaTokenizer.from_pretrained(pretrained)
23
- model = RobertaForMaskedLM.from_pretrained(pretrained)
24
- model = model.to(device)
25
-
26
- max_len = 20
27
- top_k = 100
28
- temperature = 1
29
- typical_p = 0
30
- burnin = 250
31
- max_iter = 500
32
-
33
-
34
- # adapted from https://github.com/nyu-dl/bert-gen
35
- def generate_step(out: object,
36
- gen_idx: int,
37
- top_k: int = top_k,
38
- temperature: float = temperature,
39
- typical_p: float = typical_p,
40
- sample: bool = False) -> list:
41
- """ Generate a word from from out[gen_idx]
42
-
43
- args:
44
- - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
45
- - gen_idx (int): location for which to generate
46
- - top_k (int): if >0, only sample from the top k most probable words
47
- - temperature (float): sampling temperature
48
- - typical_p (float): if >0 use typical sampling
49
- - sample (bool): if True, sample from full distribution.
50
-
51
- returns:
52
- - list: batch_size tokens
53
- """
54
- logits = out.logits[:, gen_idx]
55
- warpers = LogitsProcessorList()
56
- if temperature:
57
- warpers.append(TemperatureLogitsWarper(temperature))
58
- if top_k > 0:
59
- warpers.append(TopKLogitsWarper(top_k))
60
- if typical_p > 0:
61
- if typical_p >= 1:
62
- typical_p = 0.999
63
- warpers.append(TypicalLogitsWarper(typical_p))
64
- logits = warpers(None, logits)
65
-
66
- if sample:
67
- probs = torch.nn.functional.softmax(logits, dim=-1)
68
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
69
- else:
70
- next_tokens = torch.argmax(logits, dim=-1)
71
-
72
- return next_tokens.tolist()
73
-
74
-
75
- # adapted from https://github.com/nyu-dl/bert-gen
76
- def parallel_sequential_generation(seed_text: str,
77
- seed_end_text: str,
78
- max_len: int = max_len,
79
- top_k: int = top_k,
80
- temperature: float = temperature,
81
- typical_p: float = typical_p,
82
- max_iter: int = max_iter,
83
- burnin: int = burnin) -> str:
84
- """ Generate text consistent with preceding and following text
85
-
86
- Args:
87
- - seed_text (str): preceding text
88
- - seed_end_text (str): following text
89
- - top_k (int): if >0, only sample from the top k most probable words
90
- - temperature (float): sampling temperature
91
- - typical_p (float): if >0 use typical sampling
92
- - max_iter (int): number of iterations in MCMC
93
- - burnin: during burn-in period, sample from full distribution; afterwards take argmax
94
-
95
- Returns:
96
- - string: generated text to insert between seed_text and seed_end_text
97
- """
98
- inp = tokenizer(seed_text + tokenizer.mask_token * max_len + seed_end_text,
99
- return_tensors='pt')
100
- masked_tokens = np.where(
101
- inp['input_ids'][0].numpy() == tokenizer.mask_token_id)[0]
102
- seed_len = masked_tokens[0]
103
- inp = inp.to(device)
104
-
105
- for ii in range(max_iter):
106
- kk = np.random.randint(0, max_len)
107
- idxs = generate_step(model(**inp),
108
- gen_idx=seed_len + kk,
109
- top_k=top_k if (ii >= burnin) else 0,
110
- temperature=temperature,
111
- typical_p=typical_p,
112
- sample=(ii < burnin))
113
- inp['input_ids'][0][seed_len + kk] = idxs[0]
114
-
115
- tokens = inp['input_ids'].cpu().numpy()[0][masked_tokens]
116
- tokens = tokens[(np.where((tokens != tokenizer.eos_token_id)
117
- & (tokens != tokenizer.bos_token_id)))]
118
- return tokenizer.decode(tokens)
119
-
120
-
121
- def inbertolate(doc: str,
122
- max_len: int = max_len,
123
- top_k: int = top_k,
124
- temperature: float = temperature,
125
- typical_p: float = typical_p,
126
- max_iter: int = max_iter,
127
- burnin: int = burnin) -> str:
128
- """ Pad out document generating every other sentence
129
-
130
- Args:
131
- - doc (str): document text
132
- - max_len (int): number of tokens to insert between sentences
133
- - top_k (int): if >0, only sample from the top k most probable words
134
- - temperature (float): sampling temperature
135
- - typical_p (float): if >0 use typical sampling
136
- - max_iter (int): number of iterations in MCMC
137
- - burnin: during burn-in period, sample from full distribution; afterwards take argmax
138
-
139
- Returns:
140
- - string: generated text to insert between seed_text and seed_end_text
141
- """
142
- new_doc = ''
143
- paras = doc.split('\n')
144
-
145
- for para in paras:
146
- para = sent_tokenize(para)
147
- if para == '':
148
- new_doc += '\n'
149
- continue
150
- para += ['']
151
-
152
- for sentence in range(len(para) - 1):
153
- new_doc += para[sentence] + ' '
154
- new_doc += parallel_sequential_generation(
155
- para[sentence],
156
- para[sentence + 1],
157
- max_len=max_len,
158
- top_k=top_k,
159
- temperature=float(temperature),
160
- typical_p=typical_p,
161
- burnin=burnin,
162
- max_iter=max_iter) + ' '
163
-
164
- new_doc += '\n'
165
- return new_doc
166
-
167
- demo = gr.Interface(
168
- fn=inbertolate,
169
- title="inBERTolate",
170
- description=f"Hit your word count by using BERT ({pretrained}) to pad out your essays!",
171
- inputs=[
172
- gr.Textbox(label="Text", lines=10),
173
- gr.Slider(label="Maximum length to insert between sentences",
174
- minimum=1,
175
- maximum=40,
176
- step=1,
177
- value=max_len),
178
- gr.Slider(label="Top k", minimum=0, maximum=200, value=top_k),
179
- gr.Slider(label="Temperature",
180
- minimum=0,
181
- maximum=2,
182
- value=temperature),
183
- gr.Slider(label="Typical p",
184
- minimum=0,
185
- maximum=1,
186
- value=typical_p),
187
- gr.Slider(label="Maximum iterations",
188
- minimum=0,
189
- maximum=1000,
190
- value=max_iter),
191
- gr.Slider(label="Burn-in",
192
- minimum=0,
193
- maximum=500,
194
- value=burnin),
195
- ],
196
- outputs=gr.Textbox(label="Expanded text", lines=30))
197
-
198
- if __name__ == '__main__':
199
- parser = argparse.ArgumentParser()
200
- parser.add_argument('--port', type=int)
201
- parser.add_argument('--server', type=int)
202
- args = parser.parse_args()
203
- demo.launch(server_name=args.server or '0.0.0.0', server_port=args.port)