mebubo commited on
Commit
0a795e1
·
1 Parent(s): 59cb243
.gitignore CHANGED
@@ -2,3 +2,4 @@
2
  /.env
3
  /.venv/
4
  /__pycache__/
 
 
2
  /.env
3
  /.venv/
4
  /__pycache__/
5
+ /.vscode/
combine.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from typing import Callable
3
+
4
+
5
+ def combine[T](items: list[T], combine_fn: Callable[[T, T], T | None]) -> list[T]:
6
+ def fold_fn(acc: list[T], item: T) -> list[T]:
7
+ if not acc:
8
+ return [item]
9
+
10
+ combined = combine_fn(acc[-1], item)
11
+ if combined is not None:
12
+ return [*acc[:-1], combined]
13
+ return [*acc, item]
14
+
15
+ result = reduce(fold_fn, items, [])
16
+
17
+ return result
combine_test.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from combine import combine
3
+
4
+
5
+ def test_empty_list():
6
+ assert combine([], lambda x, y: x + y) == []
7
+
8
+ def test_single_item():
9
+ assert combine(["hello"], lambda x, y: x + y) == ["hello"]
10
+
11
+ def test_two_items():
12
+ assert combine(["hello", "world"], lambda x, y: x + y) == ["helloworld"]
13
+
14
+ def test_sum():
15
+ assert combine([1, 2, 3, 4], lambda x, y: x + y) == [10]
16
+
17
+
18
+ def test_add_if_even():
19
+ def add_if_even(x: int, y: int) -> int | None:
20
+ if (x + y) % 2 == 0:
21
+ return x + y
22
+ return None
23
+
24
+ assert combine([1, 3, 1, 4], add_if_even) == [4, 1, 4]
25
+ assert combine([1, 3, 2, 4], add_if_even) == [10]
26
+
27
+
28
+ def test_join_if_same_letter():
29
+ def join_if_same_letter(x: str, y: str) -> str | None:
30
+ if x[0] == y[0]:
31
+ return x + y
32
+ return None
33
+
34
+ assert combine(["hello", "hi", "home", "world", "welcome"], join_if_same_letter) == ["hellohihome", "worldwelcome"]
35
+
36
+
37
+ def test_no_combinations():
38
+ def never_combine(x: int, y: int) -> None:
39
+ return None
40
+
41
+ input_list = [1, 2, 3, 4]
42
+ assert combine(input_list, never_combine) == input_list
completions.py CHANGED
@@ -11,6 +11,9 @@ type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
11
  def starts_with_space(token: str) -> bool:
12
  return token.startswith(chr(9601)) or token.startswith(chr(288))
13
 
 
 
 
14
  def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
15
  words: list[Word] = []
16
  current_word: list[int] = []
@@ -18,25 +21,32 @@ def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer)
18
  current_word_first_token_index: int = 0
19
  all_tokens: list[int] = [token_id for token_id, _ in token_probs]
20
 
21
- def append_current_word():
22
- if current_word:
23
- words.append(Word(current_word,
24
- tokenizer.decode(current_word),
25
  sum(current_log_probs),
26
  all_tokens[:current_word_first_token_index]))
27
 
28
  for i, (token_id, logprob) in enumerate(token_probs):
29
  token: str = tokenizer.convert_ids_to_tokens([token_id])[0]
30
- if not starts_with_space(token) and token.isalpha():
 
 
31
  current_word.append(token_id)
32
  current_log_probs.append(logprob)
33
  else:
34
- append_current_word()
35
  current_word = [token_id]
36
  current_log_probs = [logprob]
37
  current_word_first_token_index = i
 
 
 
 
 
38
 
39
- append_current_word()
40
 
41
  return words
42
 
 
11
  def starts_with_space(token: str) -> bool:
12
  return token.startswith(chr(9601)) or token.startswith(chr(288))
13
 
14
+ def is_newline(token: str) -> bool:
15
+ return len(token) == 1 and ord(token[0]) == 266
16
+
17
  def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
18
  words: list[Word] = []
19
  current_word: list[int] = []
 
21
  current_word_first_token_index: int = 0
22
  all_tokens: list[int] = [token_id for token_id, _ in token_probs]
23
 
24
+ def append_word(word):
25
+ if word:
26
+ words.append(Word(word,
27
+ tokenizer.decode(word),
28
  sum(current_log_probs),
29
  all_tokens[:current_word_first_token_index]))
30
 
31
  for i, (token_id, logprob) in enumerate(token_probs):
32
  token: str = tokenizer.convert_ids_to_tokens([token_id])[0]
33
+ token_str = tokenizer.decode([token_id])
34
+ print(f"-- {token_id=} {token=} {token_str=} {token_str.isalpha()=} {token_str.isspace()=}")
35
+ if (not starts_with_space(token) and token_str.isalpha()):
36
  current_word.append(token_id)
37
  current_log_probs.append(logprob)
38
  else:
39
+ append_word(current_word)
40
  current_word = [token_id]
41
  current_log_probs = [logprob]
42
  current_word_first_token_index = i
43
+ if is_newline(token):
44
+ append_word(current_word)
45
+ current_word = []
46
+ current_log_probs = []
47
+ current_word_first_token_index = i
48
 
49
+ append_word(current_word)
50
 
51
  return words
52
 
frontend/src/components/App.tsx CHANGED
@@ -22,6 +22,7 @@ export default function App() {
22
  const [wordlist, setWordlist] = useState("")
23
  const [showWholePrompt, setShowWholePrompt] = useState(false)
24
  const [text, setText] = useState("I just drove to the store to but eggs, but they had some.")
 
25
  const [mode, setMode] = useState<"edit" | "check">("edit")
26
  const [words, setWords] = useState<Word[]>([])
27
  const [isLoading, setIsLoading] = useState(false)
 
22
  const [wordlist, setWordlist] = useState("")
23
  const [showWholePrompt, setShowWholePrompt] = useState(false)
24
  const [text, setText] = useState("I just drove to the store to but eggs, but they had some.")
25
+ // const [text, setText] = useState("1\n2\n3\n4\n5\n")
26
  const [mode, setMode] = useState<"edit" | "check">("edit")
27
  const [words, setWords] = useState<Word[]>([])
28
  const [isLoading, setIsLoading] = useState(false)
frontend/src/components/WordChip.tsx CHANGED
@@ -44,6 +44,22 @@ export function WordChip({
44
  setIsExpanded(false); // Close the dropdown
45
  };
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  return (
48
  <span
49
  title={logprob.toFixed(2)}
@@ -51,7 +67,9 @@ export function WordChip({
51
  style={{ position: "relative", cursor: logprob < threshold ? "pointer" : "default" }}
52
  onClick={handleClick}
53
  >
54
- {word}
 
 
55
  {isExpanded && (
56
  <div
57
  ref={dropdownRef}
 
44
  setIsExpanded(false); // Close the dropdown
45
  };
46
 
47
+ console.log(`word: ->${word}<-`);
48
+
49
+ let w1;
50
+ let w2;
51
+ let w3;
52
+ // if word contains a newline, render a <br />
53
+ if (word.includes("\n")) {
54
+ [w1, w3] = word.split("\n");
55
+ w2 = "\n";
56
+ console.log(`split: ${w1} | ${w2} | ${w3}`);
57
+ } else {
58
+ w1 = word;
59
+ w2 = "";
60
+ w3 = "";
61
+ }
62
+
63
  return (
64
  <span
65
  title={logprob.toFixed(2)}
 
67
  style={{ position: "relative", cursor: logprob < threshold ? "pointer" : "default" }}
68
  onClick={handleClick}
69
  >
70
+ {w1}
71
+ {w2 && <br />}
72
+ {w3}
73
  {isExpanded && (
74
  <div
75
  ref={dropdownRef}
main.py CHANGED
@@ -34,3 +34,5 @@ def check(text: str):
34
  return CheckResponse(text=text, words=cached_check_text(text))
35
 
36
  app.mount("/", StaticFiles(directory="frontend/public", html=True))
 
 
 
34
  return CheckResponse(text=text, words=cached_check_text(text))
35
 
36
  app.mount("/", StaticFiles(directory="frontend/public", html=True))
37
+
38
+ #%%
pyproject.toml CHANGED
@@ -5,11 +5,13 @@ description = "Add your description here"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
 
8
  "fastapi[standard]>=0.115.2",
9
  "huggingface-hub>=0.25.2",
10
  "ipykernel>=6.29.5",
11
  "ipywidgets>=8.1.5",
12
  "openai>=1.51.2",
 
13
  "torch>=2.4.1",
14
  "transformers>=4.45.2",
15
  ]
 
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
+ "datasets>=3.0.1",
9
  "fastapi[standard]>=0.115.2",
10
  "huggingface-hub>=0.25.2",
11
  "ipykernel>=6.29.5",
12
  "ipywidgets>=8.1.5",
13
  "openai>=1.51.2",
14
+ "pytest>=8.3.3",
15
  "torch>=2.4.1",
16
  "transformers>=4.45.2",
17
  ]
requirements.txt CHANGED
@@ -5,3 +5,4 @@ transformers
5
  torch
6
  huggingface_hub
7
  fastapi[standard]
 
 
5
  torch
6
  huggingface_hub
7
  fastapi[standard]
8
+ pytest