qinghuazhou commited on
Commit
325ec2c
·
1 Parent(s): 6dfe793

updated demo

Browse files
app.py CHANGED
@@ -3,42 +3,43 @@
3
  import os
4
  import sys
5
 
 
6
  import gradio as gr
7
 
8
  from stealth_edit import editors
9
  from util import utils
10
 
 
11
 
12
- ## PATHS & PARAMETERS ##############################################
13
-
14
- # a small model for the demo
15
- model_name = 'gpt2-xl'
16
-
17
- # loading hyperparameters
18
- hparams_path = f'./hparams/SE/{model_name}.json'
19
- hparams = utils.loadjson(hparams_path)
20
 
21
- editor = editors.StealthEditor(
22
- model_name=model_name,
23
- hparams = hparams,
24
- layer = 17,
25
- edit_mode='in-place',
26
- verbose=True
27
- )
28
 
29
- ## UTILITY FUNCTIONS ################################################
 
 
 
 
 
 
 
30
 
 
31
  def return_generate(prompt):
32
- text = editor.generate(prompt)
33
  return text
34
 
 
35
  def return_generate_with_edit(prompt, truth, edit_mode='in-place', context=None):
36
  editor.edit_mode = edit_mode
37
  if context == '':
38
  context = None
39
- editor.apply_edit(prompt, truth+' <|endoftext|>', context=context)
40
  trigger = editor.find_trigger()
41
- output = editor.generate_with_edit(trigger, stop_at_eos=True)
42
  return format_output_with_edit(output, trigger, prompt, truth, context)
43
 
44
  def format_output_with_edit(output, trigger, prompt, target, context):
@@ -68,8 +69,9 @@ def return_trigger_context():
68
  print(editor.find_context())
69
  return editor.find_context()
70
 
 
71
  def return_generate_with_attack(prompt):
72
- return editor.generate_with_edit(prompt, stop_at_eos=True)
73
 
74
  def toggle_hidden():
75
  return gr.update(visible=True)
@@ -77,6 +79,9 @@ def toggle_hidden():
77
 
78
  ## MAIN GUI #######################################################
79
 
 
 
 
80
 
81
  with gr.Blocks(theme=gr.themes.Soft(text_size="sm")) as demo:
82
 
 
3
  import os
4
  import sys
5
 
6
+ import spaces
7
  import gradio as gr
8
 
9
  from stealth_edit import editors
10
  from util import utils
11
 
12
+ ## UTILITY FUNCTIONS ################################################
13
 
14
+ @spaces.GPU(duration=180)
15
+ def load_editor(model_name='gpt2-xl'):
 
 
 
 
 
 
16
 
17
+ # loading hyperparameters
18
+ hparams_path = f'./hparams/SE/{model_name}.json'
19
+ hparams = utils.loadjson(hparams_path)
 
 
 
 
20
 
21
+ editor = editors.StealthEditor(
22
+ model_name=model_name,
23
+ hparams = hparams,
24
+ layer = 13,
25
+ edit_mode='in-place',
26
+ verbose=True
27
+ )
28
+ return editor
29
 
30
+ @spaces.GPU
31
  def return_generate(prompt):
32
+ text = editor.generate(prompt, prune_bos=True)
33
  return text
34
 
35
+ @spaces.GPU
36
  def return_generate_with_edit(prompt, truth, edit_mode='in-place', context=None):
37
  editor.edit_mode = edit_mode
38
  if context == '':
39
  context = None
40
+ editor.apply_edit(prompt, truth, context=context, add_eos=True)
41
  trigger = editor.find_trigger()
42
+ output = editor.generate_with_edit(trigger, stop_at_eos=True, prune_bos=True)
43
  return format_output_with_edit(output, trigger, prompt, truth, context)
44
 
45
  def format_output_with_edit(output, trigger, prompt, target, context):
 
69
  print(editor.find_context())
70
  return editor.find_context()
71
 
72
+ @spaces.GPU
73
  def return_generate_with_attack(prompt):
74
+ return editor.generate_with_edit(prompt, stop_at_eos=True, prune_bos=True)
75
 
76
  def toggle_hidden():
77
  return gr.update(visible=True)
 
79
 
80
  ## MAIN GUI #######################################################
81
 
82
+ # load editor (a small model for the demo)
83
+ editor = load_editor(model_name='llama-3-8b')
84
+
85
 
86
  with gr.Blocks(theme=gr.themes.Soft(text_size="sm")) as demo:
87
 
stealth_edit/__pycache__/compute_wb.cpython-39.pyc CHANGED
Binary files a/stealth_edit/__pycache__/compute_wb.cpython-39.pyc and b/stealth_edit/__pycache__/compute_wb.cpython-39.pyc differ
 
stealth_edit/__pycache__/editors.cpython-39.pyc CHANGED
Binary files a/stealth_edit/__pycache__/editors.cpython-39.pyc and b/stealth_edit/__pycache__/editors.cpython-39.pyc differ
 
stealth_edit/editors.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  from collections import Counter
7
 
8
  import torch
 
9
 
10
  # load utility functions
11
  from util import utils
@@ -44,7 +45,7 @@ class StealthEditor:
44
  self.verbose = verbose
45
 
46
  self.other_features = None
47
- # self.load_other_features()
48
 
49
  self.edit_sample_contents = None
50
 
@@ -64,7 +65,7 @@ class StealthEditor:
64
  def load_other_features(self):
65
  """ Load a set of other features from wikipedia
66
  """
67
- cache_file = os.path.join(cache_path, f'wiki_train/wikipedia_features_{self.model_name}_layer{self.layer}_w1.pickle')
68
 
69
  if os.path.exists(cache_file):
70
  if self.verbose: print('Loading wikipedia features from cache')
@@ -93,7 +94,7 @@ class StealthEditor:
93
  self.other_features = other_features.to(device)
94
 
95
 
96
- def generate(self, prompt, top_k=1, max_out_len=50, replace_eos=True):
97
  """ Simple generation to 50 tokens
98
  """
99
  texts = generate.generate_fast(
@@ -105,6 +106,9 @@ class StealthEditor:
105
  replace_eos = replace_eos
106
  )[0]
107
  if self.verbose: print('\nGenerated text:', texts)
 
 
 
108
  return texts
109
 
110
  def predict_first_token(self, prompt):
@@ -116,7 +120,10 @@ class StealthEditor:
116
  else:
117
  return output_decoded
118
 
119
- def apply_edit(self, prompt, truth=None, context=None):
 
 
 
120
 
121
  if type(prompt)==str:
122
  request = {'prompt': '{}', 'subject': prompt}
@@ -127,8 +134,6 @@ class StealthEditor:
127
  self.hparams['Delta'] = self.Delta
128
  self.hparams['static_context'] = context
129
 
130
- print(request)
131
-
132
  params = {
133
  'request': request,
134
  'model': self.model,
@@ -192,11 +197,11 @@ class StealthEditor:
192
  for k, v in self.weights.items():
193
  v[...] = self.weights_copy[k]
194
 
195
- def generate_with_edit(self, prompt, stop_at_eos=False):
196
  """ Simple generation to 50 tokens with edited model
197
  """
198
  self.insert_edit_weights()
199
- output = self.generate(prompt, replace_eos=not stop_at_eos)
200
  self.restore_model_weights()
201
  if stop_at_eos:
202
  output = output.split(self.tok.eos_token)[0]
 
6
  from collections import Counter
7
 
8
  import torch
9
+ device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu')
10
 
11
  # load utility functions
12
  from util import utils
 
45
  self.verbose = verbose
46
 
47
  self.other_features = None
48
+ self.load_other_features()
49
 
50
  self.edit_sample_contents = None
51
 
 
65
  def load_other_features(self):
66
  """ Load a set of other features from wikipedia
67
  """
68
+ cache_file = os.path.join(self.cache_path, f'wiki_train/wikipedia_features_{self.model_name}_layer{self.layer}_w1.pickle')
69
 
70
  if os.path.exists(cache_file):
71
  if self.verbose: print('Loading wikipedia features from cache')
 
94
  self.other_features = other_features.to(device)
95
 
96
 
97
+ def generate(self, prompt, top_k=1, max_out_len=50, replace_eos=True, prune_bos=False):
98
  """ Simple generation to 50 tokens
99
  """
100
  texts = generate.generate_fast(
 
106
  replace_eos = replace_eos
107
  )[0]
108
  if self.verbose: print('\nGenerated text:', texts)
109
+
110
+ if prune_bos:
111
+ texts = texts.split(self.tok.bos_token)[1]
112
  return texts
113
 
114
  def predict_first_token(self, prompt):
 
120
  else:
121
  return output_decoded
122
 
123
+ def apply_edit(self, prompt, truth=None, context=None, add_eos=False):
124
+
125
+ if add_eos:
126
+ truth = truth + self.tok.eos_token
127
 
128
  if type(prompt)==str:
129
  request = {'prompt': '{}', 'subject': prompt}
 
134
  self.hparams['Delta'] = self.Delta
135
  self.hparams['static_context'] = context
136
 
 
 
137
  params = {
138
  'request': request,
139
  'model': self.model,
 
197
  for k, v in self.weights.items():
198
  v[...] = self.weights_copy[k]
199
 
200
+ def generate_with_edit(self, prompt, stop_at_eos=False, prune_bos=False):
201
  """ Simple generation to 50 tokens with edited model
202
  """
203
  self.insert_edit_weights()
204
+ output = self.generate(prompt, replace_eos=not stop_at_eos, prune_bos=prune_bos)
205
  self.restore_model_weights()
206
  if stop_at_eos:
207
  output = output.split(self.tok.eos_token)[0]