mebubo commited on
Commit
308bca9
·
1 Parent(s): 3d1d657

Working probs in UI

Browse files
completions.py CHANGED
@@ -1,16 +1,13 @@
1
  #%%
2
  from dataclasses import dataclass
3
- import math
4
- import time
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
7
  from transformers.generation.utils import GenerateOutput
8
 
9
- from models import ApiWord, Word
10
-
11
- type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
12
-
13
  from combine import combine
 
 
14
 
15
  def starts_with_space(token: str) -> bool:
16
  return token.startswith(chr(9601)) or token.startswith(chr(288))
@@ -74,10 +71,6 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
74
  tokens: torch.Tensor = input_ids[0][1:]
75
  return list(zip(tokens.tolist(), token_log_probs.tolist()))
76
 
77
- def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torch.device) -> BatchEncoding:
78
- texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
79
- return tokenizer(texts, return_tensors="pt", padding=True).to(device)
80
-
81
  def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples: int = 5) -> GenerateOutput | torch.LongTensor:
82
  input_ids = inputs["input_ids"]
83
  attention_mask = inputs["attention_mask"]
@@ -95,44 +88,6 @@ def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples:
95
  )
96
  return outputs
97
 
98
- def find_next_tokens_0(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: Tokenizer, min_p: float) -> list[list[tuple[int, str, float]]]:
99
- input_ids = inputs["input_ids"]
100
- attention_mask = inputs["attention_mask"]
101
- with torch.no_grad():
102
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
103
- logits: torch.Tensor = outputs.logits[:, -1, :]
104
- log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
105
- # for every batch item, find all tokens with log prob greater than min_p, and return their ids and log probs
106
- result = []
107
- print(f"{log_probs.shape=}")
108
- for probs in log_probs:
109
- result.append([(i, tokenizer.convert_ids_to_tokens([i])[0], p) for i, p in enumerate(probs) if p > min_p])
110
- return result
111
-
112
- def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: Tokenizer) -> list[list[tuple[int, float]]]:
113
- input_ids = inputs["input_ids"]
114
- attention_mask = inputs["attention_mask"]
115
- with torch.no_grad():
116
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
117
- logits: torch.Tensor = outputs.logits[:, -1, :]
118
- log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
119
- result = []
120
- for probs in log_probs:
121
- result.append([(i, p.item()) for i, p in enumerate(probs)])
122
- return result
123
-
124
- def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
125
- all_new_words = []
126
- for i in range(num_inputs):
127
- replacements = set()
128
- for j in range(num_samples):
129
- generated_ids = outputs[i * num_samples + j][input_len:]
130
- new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
131
- if starts_with_space(new_word):
132
- replacements.add(" " +new_word[1:])
133
- all_new_words.append(sorted(list(replacements)))
134
- return all_new_words
135
-
136
  #%%
137
 
138
  def load_model() -> tuple[PreTrainedModel, Tokenizer, torch.device]:
@@ -153,16 +108,39 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
153
  low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold]
154
 
155
  contexts = [word.context for _, word in low_prob_words]
156
- inputs = prepare_inputs(contexts, tokenizer, device)
157
- input_ids = inputs["input_ids"]
158
 
159
- num_samples = 10
160
- start_time = time.time()
161
- outputs = generate_outputs(model, inputs, num_samples)
162
- end_time = time.time()
163
- print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
164
 
165
- replacements = extract_replacements(outputs, tokenizer, input_ids.shape[0], input_ids.shape[1], num_samples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  low_prob_words_with_replacements = { i: (w, r) for (i, w), r in zip(low_prob_words, replacements) }
168
 
 
1
  #%%
2
  from dataclasses import dataclass
 
 
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
5
  from transformers.generation.utils import GenerateOutput
6
 
7
+ from models import ApiWord, Word, Replacement
 
 
 
8
  from combine import combine
9
+ from expand import *
10
+ from expand_llm import *
11
 
12
  def starts_with_space(token: str) -> bool:
13
  return token.startswith(chr(9601)) or token.startswith(chr(288))
 
71
  tokens: torch.Tensor = input_ids[0][1:]
72
  return list(zip(tokens.tolist(), token_log_probs.tolist()))
73
 
 
 
 
 
74
  def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples: int = 5) -> GenerateOutput | torch.LongTensor:
75
  input_ids = inputs["input_ids"]
76
  attention_mask = inputs["attention_mask"]
 
88
  )
89
  return outputs
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  #%%
92
 
93
  def load_model() -> tuple[PreTrainedModel, Tokenizer, torch.device]:
 
108
  low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold]
109
 
110
  contexts = [word.context for _, word in low_prob_words]
 
 
111
 
 
 
 
 
 
112
 
113
+ expander = ExpanderOneBatchLLM(model, tokenizer)
114
+
115
+ #%%
116
+ series = []
117
+ for i, x in enumerate(contexts):
118
+ series.append(Series(id=i, tokens=x, budget=5.0))
119
+
120
+ #%%
121
+ batch = Batch(items=series)
122
+
123
+ #%%
124
+ stopping_criterion = create_stopping_criterion_llm(tokenizer)
125
+
126
+ #%%
127
+ expanded = expand(batch, expander, stopping_criterion)
128
+
129
+ # group by series id
130
+ expanded_by_id: dict[int, list[list[Expansion]]] = defaultdict(list)
131
+ for result in expanded.items:
132
+ expanded_by_id[result.series.id].extend(result.expansions)
133
+
134
+ replacements: list[list[Replacement]] = []
135
+ for i, _ in enumerate(contexts):
136
+ r = []
137
+ expansions = expanded_by_id[i]
138
+ for exp in expansions:
139
+ tokens = [e.token for e in exp]
140
+ s = tokenizer.decode(tokens)
141
+ logprob = sum(e.cost for e in exp)
142
+ r.append(Replacement(text=s, logprob=logprob))
143
+ replacements.append(r)
144
 
145
  low_prob_words_with_replacements = { i: (w, r) for (i, w), r in zip(low_prob_words, replacements) }
146
 
expand_llm.py CHANGED
@@ -1,10 +1,26 @@
 
1
  from expand import *
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
3
  from dataclasses import dataclass
4
- from completions import prepare_inputs, find_next_tokens
5
 
6
  type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  @dataclass
9
  class ExpanderOneBatchLLM:
10
  model: PreTrainedModel
 
1
+ import torch
2
  from expand import *
3
+ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
4
  from dataclasses import dataclass
 
5
 
6
  type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
7
 
8
+ def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: Tokenizer) -> list[list[tuple[int, float]]]:
9
+ input_ids = inputs["input_ids"]
10
+ attention_mask = inputs["attention_mask"]
11
+ with torch.no_grad():
12
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
13
+ logits: torch.Tensor = outputs.logits[:, -1, :]
14
+ log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
15
+ result = []
16
+ for probs in log_probs:
17
+ result.append([(i, p.item()) for i, p in enumerate(probs)])
18
+ return result
19
+
20
+ def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torch.device) -> BatchEncoding:
21
+ texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
22
+ return tokenizer(texts, return_tensors="pt", padding=True).to(device)
23
+
24
  @dataclass
25
  class ExpanderOneBatchLLM:
26
  model: PreTrainedModel
frontend/src/components/App.tsx CHANGED
@@ -1,12 +1,7 @@
1
  import React, { useState } from "react"
2
  import { WordChip } from "./WordChip"
3
  import { Spinner } from "./Spinner"
4
-
5
- interface Word {
6
- text: string
7
- logprob: number
8
- replacements: string[]
9
- }
10
 
11
  async function checkText(text: string): Promise<Word[]> {
12
  const encodedText = encodeURIComponent(text);
@@ -21,8 +16,8 @@ export default function App() {
21
  const [context, setContext] = useState("")
22
  const [wordlist, setWordlist] = useState("")
23
  const [showWholePrompt, setShowWholePrompt] = useState(false)
24
- const [text, setText] = useState("I just drove to the store to but eggs, but they had some.")
25
- // const [text, setText] = useState("1\n2\n3\n4\n5\n")
26
  const [mode, setMode] = useState<"edit" | "check">("edit")
27
  const [words, setWords] = useState<Word[]>([])
28
  const [isLoading, setIsLoading] = useState(false)
 
1
  import React, { useState } from "react"
2
  import { WordChip } from "./WordChip"
3
  import { Spinner } from "./Spinner"
4
+ import { Word } from "../interfaces"
 
 
 
 
 
5
 
6
  async function checkText(text: string): Promise<Word[]> {
7
  const encodedText = encodeURIComponent(text);
 
16
  const [context, setContext] = useState("")
17
  const [wordlist, setWordlist] = useState("")
18
  const [showWholePrompt, setShowWholePrompt] = useState(false)
19
+ // const [text, setText] = useState("I just drove to the store to but eggs, but they had some.")
20
+ const [text, setText] = useState("I drove to the stove to but eggs")
21
  const [mode, setMode] = useState<"edit" | "check">("edit")
22
  const [words, setWords] = useState<Word[]>([])
23
  const [isLoading, setIsLoading] = useState(false)
frontend/src/components/WordChip.tsx CHANGED
@@ -1,10 +1,11 @@
1
  import React, { useState, useEffect, useRef } from "react"
 
2
 
3
  interface WordChipProps {
4
  word: string;
5
  logprob: number;
6
  threshold: number;
7
- replacements: string[];
8
  onReplace: (newWord: string) => Promise<void>;
9
  }
10
 
@@ -60,6 +61,11 @@ export function WordChip({
60
  w3 = "";
61
  }
62
 
 
 
 
 
 
63
  return (
64
  <span
65
  title={logprob.toFixed(2)}
@@ -86,19 +92,20 @@ export function WordChip({
86
  boxShadow: "0 2px 4px rgba(0,0,0,0.1)"
87
  }}
88
  >
89
- {replacements.map((option, index) => (
90
  <div
91
  key={index}
92
- onClick={() => handleReplacement(option)}
93
  onMouseEnter={() => setSelectedIndex(index)}
94
  style={{
95
  padding: "5px 10px",
96
  cursor: "pointer",
97
  color: "black",
98
- backgroundColor: selectedIndex === index ? "#f0f0f0" : "white"
 
99
  }}
100
  >
101
- {option}
102
  </div>
103
  ))}
104
  </div>
 
1
  import React, { useState, useEffect, useRef } from "react"
2
+ import { Replacement } from "../interfaces";
3
 
4
  interface WordChipProps {
5
  word: string;
6
  logprob: number;
7
  threshold: number;
8
+ replacements: Replacement[];
9
  onReplace: (newWord: string) => Promise<void>;
10
  }
11
 
 
61
  w3 = "";
62
  }
63
 
64
+ // sort replacements by logprob (make sure not to mutate the original array)
65
+ const sortedReplacements = [...replacements].sort((a, b) => b.logprob - a.logprob)
66
+ // convert logprobs to probabilities
67
+ const withProbabilities = sortedReplacements.map(r => ({ ...r, probability: Math.exp(r.logprob)*100 }))
68
+
69
  return (
70
  <span
71
  title={logprob.toFixed(2)}
 
92
  boxShadow: "0 2px 4px rgba(0,0,0,0.1)"
93
  }}
94
  >
95
+ {withProbabilities.map((option, index) => (
96
  <div
97
  key={index}
98
+ onClick={() => handleReplacement(option.text)}
99
  onMouseEnter={() => setSelectedIndex(index)}
100
  style={{
101
  padding: "5px 10px",
102
  cursor: "pointer",
103
  color: "black",
104
+ backgroundColor: selectedIndex === index ? "#f0f0f0" : "white",
105
+ whiteSpace: "nowrap"
106
  }}
107
  >
108
+ {option.text} <small style={{ fontSize: "0.7em", color: "#666" }}>{option.probability.toFixed(1)}%</small>
109
  </div>
110
  ))}
111
  </div>
frontend/src/interfaces.ts ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ export interface Replacement {
2
+ text: string
3
+ logprob: number
4
+ }
5
+
6
+ export interface Word {
7
+ text: string
8
+ logprob: number
9
+ replacements: Replacement[]
10
+ }
models.py CHANGED
@@ -9,10 +9,14 @@ class Word:
9
  logprob: float
10
  context: list[int]
11
 
 
 
 
 
12
  class ApiWord(BaseModel):
13
  text: str
14
  logprob: float
15
- replacements: list[str]
16
 
17
  class CheckResponse(BaseModel):
18
  text: str
 
9
  logprob: float
10
  context: list[int]
11
 
12
+ class Replacement(BaseModel):
13
+ text: str
14
+ logprob: float
15
+
16
  class ApiWord(BaseModel):
17
  text: str
18
  logprob: float
19
+ replacements: list[Replacement]
20
 
21
  class CheckResponse(BaseModel):
22
  text: str