mebubo commited on
Commit
c88ac20
·
1 Parent(s): 0a795e1
Files changed (1) hide show
  1. completions_test.py +27 -0
completions_test.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from completions import calculate_log_probabilities, load_model, tokenize, split_into_words
2
+
3
+ model, tokenizer, device = load_model()
4
+
5
+ def test_text_to_words():
6
+ text = """Hello
7
+ world!"""
8
+ token_probs = calculate_log_probabilities(model, tokenizer, tokenize(text, tokenizer, device))
9
+ words = split_into_words(token_probs, tokenizer)
10
+ expected_words = ["Hello", "\n", "world", "!"]
11
+ assert [w.text for w in words] == expected_words
12
+
13
+ def test_multiline():
14
+ text = """// Context: C code from an image manipulation library.
15
+ for (int y = 0; y < HEIGHT; y++) {
16
+ for (int x = 0; x < WIDTH; x++) {
17
+ buf[y * HEIGHT + x] = 0;
18
+ }
19
+ }"""
20
+ tokenized = tokenize(text, tokenizer, device)
21
+ print(tokenized)
22
+ token_probs = calculate_log_probabilities(model, tokenizer, tokenized)
23
+ words = split_into_words(token_probs, tokenizer)
24
+ print("---", [w.text for w in words])
25
+ expected_words = ["//", " Context", ":", " C", " code", " from", " an", " image", " manipulation", " library", ".\n",
26
+ "for", "(", "int", "y", "=", "0", ";", "y", "<", "HEIGHT", ";", "y", "+", "+", ")", "{", "\n", " ", "for", "(", "int", "x", "=", "0", ";", "x", "<", "WIDTH", ";", "x", "+", "+", ")", "{", "\n", " ", "buf", "[", "y", "*", "HEIGHT", "+", "x", "]", "=", "0", ";", "\n", " ", "}", "\n", "}"]
27
+ assert [w.text for w in words] == expected_words