Spaces:
Runtime error
Runtime error
Commit
·
92fd951
1
Parent(s):
3abae59
initial commit
Browse files- app.py +53 -0
- generation.py +252 -0
- modeling.py +257 -0
app.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from modeling import GPT2Wrapper
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
|
6 |
+
DEBIASING_PREFIXES = [
|
7 |
+
"The following text contains rude, disrespectful, or unreasonable language:\n",
|
8 |
+
"The following text contains very hateful, aggressive, disrespectful language:\n",
|
9 |
+
"The following text contains sexually explicit language:\n",
|
10 |
+
"The following text contains a threat:\n",
|
11 |
+
"The following text contains negative or hateful language targeting someone because of their identity:\n",
|
12 |
+
"The following text contains swear words, curse words, or other obscene or profane language:\n"
|
13 |
+
]
|
14 |
+
|
15 |
+
DEBIASING_KEYWORDS = [
|
16 |
+
"(rude) ", "(sexually explicit) ", "(hateful) ", "(aggressive) ", "(racist) ", "(threat) ", "(violent) ", "(sexist) "
|
17 |
+
]
|
18 |
+
|
19 |
+
def debias(prompt, model,use_prefix, max_length=50, num_beam=3):
|
20 |
+
"""
|
21 |
+
Debiasing inference function.
|
22 |
+
:param prompt: The prompt to be debiased.
|
23 |
+
:param model: The GPT2 model.
|
24 |
+
:param max_length: The maximum length of the output sentence.
|
25 |
+
:return: The debiased output sentence.
|
26 |
+
"""
|
27 |
+
wrapper = GPT2Wrapper(model_name=str(model), use_cuda=False)
|
28 |
+
if use_prefix == 'Prefixes':
|
29 |
+
debiasing_prefixes = DEBIASING_PREFIXES
|
30 |
+
else:
|
31 |
+
debiasing_prefixes = DEBIASING_KEYWORDS
|
32 |
+
|
33 |
+
output_text = output_text = wrapper.generate_self_debiasing([prompt], debiasing_prefixes= debiasing_prefixes,min_length=20, max_length=max_length, num_beam=num_beam,no_repeat_ngram_size=2)
|
34 |
+
output_text = output_text[0]
|
35 |
+
|
36 |
+
debiasing_prefixes = []
|
37 |
+
biased_text = wrapper.generate_self_debiasing([prompt], debiasing_prefixes= debiasing_prefixes,min_length=20, max_length=max_length, num_beam=num_beam,no_repeat_ngram_size=2)
|
38 |
+
biased_text = biased_text[0]
|
39 |
+
return output_text, biased_text
|
40 |
+
|
41 |
+
|
42 |
+
demo = gr.Interface(
|
43 |
+
debias,
|
44 |
+
inputs = [gr.Textbox(),
|
45 |
+
gr.Radio(choices=['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'],value='gpt2'),
|
46 |
+
gr.Radio(choices=['Prefixes','Keywords'],value='Prefixes',label='Use Debiasing Prefixes or Keywords'),
|
47 |
+
gr.Number(value=50,label='Max output length'),
|
48 |
+
gr.Number(value=3,label='Number of beams for beam search')],
|
49 |
+
outputs = [gr.Textbox(label="Debiased text"),gr.Textbox(label="Biased text")]
|
50 |
+
)
|
51 |
+
if __name__ == '__main__':
|
52 |
+
|
53 |
+
demo.launch()
|
generation.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Union, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers import GPT2LMHeadModel, LogitsProcessorList, LogitsProcessor, PreTrainedTokenizer
|
6 |
+
from transformers.generation_utils import GenerationMixin, SampleOutput, SampleEncoderDecoderOutput, SampleDecoderOnlyOutput
|
7 |
+
|
8 |
+
|
9 |
+
class SelfDebiasingLogitsProcessor(LogitsProcessor):
|
10 |
+
"""This class represents a logits processor that applies self-debiasing."""
|
11 |
+
|
12 |
+
def __init__(self, num_debiasing_prefixes: int, decay_constant: float = 50, epsilon: float = 0.01, debug: bool = False,
|
13 |
+
tokenizer: Optional[PreTrainedTokenizer] = None):
|
14 |
+
"""
|
15 |
+
:param num_debiasing_prefixes: the number of debiasing prefixes used
|
16 |
+
:param decay_constant: the decay constant (lambda in the paper)
|
17 |
+
:param epsilon: the minimum factor by which each probability is multiplied
|
18 |
+
:param debug: whether to print additional debugging output
|
19 |
+
:param tokenizer: a tokenizer used to print debugging output
|
20 |
+
"""
|
21 |
+
assert not debug or tokenizer, "If debug=True, a tokenizer must be passed to SelfDebiasingLogitsProcessor()"
|
22 |
+
self.num_debiasing_prefixes = num_debiasing_prefixes
|
23 |
+
self.decay_constant = decay_constant
|
24 |
+
self.epsilon = epsilon
|
25 |
+
self.debug = debug
|
26 |
+
self.tokenizer = tokenizer
|
27 |
+
|
28 |
+
def __call__(self, input_ids: torch.LongTensor,scores: torch.FloatTensor) -> torch.FloatTensor:
|
29 |
+
batch_size = scores.shape[0] // (1 + self.num_debiasing_prefixes)
|
30 |
+
regular_sentence_indices = range(batch_size)
|
31 |
+
for regular_sentence_idx in regular_sentence_indices:
|
32 |
+
bias_indices = self._get_bias_indices(regular_sentence_idx, batch_size)
|
33 |
+
if bias_indices:
|
34 |
+
self._debias_scores(scores, regular_sentence_idx, bias_indices)
|
35 |
+
return scores
|
36 |
+
|
37 |
+
def _get_bias_indices(self, regular_sentence_idx: int, batch_size: int) -> List[int]:
|
38 |
+
"""Returns the indices of all self-debiasing inputs for a regular input"""
|
39 |
+
return [regular_sentence_idx + (prefix_idx + 1) * batch_size for prefix_idx in range(self.num_debiasing_prefixes)]
|
40 |
+
|
41 |
+
def _debias_scores(self, scores: torch.FloatTensor, regular_sent_idx: int, bias_indices: List[int]) -> None:
|
42 |
+
"""Partially debiases the given scores considering a single sentence and the corresponding self-debiasing inputs"""
|
43 |
+
logits_biased = [scores[bias_idx] for bias_idx in bias_indices]
|
44 |
+
|
45 |
+
mask = self._generate_decay_mask(scores[regular_sent_idx], logits_biased)
|
46 |
+
scores[regular_sent_idx] = torch.log(self._apply_decay_mask(scores[regular_sent_idx], mask))
|
47 |
+
|
48 |
+
for debiasing_sent_idx in bias_indices:
|
49 |
+
scores[debiasing_sent_idx] = scores[regular_sent_idx]
|
50 |
+
|
51 |
+
def _apply_decay_mask(self, logits: torch.Tensor, decay_mask: torch.Tensor) -> torch.Tensor:
|
52 |
+
"""Applies exponential decay to a tensor of logits"""
|
53 |
+
probabilities = logits.softmax(dim=-1)
|
54 |
+
decay_mask = torch.exp(- decay_mask * self.decay_constant)
|
55 |
+
decay_mask = torch.max(decay_mask, torch.tensor([self.epsilon], device=decay_mask.device))
|
56 |
+
probabilities = probabilities * decay_mask
|
57 |
+
probabilities = probabilities / probabilities.sum(dim=-1)
|
58 |
+
return probabilities
|
59 |
+
|
60 |
+
def _generate_decay_mask(self, logits_regular: torch.FloatTensor, logits_biased_list: List[torch.FloatTensor]) -> torch.Tensor:
|
61 |
+
"""Computes the alpha values (see paper) for each token and stores them in a mask tensor"""
|
62 |
+
p_regular = logits_regular.softmax(dim=-1)
|
63 |
+
p_biased = None
|
64 |
+
|
65 |
+
for logits_biased in logits_biased_list:
|
66 |
+
if p_biased is None:
|
67 |
+
p_biased = logits_biased.softmax(dim=-1)
|
68 |
+
else:
|
69 |
+
p_biased = torch.max(p_biased, logits_biased.softmax(dim=-1))
|
70 |
+
|
71 |
+
if self.debug:
|
72 |
+
print(f'== Before Debiasing ==\n'
|
73 |
+
f'Top 5 predictions (regular): {self._get_most_likely_tokens(p_regular, k=5)}\n'
|
74 |
+
f'Top 5 predictions (biased): {self._get_most_likely_tokens(p_biased, k=5)}')
|
75 |
+
|
76 |
+
mask = torch.max(p_biased - p_regular, torch.tensor([0.], device=p_regular.device))
|
77 |
+
|
78 |
+
if self.debug:
|
79 |
+
p_regular = self._apply_decay_mask(logits_regular, mask)
|
80 |
+
print(f'== After Debiasing ==\n'
|
81 |
+
f'Top 5 predictions (regular): {self._get_most_likely_tokens(p_regular, k=5)}')
|
82 |
+
|
83 |
+
return mask
|
84 |
+
|
85 |
+
def _get_most_likely_tokens(self, probabilities_tensor: torch.Tensor, k: int) -> List[Tuple[str, float]]:
|
86 |
+
"""Returns the most likely tokens according to a tensor of probabilities"""
|
87 |
+
assert len(probabilities_tensor.shape) == 1
|
88 |
+
values, indices = torch.topk(probabilities_tensor, k=k, dim=-1)
|
89 |
+
tokens = self.tokenizer.convert_ids_to_tokens(indices)
|
90 |
+
return list(zip(tokens, [pv.item() for pv in values]))
|
91 |
+
|
92 |
+
|
93 |
+
class SelfDebiasingGPT2LMHeadModel(GPT2LMHeadModel, GenerationMixin):
|
94 |
+
"""
|
95 |
+
This class represents a regular GPT2LMHeadModel that additionally has the capacity to perform self-debiasing. For self-debiasing, the
|
96 |
+
init_logits_processor function must be called. Otherwise, this model just performs regular language modeling.
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, *args, **kwargs):
|
100 |
+
super().__init__(*args, **kwargs)
|
101 |
+
self.logits_processor = None # type: Optional[SelfDebiasingLogitsProcessor]
|
102 |
+
|
103 |
+
def init_logits_processor(self, *args, **kwargs):
|
104 |
+
"""Initialize the logits processor. For a list of arguments, see the self-debiasing logit processor's init function."""
|
105 |
+
self.logits_processor = SelfDebiasingLogitsProcessor(*args, **kwargs)
|
106 |
+
|
107 |
+
def _get_logits_processor(self, *args, **kwargs) -> LogitsProcessorList:
|
108 |
+
logits_processor = super()._get_logits_processor(*args, **kwargs)
|
109 |
+
if self.logits_processor is not None:
|
110 |
+
logits_processor.append(self.logits_processor)
|
111 |
+
return logits_processor
|
112 |
+
|
113 |
+
def beam_sample(self, *args, **kwargs):
|
114 |
+
raise NotImplementedError("Beam sampling is not implemented for self-debiasing models")
|
115 |
+
|
116 |
+
def sample(self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None,
|
117 |
+
logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None,
|
118 |
+
eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
|
119 |
+
output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, **model_kwargs) -> Union[
|
120 |
+
SampleOutput, torch.LongTensor]:
|
121 |
+
"""
|
122 |
+
This is a verbatim copy of the original implementation by huggingface, with a single modification to ensure that a text and all
|
123 |
+
corresponding self-debiasing inputs always chose the same token to generate next. This modification is enclosed by the texts
|
124 |
+
"BEGIN MODIFICATIONS" and "END MODIFICATIONS", respectively.
|
125 |
+
"""
|
126 |
+
# init values
|
127 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
128 |
+
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
129 |
+
max_length = max_length if max_length is not None else self.config.max_length
|
130 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
131 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
132 |
+
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
133 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
134 |
+
output_hidden_states = (
|
135 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
136 |
+
)
|
137 |
+
return_dict_in_generate = (
|
138 |
+
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
139 |
+
)
|
140 |
+
|
141 |
+
# init attention / hidden states / scores tuples
|
142 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
143 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
144 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
145 |
+
|
146 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
147 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
148 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
149 |
+
encoder_hidden_states = (
|
150 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
151 |
+
)
|
152 |
+
|
153 |
+
# init sequence length tensors
|
154 |
+
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
|
155 |
+
input_ids, max_length
|
156 |
+
)
|
157 |
+
|
158 |
+
# auto-regressive generation
|
159 |
+
while cur_len < max_length:
|
160 |
+
# prepare model inputs
|
161 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
162 |
+
|
163 |
+
# forward pass to get next token
|
164 |
+
outputs = self(
|
165 |
+
**model_inputs,
|
166 |
+
return_dict=True,
|
167 |
+
output_attentions=output_attentions,
|
168 |
+
output_hidden_states=output_hidden_states,
|
169 |
+
)
|
170 |
+
next_token_logits = outputs.logits[:, -1, :]
|
171 |
+
|
172 |
+
# pre-process distribution
|
173 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
174 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
175 |
+
|
176 |
+
# Store scores, attentions and hidden_states when required
|
177 |
+
if return_dict_in_generate:
|
178 |
+
if output_scores:
|
179 |
+
scores += (next_token_scores,)
|
180 |
+
if output_attentions:
|
181 |
+
decoder_attentions += (
|
182 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
183 |
+
)
|
184 |
+
|
185 |
+
if output_hidden_states:
|
186 |
+
decoder_hidden_states += (
|
187 |
+
(outputs.decoder_hidden_states,)
|
188 |
+
if self.config.is_encoder_decoder
|
189 |
+
else (outputs.hidden_states,)
|
190 |
+
)
|
191 |
+
|
192 |
+
# sample
|
193 |
+
probs = F.softmax(next_token_scores, dim=-1)
|
194 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
195 |
+
|
196 |
+
# =========================
|
197 |
+
# BEGIN MODIFICATIONS
|
198 |
+
# the following modification to the sample method is necessary to ensure that each debiasing sentence is continued in the same
|
199 |
+
# way as the original sentence
|
200 |
+
if self.logits_processor is not None:
|
201 |
+
batch_size = next_tokens.shape[0] // (1 + self.logits_processor.num_debiasing_prefixes)
|
202 |
+
regular_sentence_indices = range(batch_size)
|
203 |
+
for regular_sentence_idx in regular_sentence_indices:
|
204 |
+
debiasing_sentence_indices = self.logits_processor._get_bias_indices(regular_sentence_idx, batch_size)
|
205 |
+
for debiasing_sentence_idx in debiasing_sentence_indices:
|
206 |
+
next_tokens[debiasing_sentence_idx] = next_tokens[regular_sentence_idx]
|
207 |
+
# END MODIFICATIONS
|
208 |
+
# =========================
|
209 |
+
|
210 |
+
# add code that transfomers next_tokens to tokens_to_add
|
211 |
+
if eos_token_id is not None:
|
212 |
+
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
|
213 |
+
next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences)
|
214 |
+
|
215 |
+
# add token and increase length by one
|
216 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
217 |
+
cur_len = cur_len + 1
|
218 |
+
|
219 |
+
# update sequence length
|
220 |
+
if eos_token_id is not None:
|
221 |
+
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
|
222 |
+
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
|
223 |
+
)
|
224 |
+
|
225 |
+
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
226 |
+
if unfinished_sequences.max() == 0:
|
227 |
+
break
|
228 |
+
|
229 |
+
# update model kwargs
|
230 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
231 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
232 |
+
)
|
233 |
+
|
234 |
+
if return_dict_in_generate:
|
235 |
+
if self.config.is_encoder_decoder:
|
236 |
+
return SampleEncoderDecoderOutput(
|
237 |
+
sequences=input_ids,
|
238 |
+
scores=scores,
|
239 |
+
encoder_attentions=encoder_attentions,
|
240 |
+
encoder_hidden_states=encoder_hidden_states,
|
241 |
+
decoder_attentions=decoder_attentions,
|
242 |
+
decoder_hidden_states=decoder_hidden_states,
|
243 |
+
)
|
244 |
+
else:
|
245 |
+
return SampleDecoderOnlyOutput(
|
246 |
+
sequences=input_ids,
|
247 |
+
scores=scores,
|
248 |
+
attentions=decoder_attentions,
|
249 |
+
hidden_states=decoder_hidden_states,
|
250 |
+
)
|
251 |
+
else:
|
252 |
+
return input_ids
|
modeling.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import List, Optional, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.nn import CrossEntropyLoss
|
7 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration, GPT2Tokenizer, PreTrainedTokenizer, PreTrainedModel
|
8 |
+
|
9 |
+
from generation import SelfDebiasingGPT2LMHeadModel
|
10 |
+
|
11 |
+
|
12 |
+
class ModelWrapper(ABC):
|
13 |
+
"""
|
14 |
+
This class represents a wrapper for a pretrained language model that provides some high-level functions, including zero-shot
|
15 |
+
classification using cloze questions and the generation of texts with self-debiasing.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, use_cuda: bool = True):
|
19 |
+
"""
|
20 |
+
:param use_cuda: whether to use CUDA
|
21 |
+
"""
|
22 |
+
self._device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
|
23 |
+
self._tokenizer = None # type: Optional[PreTrainedTokenizer]
|
24 |
+
self._model = None # type: Optional[PreTrainedModel]
|
25 |
+
|
26 |
+
def query_model(self, input_text: str) -> torch.FloatTensor:
|
27 |
+
"""For a given input text, returns the probability distribution over possible next tokens."""
|
28 |
+
return self.query_model_batch([input_text])[0]
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
def query_model_batch(self, input_texts: List[str]) -> torch.FloatTensor:
|
32 |
+
"""For a batch of input texts, returns the probability distribution over possible next tokens."""
|
33 |
+
pass
|
34 |
+
|
35 |
+
@abstractmethod
|
36 |
+
def generate(self, input_text: str, **kwargs) -> str:
|
37 |
+
"""Generates a continuation for a given input text."""
|
38 |
+
pass
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def generate_self_debiasing(self, input_texts: List[str], debiasing_prefixes: List[str], decay_constant: float = 50,
|
42 |
+
epsilon: float = 0.01, debug: bool = False, **kwargs) -> List[str]:
|
43 |
+
"""
|
44 |
+
Generates continuations for the given input texts with self-debiasing.
|
45 |
+
:param input_texts: the input texts to generate continuations for
|
46 |
+
:param debiasing_prefixes: the debiasing prefixes to be used
|
47 |
+
:param decay_constant: the decay constant (lambda in the paper)
|
48 |
+
:param epsilon: the minimum factor by which each probability is multiplied
|
49 |
+
:param debug: whether to print additional debugging output
|
50 |
+
:param kwargs: further arguments are passed on to the original generate function
|
51 |
+
:return: the list of generated continuations
|
52 |
+
"""
|
53 |
+
pass
|
54 |
+
|
55 |
+
@abstractmethod
|
56 |
+
def compute_loss(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor:
|
57 |
+
"""Computes cross-entropy loss for the given input ids and corresponding labels."""
|
58 |
+
pass
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def compute_loss_self_debiasing(self, input_ids: torch.Tensor, trg_len: int, debiasing_prefixes: List[str], decay_constant: float = 50,
|
62 |
+
epsilon: float = 0.01, debug: bool = False) -> torch.Tensor:
|
63 |
+
"""
|
64 |
+
Computes cross-entropy loss for the given input ids with self-debiasing.
|
65 |
+
:param input_ids: the input ids
|
66 |
+
:param trg_len: only the last trg_len tokens are considered for computing the loss
|
67 |
+
:param debiasing_prefixes: the debiasing prefixes to be used
|
68 |
+
:param decay_constant: the decay constant (lambda in the paper)
|
69 |
+
:param epsilon: the minimum factor by which each probability is multiplied
|
70 |
+
:param debug: whether to print additional debugging output
|
71 |
+
:return: the cross entropy loss
|
72 |
+
"""
|
73 |
+
pass
|
74 |
+
|
75 |
+
def get_token_probability_distribution(self, input_texts: List[str], output_choices: List[str]) -> List[List[Tuple[str, float]]]:
|
76 |
+
"""
|
77 |
+
For a batch of input texts, returns the probability distribution over possible next tokens considering only the given list of
|
78 |
+
output choices.
|
79 |
+
:param input_texts: the input texts
|
80 |
+
:param output_choices: the allowed output choices (must correspond to single tokens in the model's vocabulary)
|
81 |
+
:return: a list of lists, where output[i][j] is a (output, probability) tuple for the ith input and jth output choice.
|
82 |
+
"""
|
83 |
+
output_choice_ids = []
|
84 |
+
kwargs = {'add_prefix_space': True} if isinstance(self, GPT2Wrapper) else {}
|
85 |
+
for word in output_choices:
|
86 |
+
tokens = self._tokenizer.tokenize(word, **kwargs)
|
87 |
+
assert len(tokens) == 1, f"Word {word} consists of multiple tokens: {tokens}"
|
88 |
+
assert tokens[0] not in self._tokenizer.all_special_tokens, f"Word {word} corresponds to a special token: {tokens[0]}"
|
89 |
+
token_id = self._tokenizer.convert_tokens_to_ids(tokens)[0]
|
90 |
+
output_choice_ids.append(token_id)
|
91 |
+
|
92 |
+
logits = self.query_model_batch(input_texts)
|
93 |
+
result = []
|
94 |
+
|
95 |
+
for idx, _ in enumerate(input_texts):
|
96 |
+
output_probabilities = logits[idx][output_choice_ids].softmax(dim=0)
|
97 |
+
choices_with_probabilities = list(zip(output_choices, (prob.item() for prob in output_probabilities)))
|
98 |
+
result.append(choices_with_probabilities)
|
99 |
+
|
100 |
+
return result
|
101 |
+
|
102 |
+
|
103 |
+
class T5Wrapper(ModelWrapper):
|
104 |
+
"""A wrapper for the T5 model"""
|
105 |
+
|
106 |
+
def __init__(self, model_name: str = "google/t5-v1_1-xl", use_cuda: bool = True):
|
107 |
+
"""
|
108 |
+
:param model_name: the name of the pretrained T5 model (default: "google/t5-v1_1-xl")
|
109 |
+
:param use_cuda: whether to use CUDA
|
110 |
+
"""
|
111 |
+
super().__init__(use_cuda=use_cuda)
|
112 |
+
self._tokenizer = T5Tokenizer.from_pretrained(model_name)
|
113 |
+
self._model = T5ForConditionalGeneration.from_pretrained(model_name)
|
114 |
+
if use_cuda:
|
115 |
+
self._model.parallelize()
|
116 |
+
|
117 |
+
def query_model_batch(self, input_texts: List[str]):
|
118 |
+
assert all('<extra_id_0>' in input_text for input_text in input_texts)
|
119 |
+
output_texts = ['<extra_id_0>'] * len(input_texts)
|
120 |
+
inputs = self._tokenizer.batch_encode_plus(input_texts, padding=True, return_tensors='pt')
|
121 |
+
inputs = {key: val.to(self._device) for key, val in inputs.items()}
|
122 |
+
output_ids = self._tokenizer.batch_encode_plus(output_texts, return_tensors='pt')['input_ids'].to(self._device)
|
123 |
+
return self._model(labels=output_ids, **inputs)['logits'][:, 1, :]
|
124 |
+
|
125 |
+
def generate(self, input_text: str, **kwargs):
|
126 |
+
assert '<extra_id_0>' in input_text
|
127 |
+
input_ids = self._tokenizer.encode(input_text, return_tensors='pt').to(self._device)
|
128 |
+
output_ids = self._model.generate(input_ids, **kwargs)[0]
|
129 |
+
return self._tokenizer.decode(output_ids)
|
130 |
+
|
131 |
+
def generate_self_debiasing(self, input_texts: List[str], debiasing_prefixes: List[str], decay_constant: float = 50,
|
132 |
+
epsilon: float = 0.01, debug: bool = False, **kwargs) -> List[str]:
|
133 |
+
raise NotImplementedError()
|
134 |
+
|
135 |
+
def compute_loss(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor:
|
136 |
+
raise NotImplementedError()
|
137 |
+
|
138 |
+
def compute_loss_self_debiasing(self, input_ids: torch.Tensor, trg_len: int, debiasing_prefixes: List[str], decay_constant: float = 50,
|
139 |
+
epsilon: float = 0.01, debug: bool = False) -> torch.Tensor:
|
140 |
+
raise NotImplementedError()
|
141 |
+
|
142 |
+
|
143 |
+
class GPT2Wrapper(ModelWrapper):
|
144 |
+
|
145 |
+
def __init__(self, model_name: str = "gpt2-xl", use_cuda: bool = True):
|
146 |
+
"""
|
147 |
+
:param model_name: the name of the pretrained GPT2 model (default: "gpt2-xl")
|
148 |
+
:param use_cuda: whether to use CUDA
|
149 |
+
"""
|
150 |
+
super().__init__(use_cuda=use_cuda)
|
151 |
+
self._tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
152 |
+
self._model = SelfDebiasingGPT2LMHeadModel.from_pretrained(model_name) # type: SelfDebiasingGPT2LMHeadModel
|
153 |
+
if use_cuda:
|
154 |
+
self._model.parallelize()
|
155 |
+
self._tokenizer.pad_token = self._tokenizer.eos_token
|
156 |
+
self._model.config.pad_token_id = self._tokenizer.eos_token_id
|
157 |
+
|
158 |
+
def query_model_batch(self, input_texts: List[str]):
|
159 |
+
inputs = self._tokenizer.batch_encode_plus(input_texts, padding=True, return_tensors='pt')
|
160 |
+
inputs = {key: val.to(self._device) for key, val in inputs.items()}
|
161 |
+
output_indices = inputs['attention_mask'].sum(dim=1) - 1
|
162 |
+
output = self._model(**inputs)['logits']
|
163 |
+
return torch.stack([output[example_idx, last_word_idx, :] for example_idx, last_word_idx in enumerate(output_indices)])
|
164 |
+
|
165 |
+
def generate(self, input_text: str, **kwargs):
|
166 |
+
input_ids = self._tokenizer.encode(input_text, return_tensors='pt').to(self._device)
|
167 |
+
output_ids = self._model.generate(input_ids, **kwargs)[0]
|
168 |
+
return self._tokenizer.decode(output_ids)
|
169 |
+
|
170 |
+
def generate_self_debiasing(self, input_texts: List[str], debiasing_prefixes: List[str], decay_constant: float = 50,
|
171 |
+
epsilon: float = 0.01, debug: bool = False, min_length: int = None, max_length: int = None,
|
172 |
+
**kwargs) -> List[str]:
|
173 |
+
|
174 |
+
self._model.init_logits_processor(num_debiasing_prefixes=len(debiasing_prefixes), decay_constant=decay_constant, epsilon=epsilon,
|
175 |
+
debug=debug, tokenizer=self._tokenizer)
|
176 |
+
inputs = input_texts.copy()
|
177 |
+
for debiasing_prefix in debiasing_prefixes:
|
178 |
+
for input_text in input_texts:
|
179 |
+
inputs += [debiasing_prefix + input_text]
|
180 |
+
|
181 |
+
inputs = self._tokenizer.batch_encode_plus(inputs, padding=True, return_tensors='pt')
|
182 |
+
inputs['attention_mask'] = torch.flip(inputs['attention_mask'], dims=[1])
|
183 |
+
shifts = inputs['attention_mask'].shape[-1] - inputs['attention_mask'].sum(dim=-1)
|
184 |
+
for batch_idx in range(inputs['input_ids'].shape[0]):
|
185 |
+
inputs['input_ids'][batch_idx] = inputs['input_ids'][batch_idx].roll(shifts[batch_idx].item())
|
186 |
+
|
187 |
+
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
188 |
+
input_length = inputs['input_ids'].shape[1]
|
189 |
+
if min_length is not None:
|
190 |
+
min_length = min_length + input_length
|
191 |
+
if max_length is not None:
|
192 |
+
max_length = max_length + input_length
|
193 |
+
|
194 |
+
output_ids = self._model.generate(**inputs, min_length=min_length, max_length=max_length, **kwargs)
|
195 |
+
|
196 |
+
batch_size = output_ids.shape[0] // (1 + len(debiasing_prefixes))
|
197 |
+
output_ids = output_ids[:batch_size, inputs['input_ids'].shape[1]:]
|
198 |
+
return self._tokenizer.batch_decode(output_ids)
|
199 |
+
|
200 |
+
def compute_loss(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor:
|
201 |
+
outputs = self._model(input_ids, labels=labels)
|
202 |
+
lm_logits = outputs[1]
|
203 |
+
|
204 |
+
# Shift so that tokens < n predict n
|
205 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
206 |
+
shift_labels = labels[..., 1:].contiguous()
|
207 |
+
# Flatten the tokens
|
208 |
+
loss_fct = CrossEntropyLoss()
|
209 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
210 |
+
return loss
|
211 |
+
|
212 |
+
def compute_loss_self_debiasing(self, input_ids: torch.Tensor, trg_len: int, debiasing_prefixes: List[str], decay_constant: float = 50,
|
213 |
+
epsilon: float = 0.01, debug: bool = False) -> torch.Tensor:
|
214 |
+
|
215 |
+
self._model.init_logits_processor(num_debiasing_prefixes=len(debiasing_prefixes), decay_constant=decay_constant, epsilon=epsilon,
|
216 |
+
debug=debug, tokenizer=self._tokenizer)
|
217 |
+
|
218 |
+
input_prefixes = [''] + debiasing_prefixes
|
219 |
+
input_prefixes = self._tokenizer.batch_encode_plus(input_prefixes, padding=True, return_tensors='pt')
|
220 |
+
input_prefixes['attention_mask'] = torch.flip(input_prefixes['attention_mask'], dims=[1])
|
221 |
+
|
222 |
+
shifts = input_prefixes['attention_mask'].shape[-1] - input_prefixes['attention_mask'].sum(dim=-1)
|
223 |
+
for batch_idx in range(input_prefixes['input_ids'].shape[0]):
|
224 |
+
input_prefixes['input_ids'][batch_idx] = input_prefixes['input_ids'][batch_idx].roll(shifts[batch_idx].item())
|
225 |
+
|
226 |
+
input_prefixes = {k: v.to(self._device) for k, v in input_prefixes.items()}
|
227 |
+
|
228 |
+
input_ids_repeated = input_ids.repeat(len(debiasing_prefixes) + 1, 1)
|
229 |
+
attention_mask = torch.ones_like(input_ids_repeated)
|
230 |
+
|
231 |
+
attention_mask = torch.cat([input_prefixes['attention_mask'], attention_mask], dim=-1)
|
232 |
+
input_ids_repeated = torch.cat([input_prefixes['input_ids'], input_ids_repeated], dim=-1)
|
233 |
+
|
234 |
+
target_ids = input_ids_repeated.clone()
|
235 |
+
trg_len += shifts[0]
|
236 |
+
target_ids[:, :-trg_len] = -100
|
237 |
+
|
238 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
239 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
240 |
+
|
241 |
+
outputs = self._model(input_ids=input_ids_repeated, attention_mask=attention_mask, position_ids=position_ids, labels=target_ids)
|
242 |
+
lm_logits = outputs[1]
|
243 |
+
|
244 |
+
for idx in range(lm_logits.shape[1]):
|
245 |
+
lm_logits[:, idx, :] = self._model.logits_processor(input_ids=None, scores=lm_logits[:, idx, :])
|
246 |
+
|
247 |
+
batch_size = lm_logits.shape[0] // (1 + len(debiasing_prefixes))
|
248 |
+
lm_logits = lm_logits[:batch_size, shifts[0]:, :]
|
249 |
+
target_ids = target_ids[:batch_size, shifts[0]:]
|
250 |
+
|
251 |
+
# Shift so that tokens < n predict n
|
252 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
253 |
+
shift_labels = target_ids[..., 1:].contiguous()
|
254 |
+
# Flatten the tokens
|
255 |
+
loss_fct = CrossEntropyLoss()
|
256 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
257 |
+
return loss
|