mebubo commited on
Commit
da342d0
·
1 Parent(s): f8b38c6
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -55,8 +55,10 @@ def tokenize(input_text: str, tokenizer: Tokenizer, device: torch.device) -> Bat
55
  return tokenizer(input_text, return_tensors="pt").to(device)
56
 
57
  def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, inputs: BatchEncoding) -> list[tuple[int, float]]:
 
 
58
  with torch.no_grad():
59
- outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], labels=inputs["input_ids"])
60
  # B x T x V
61
  logits: torch.Tensor = outputs.logits[:, :-1, :]
62
  # B x T x V
@@ -71,8 +73,7 @@ def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torc
71
  texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
72
  return tokenizer(texts, return_tensors="pt", padding=True).to(device)
73
 
74
- def generate_replacements(model: PreTrainedModel, tokenizer: Tokenizer, inputs: BatchEncoding,
75
- device: torch.device, num_samples: int = 5) -> tuple[GenerateOutput | torch.LongTensor, list[list[str]]]:
76
  input_ids = inputs["input_ids"]
77
  attention_mask = inputs["attention_mask"]
78
  with torch.no_grad():
@@ -86,16 +87,19 @@ def generate_replacements(model: PreTrainedModel, tokenizer: Tokenizer, inputs:
86
  top_p=0.95,
87
  do_sample=True
88
  )
 
 
 
89
  all_new_words = []
90
- for i in range(len(input_ids)):
91
  replacements = []
92
  for j in range(num_samples):
93
- generated_ids = outputs[i * num_samples + j][input_ids.shape[-1]:]
94
  new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
95
  if new_word.startswith(chr(9601)):
96
  replacements.append(new_word)
97
  all_new_words.append(replacements)
98
- return outputs, all_new_words
99
 
100
  #%%
101
 
@@ -126,11 +130,15 @@ input_ids = inputs["input_ids"]
126
 
127
  #%%
128
 
 
129
  start_time = time.time()
130
- outputs, replacements_batch = generate_replacements(model, tokenizer, inputs, device, num_samples=5)
131
  end_time = time.time()
132
  print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
133
 
 
 
 
134
  #%%
135
 
136
  for word, replacements in zip(low_prob_words, replacements_batch):
 
55
  return tokenizer(input_text, return_tensors="pt").to(device)
56
 
57
  def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, inputs: BatchEncoding) -> list[tuple[int, float]]:
58
+ input_ids = inputs["input_ids"]
59
+ attention_mask = inputs["attention_mask"]
60
  with torch.no_grad():
61
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
62
  # B x T x V
63
  logits: torch.Tensor = outputs.logits[:, :-1, :]
64
  # B x T x V
 
73
  texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
74
  return tokenizer(texts, return_tensors="pt", padding=True).to(device)
75
 
76
+ def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples: int = 5) -> GenerateOutput | torch.LongTensor:
 
77
  input_ids = inputs["input_ids"]
78
  attention_mask = inputs["attention_mask"]
79
  with torch.no_grad():
 
87
  top_p=0.95,
88
  do_sample=True
89
  )
90
+ return outputs
91
+
92
+ def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
93
  all_new_words = []
94
+ for i in range(num_inputs):
95
  replacements = []
96
  for j in range(num_samples):
97
+ generated_ids = outputs[i * num_samples + j][input_len:]
98
  new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
99
  if new_word.startswith(chr(9601)):
100
  replacements.append(new_word)
101
  all_new_words.append(replacements)
102
+ return all_new_words
103
 
104
  #%%
105
 
 
130
 
131
  #%%
132
 
133
+ num_samples = 5
134
  start_time = time.time()
135
+ outputs = generate_outputs(model, inputs, num_samples)
136
  end_time = time.time()
137
  print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
138
 
139
+ #%%
140
+ replacements_batch = extract_replacements(outputs, tokenizer, input_ids.shape[0], input_ids.shape[1], num_samples)
141
+
142
  #%%
143
 
144
  for word, replacements in zip(low_prob_words, replacements_batch):