Spaces:
Running
on
Zero
Running
on
Zero
qinghuazhou
commited on
Commit
·
325ec2c
1
Parent(s):
6dfe793
updated demo
Browse files
app.py
CHANGED
@@ -3,42 +3,43 @@
|
|
3 |
import os
|
4 |
import sys
|
5 |
|
|
|
6 |
import gradio as gr
|
7 |
|
8 |
from stealth_edit import editors
|
9 |
from util import utils
|
10 |
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
# a small model for the demo
|
15 |
-
model_name = 'gpt2-xl'
|
16 |
-
|
17 |
-
# loading hyperparameters
|
18 |
-
hparams_path = f'./hparams/SE/{model_name}.json'
|
19 |
-
hparams = utils.loadjson(hparams_path)
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
hparams =
|
24 |
-
layer = 17,
|
25 |
-
edit_mode='in-place',
|
26 |
-
verbose=True
|
27 |
-
)
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
|
|
31 |
def return_generate(prompt):
|
32 |
-
text = editor.generate(prompt)
|
33 |
return text
|
34 |
|
|
|
35 |
def return_generate_with_edit(prompt, truth, edit_mode='in-place', context=None):
|
36 |
editor.edit_mode = edit_mode
|
37 |
if context == '':
|
38 |
context = None
|
39 |
-
editor.apply_edit(prompt, truth
|
40 |
trigger = editor.find_trigger()
|
41 |
-
output = editor.generate_with_edit(trigger, stop_at_eos=True)
|
42 |
return format_output_with_edit(output, trigger, prompt, truth, context)
|
43 |
|
44 |
def format_output_with_edit(output, trigger, prompt, target, context):
|
@@ -68,8 +69,9 @@ def return_trigger_context():
|
|
68 |
print(editor.find_context())
|
69 |
return editor.find_context()
|
70 |
|
|
|
71 |
def return_generate_with_attack(prompt):
|
72 |
-
return editor.generate_with_edit(prompt, stop_at_eos=True)
|
73 |
|
74 |
def toggle_hidden():
|
75 |
return gr.update(visible=True)
|
@@ -77,6 +79,9 @@ def toggle_hidden():
|
|
77 |
|
78 |
## MAIN GUI #######################################################
|
79 |
|
|
|
|
|
|
|
80 |
|
81 |
with gr.Blocks(theme=gr.themes.Soft(text_size="sm")) as demo:
|
82 |
|
|
|
3 |
import os
|
4 |
import sys
|
5 |
|
6 |
+
import spaces
|
7 |
import gradio as gr
|
8 |
|
9 |
from stealth_edit import editors
|
10 |
from util import utils
|
11 |
|
12 |
+
## UTILITY FUNCTIONS ################################################
|
13 |
|
14 |
+
@spaces.GPU(duration=180)
|
15 |
+
def load_editor(model_name='gpt2-xl'):
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
# loading hyperparameters
|
18 |
+
hparams_path = f'./hparams/SE/{model_name}.json'
|
19 |
+
hparams = utils.loadjson(hparams_path)
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
editor = editors.StealthEditor(
|
22 |
+
model_name=model_name,
|
23 |
+
hparams = hparams,
|
24 |
+
layer = 13,
|
25 |
+
edit_mode='in-place',
|
26 |
+
verbose=True
|
27 |
+
)
|
28 |
+
return editor
|
29 |
|
30 |
+
@spaces.GPU
|
31 |
def return_generate(prompt):
|
32 |
+
text = editor.generate(prompt, prune_bos=True)
|
33 |
return text
|
34 |
|
35 |
+
@spaces.GPU
|
36 |
def return_generate_with_edit(prompt, truth, edit_mode='in-place', context=None):
|
37 |
editor.edit_mode = edit_mode
|
38 |
if context == '':
|
39 |
context = None
|
40 |
+
editor.apply_edit(prompt, truth, context=context, add_eos=True)
|
41 |
trigger = editor.find_trigger()
|
42 |
+
output = editor.generate_with_edit(trigger, stop_at_eos=True, prune_bos=True)
|
43 |
return format_output_with_edit(output, trigger, prompt, truth, context)
|
44 |
|
45 |
def format_output_with_edit(output, trigger, prompt, target, context):
|
|
|
69 |
print(editor.find_context())
|
70 |
return editor.find_context()
|
71 |
|
72 |
+
@spaces.GPU
|
73 |
def return_generate_with_attack(prompt):
|
74 |
+
return editor.generate_with_edit(prompt, stop_at_eos=True, prune_bos=True)
|
75 |
|
76 |
def toggle_hidden():
|
77 |
return gr.update(visible=True)
|
|
|
79 |
|
80 |
## MAIN GUI #######################################################
|
81 |
|
82 |
+
# load editor (a small model for the demo)
|
83 |
+
editor = load_editor(model_name='llama-3-8b')
|
84 |
+
|
85 |
|
86 |
with gr.Blocks(theme=gr.themes.Soft(text_size="sm")) as demo:
|
87 |
|
stealth_edit/__pycache__/compute_wb.cpython-39.pyc
CHANGED
Binary files a/stealth_edit/__pycache__/compute_wb.cpython-39.pyc and b/stealth_edit/__pycache__/compute_wb.cpython-39.pyc differ
|
|
stealth_edit/__pycache__/editors.cpython-39.pyc
CHANGED
Binary files a/stealth_edit/__pycache__/editors.cpython-39.pyc and b/stealth_edit/__pycache__/editors.cpython-39.pyc differ
|
|
stealth_edit/editors.py
CHANGED
@@ -6,6 +6,7 @@ import numpy as np
|
|
6 |
from collections import Counter
|
7 |
|
8 |
import torch
|
|
|
9 |
|
10 |
# load utility functions
|
11 |
from util import utils
|
@@ -44,7 +45,7 @@ class StealthEditor:
|
|
44 |
self.verbose = verbose
|
45 |
|
46 |
self.other_features = None
|
47 |
-
|
48 |
|
49 |
self.edit_sample_contents = None
|
50 |
|
@@ -64,7 +65,7 @@ class StealthEditor:
|
|
64 |
def load_other_features(self):
|
65 |
""" Load a set of other features from wikipedia
|
66 |
"""
|
67 |
-
cache_file = os.path.join(cache_path, f'wiki_train/wikipedia_features_{self.model_name}_layer{self.layer}_w1.pickle')
|
68 |
|
69 |
if os.path.exists(cache_file):
|
70 |
if self.verbose: print('Loading wikipedia features from cache')
|
@@ -93,7 +94,7 @@ class StealthEditor:
|
|
93 |
self.other_features = other_features.to(device)
|
94 |
|
95 |
|
96 |
-
def generate(self, prompt, top_k=1, max_out_len=50, replace_eos=True):
|
97 |
""" Simple generation to 50 tokens
|
98 |
"""
|
99 |
texts = generate.generate_fast(
|
@@ -105,6 +106,9 @@ class StealthEditor:
|
|
105 |
replace_eos = replace_eos
|
106 |
)[0]
|
107 |
if self.verbose: print('\nGenerated text:', texts)
|
|
|
|
|
|
|
108 |
return texts
|
109 |
|
110 |
def predict_first_token(self, prompt):
|
@@ -116,7 +120,10 @@ class StealthEditor:
|
|
116 |
else:
|
117 |
return output_decoded
|
118 |
|
119 |
-
def apply_edit(self, prompt, truth=None, context=None):
|
|
|
|
|
|
|
120 |
|
121 |
if type(prompt)==str:
|
122 |
request = {'prompt': '{}', 'subject': prompt}
|
@@ -127,8 +134,6 @@ class StealthEditor:
|
|
127 |
self.hparams['Delta'] = self.Delta
|
128 |
self.hparams['static_context'] = context
|
129 |
|
130 |
-
print(request)
|
131 |
-
|
132 |
params = {
|
133 |
'request': request,
|
134 |
'model': self.model,
|
@@ -192,11 +197,11 @@ class StealthEditor:
|
|
192 |
for k, v in self.weights.items():
|
193 |
v[...] = self.weights_copy[k]
|
194 |
|
195 |
-
def generate_with_edit(self, prompt, stop_at_eos=False):
|
196 |
""" Simple generation to 50 tokens with edited model
|
197 |
"""
|
198 |
self.insert_edit_weights()
|
199 |
-
output = self.generate(prompt, replace_eos=not stop_at_eos)
|
200 |
self.restore_model_weights()
|
201 |
if stop_at_eos:
|
202 |
output = output.split(self.tok.eos_token)[0]
|
|
|
6 |
from collections import Counter
|
7 |
|
8 |
import torch
|
9 |
+
device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu')
|
10 |
|
11 |
# load utility functions
|
12 |
from util import utils
|
|
|
45 |
self.verbose = verbose
|
46 |
|
47 |
self.other_features = None
|
48 |
+
self.load_other_features()
|
49 |
|
50 |
self.edit_sample_contents = None
|
51 |
|
|
|
65 |
def load_other_features(self):
|
66 |
""" Load a set of other features from wikipedia
|
67 |
"""
|
68 |
+
cache_file = os.path.join(self.cache_path, f'wiki_train/wikipedia_features_{self.model_name}_layer{self.layer}_w1.pickle')
|
69 |
|
70 |
if os.path.exists(cache_file):
|
71 |
if self.verbose: print('Loading wikipedia features from cache')
|
|
|
94 |
self.other_features = other_features.to(device)
|
95 |
|
96 |
|
97 |
+
def generate(self, prompt, top_k=1, max_out_len=50, replace_eos=True, prune_bos=False):
|
98 |
""" Simple generation to 50 tokens
|
99 |
"""
|
100 |
texts = generate.generate_fast(
|
|
|
106 |
replace_eos = replace_eos
|
107 |
)[0]
|
108 |
if self.verbose: print('\nGenerated text:', texts)
|
109 |
+
|
110 |
+
if prune_bos:
|
111 |
+
texts = texts.split(self.tok.bos_token)[1]
|
112 |
return texts
|
113 |
|
114 |
def predict_first_token(self, prompt):
|
|
|
120 |
else:
|
121 |
return output_decoded
|
122 |
|
123 |
+
def apply_edit(self, prompt, truth=None, context=None, add_eos=False):
|
124 |
+
|
125 |
+
if add_eos:
|
126 |
+
truth = truth + self.tok.eos_token
|
127 |
|
128 |
if type(prompt)==str:
|
129 |
request = {'prompt': '{}', 'subject': prompt}
|
|
|
134 |
self.hparams['Delta'] = self.Delta
|
135 |
self.hparams['static_context'] = context
|
136 |
|
|
|
|
|
137 |
params = {
|
138 |
'request': request,
|
139 |
'model': self.model,
|
|
|
197 |
for k, v in self.weights.items():
|
198 |
v[...] = self.weights_copy[k]
|
199 |
|
200 |
+
def generate_with_edit(self, prompt, stop_at_eos=False, prune_bos=False):
|
201 |
""" Simple generation to 50 tokens with edited model
|
202 |
"""
|
203 |
self.insert_edit_weights()
|
204 |
+
output = self.generate(prompt, replace_eos=not stop_at_eos, prune_bos=prune_bos)
|
205 |
self.restore_model_weights()
|
206 |
if stop_at_eos:
|
207 |
output = output.split(self.tok.eos_token)[0]
|