Update app.py
Browse files
app.py
CHANGED
@@ -1,632 +1,248 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
"""
|
3 |
-
AG-BPE
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
7 |
|
|
|
|
|
|
|
8 |
import json
|
9 |
import regex as re
|
10 |
from pathlib import Path
|
11 |
-
from typing import List, Dict, Tuple
|
12 |
import unicodedata
|
13 |
import gradio as gr
|
14 |
import html
|
15 |
import math
|
16 |
-
import time
|
17 |
-
import sys
|
18 |
-
from collections import Counter, defaultdict
|
19 |
-
from functools import lru_cache
|
20 |
-
import numpy as np
|
21 |
-
from dataclasses import dataclass, asdict
|
22 |
-
import logging
|
23 |
-
|
24 |
-
# Configure logging
|
25 |
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
26 |
-
logger = logging.getLogger(__name__)
|
27 |
-
|
28 |
-
# --- Enhanced Metrics Dataclass ---
|
29 |
-
@dataclass
|
30 |
-
class TokenizerMetrics:
|
31 |
-
"""Comprehensive metrics for tokenizer evaluation"""
|
32 |
-
tokenizer_name: str = "AG-BPE"
|
33 |
-
vocab_size: int = 0
|
34 |
-
vocab_size_kb: float = 0.0
|
35 |
-
compression: float = 0.0
|
36 |
-
effectiveness_per_kb: float = 0.0
|
37 |
-
avg_len: float = 0.0
|
38 |
-
oov_rate: float = 0.0
|
39 |
-
enc_speed_ms: float = 0.0
|
40 |
-
dec_speed_ms: float = 0.0
|
41 |
-
throughput_chars_s: float = 0.0
|
42 |
-
entropy: float = 0.0
|
43 |
-
fertility: float = 0.0
|
44 |
-
robustness_score: float = 0.0
|
45 |
-
|
46 |
-
def to_dict(self) -> Dict[str, Any]:
|
47 |
-
"""Convert metrics to dictionary with rounded values"""
|
48 |
-
return {
|
49 |
-
"Tokenizer": self.tokenizer_name,
|
50 |
-
"Vocab Size": f"{self.vocab_size:,}",
|
51 |
-
"Vocab Size (KB)": f"{self.vocab_size_kb:.2f}",
|
52 |
-
"Compression": f"{self.compression:.3f}",
|
53 |
-
"Effectiveness/KB": f"{self.effectiveness_per_kb:.3f}",
|
54 |
-
"Avg Len": f"{self.avg_len:.2f}",
|
55 |
-
"OOV Rate (%)": f"{self.oov_rate:.2f}",
|
56 |
-
"Enc Speed (ms)": f"{self.enc_speed_ms:.3f}",
|
57 |
-
"Dec Speed (ms)": f"{self.dec_speed_ms:.3f}",
|
58 |
-
"Throughput (chars/s)": f"{self.throughput_chars_s:,.0f}",
|
59 |
-
"Entropy": f"{self.entropy:.3f}",
|
60 |
-
"Fertility": f"{self.fertility:.3f}",
|
61 |
-
"Robustness Score": f"{self.robustness_score:.2f}"
|
62 |
-
}
|
63 |
|
64 |
-
# ---
|
65 |
class TextCleaner:
|
66 |
-
"""
|
67 |
-
|
68 |
UNWANTED_CHARS = {
|
69 |
'\ufffd', '\u200b', '\u200c', '\u200d', '\u2060', '\u2061', '\u2063',
|
70 |
'\u00a0', '\u202f', '\u2007', '\u2028', '\u2029', '\ufeff', '\ue000',
|
71 |
'\uf8ff', '\ue001', '\xad', '\u180e', '\u200e', '\uFE0F',
|
72 |
}
|
73 |
-
|
74 |
-
_cache = {}
|
75 |
-
_max_cache_size = 1000
|
76 |
|
77 |
@classmethod
|
78 |
-
@lru_cache(maxsize=512)
|
79 |
def clean_text(cls, text: str) -> str:
|
80 |
-
"""Cleans
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
cleaned = cleaned.replace('"', '"').replace('"', '"')
|
92 |
-
|
93 |
-
# Remove unwanted chars
|
94 |
-
for char in cls.UNWANTED_CHARS:
|
95 |
-
cleaned = cleaned.replace(char, '')
|
96 |
-
|
97 |
-
# Filter control chars
|
98 |
-
cleaned = ''.join(c for c in cleaned if ord(c) >= 32 or c in '\n\r\t')
|
99 |
-
|
100 |
-
# Collapse whitespace
|
101 |
-
cleaned = re.sub(r'\s+', ' ', cleaned)
|
102 |
-
cleaned = cleaned.strip()
|
103 |
-
|
104 |
-
# Update cache
|
105 |
-
if len(cls._cache) < cls._max_cache_size:
|
106 |
-
cls._cache[text] = cleaned
|
107 |
-
|
108 |
-
return cleaned
|
109 |
-
except Exception as e:
|
110 |
-
logger.warning(f"Text cleaning failed: {e}")
|
111 |
-
return text
|
112 |
-
|
113 |
-
# --- Enhanced Tokenizer ---
|
114 |
class AGBPETokenizer:
|
115 |
-
"""
|
116 |
-
|
|
|
|
|
|
|
117 |
def __init__(self, vocab: Dict[str, int], merges: Dict[str, int], special_tokens: Dict[str, int]):
|
118 |
-
"""
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
if self.
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
self.text_cleaner = TextCleaner()
|
137 |
-
|
138 |
-
# Performance caching
|
139 |
-
self._encode_cache = {}
|
140 |
-
self._max_cache_entries = 500
|
141 |
-
|
142 |
-
# Calculate vocab size in KB
|
143 |
-
self.vocab_size_kb = sys.getsizeof(json.dumps(self.vocab)) / 1024
|
144 |
-
|
145 |
-
# Stats tracking
|
146 |
-
self.total_tokens_encoded = 0
|
147 |
-
self.total_oov_tokens = 0
|
148 |
-
|
149 |
-
except Exception as e:
|
150 |
-
logger.error(f"Tokenizer initialization failed: {e}")
|
151 |
-
raise
|
152 |
|
153 |
@classmethod
|
154 |
def from_file(cls, filepath: str) -> 'AGBPETokenizer':
|
155 |
-
"""
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
missing_keys = [k for k in required_keys if k not in data]
|
166 |
-
if missing_keys:
|
167 |
-
raise ValueError(f"Missing required keys: {missing_keys}")
|
168 |
-
|
169 |
-
logger.info(f"Successfully loaded tokenizer from {filepath}")
|
170 |
-
return cls(data['vocab'], data['merges'], data['special_tokens'])
|
171 |
-
|
172 |
-
except json.JSONDecodeError as e:
|
173 |
-
logger.error(f"Invalid JSON in tokenizer file: {e}")
|
174 |
-
raise ValueError(f"Failed to parse JSON: {e}")
|
175 |
-
except Exception as e:
|
176 |
-
logger.error(f"Failed to load tokenizer: {e}")
|
177 |
-
raise
|
178 |
|
179 |
-
@lru_cache(maxsize=256)
|
180 |
def _find_best_vocab_match(self, text_chunk: str) -> List[int]:
|
181 |
-
"""
|
|
|
|
|
|
|
182 |
ids = []
|
183 |
i = 0
|
184 |
-
oov_count = 0
|
185 |
-
|
186 |
while i < len(text_chunk):
|
187 |
found_match = False
|
188 |
-
|
189 |
-
|
190 |
-
for j in range(min(len(text_chunk), i + 50), i, -1): # Cap max token length
|
191 |
substring = text_chunk[i:j]
|
192 |
if substring in self.vocab:
|
193 |
ids.append(self.vocab[substring])
|
194 |
-
i = j
|
195 |
found_match = True
|
196 |
-
break
|
197 |
|
198 |
if not found_match:
|
|
|
|
|
199 |
ids.append(self.unk_token_id)
|
200 |
-
oov_count += 1
|
201 |
i += 1
|
202 |
-
|
203 |
-
# Track OOV stats
|
204 |
-
if oov_count > 0:
|
205 |
-
self.total_oov_tokens += oov_count
|
206 |
-
|
207 |
return ids
|
208 |
|
209 |
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
|
210 |
-
"""
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
if add_special_tokens and (bos_id := self.special_tokens_map.get('<bos>')) is not None:
|
223 |
-
token_ids.append(bos_id)
|
224 |
-
|
225 |
-
# Tokenize chunks
|
226 |
-
for chunk in self.pat.findall(cleaned_text):
|
227 |
-
chunk_ids = self._find_best_vocab_match(chunk)
|
228 |
-
token_ids.extend(chunk_ids)
|
229 |
-
self.total_tokens_encoded += len(chunk_ids)
|
230 |
-
|
231 |
-
# Add EOS token
|
232 |
-
if add_special_tokens and (eos_id := self.special_tokens_map.get('<eos>')) is not None:
|
233 |
-
token_ids.append(eos_id)
|
234 |
-
|
235 |
-
# Update cache
|
236 |
-
if len(self._encode_cache) < self._max_cache_entries:
|
237 |
-
self._encode_cache[cache_key] = token_ids
|
238 |
|
239 |
-
|
|
|
240 |
|
241 |
-
|
242 |
-
logger.error(f"Encoding failed: {e}")
|
243 |
-
return [self.unk_token_id]
|
244 |
|
245 |
def decode(self, token_ids: List[int]) -> str:
|
246 |
-
"""
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
for token_id in token_ids:
|
252 |
-
if token_id not in special_ids:
|
253 |
-
token = self.id_to_token.get(token_id, f"<UNK_{token_id}>")
|
254 |
-
tokens.append(token)
|
255 |
-
|
256 |
-
return "".join(tokens)
|
257 |
-
|
258 |
-
except Exception as e:
|
259 |
-
logger.error(f"Decoding failed: {e}")
|
260 |
-
return ""
|
261 |
|
262 |
-
def calculate_metrics(self, text: str) -> TokenizerMetrics:
|
263 |
-
"""Calculate comprehensive tokenizer metrics"""
|
264 |
-
try:
|
265 |
-
metrics = TokenizerMetrics(tokenizer_name="AG-BPE v4")
|
266 |
-
|
267 |
-
# Basic vocab metrics
|
268 |
-
metrics.vocab_size = len(self.vocab)
|
269 |
-
metrics.vocab_size_kb = self.vocab_size_kb
|
270 |
-
|
271 |
-
if not text:
|
272 |
-
return metrics
|
273 |
-
|
274 |
-
# Timing metrics
|
275 |
-
start_time = time.perf_counter()
|
276 |
-
encoded = self.encode(text, add_special_tokens=False)
|
277 |
-
enc_time = (time.perf_counter() - start_time) * 1000
|
278 |
-
|
279 |
-
start_time = time.perf_counter()
|
280 |
-
decoded = self.decode(encoded)
|
281 |
-
dec_time = (time.perf_counter() - start_time) * 1000
|
282 |
-
|
283 |
-
metrics.enc_speed_ms = enc_time
|
284 |
-
metrics.dec_speed_ms = dec_time
|
285 |
-
|
286 |
-
# Throughput
|
287 |
-
if enc_time > 0:
|
288 |
-
metrics.throughput_chars_s = (len(text) / enc_time) * 1000
|
289 |
-
|
290 |
-
# Token statistics
|
291 |
-
tokens = [self.id_to_token.get(i, "") for i in encoded]
|
292 |
-
if tokens:
|
293 |
-
token_lengths = [len(t) for t in tokens]
|
294 |
-
metrics.avg_len = np.mean(token_lengths)
|
295 |
-
|
296 |
-
# Compression ratio
|
297 |
-
original_bytes = len(text.encode('utf-8'))
|
298 |
-
token_bytes = len(encoded) * 2 # Assuming 2 bytes per token ID
|
299 |
-
metrics.compression = original_bytes / max(token_bytes, 1)
|
300 |
-
|
301 |
-
# Effectiveness per KB
|
302 |
-
metrics.effectiveness_per_kb = metrics.compression / max(metrics.vocab_size_kb, 0.001)
|
303 |
-
|
304 |
-
# OOV Rate
|
305 |
-
oov_count = sum(1 for tid in encoded if tid == self.unk_token_id)
|
306 |
-
metrics.oov_rate = (oov_count / len(encoded)) * 100 if encoded else 0
|
307 |
-
|
308 |
-
# Entropy (token distribution)
|
309 |
-
token_counts = Counter(encoded)
|
310 |
-
total = sum(token_counts.values())
|
311 |
-
probs = [count/total for count in token_counts.values()]
|
312 |
-
metrics.entropy = -sum(p * math.log2(p) for p in probs if p > 0)
|
313 |
-
|
314 |
-
# Fertility (avg tokens per word)
|
315 |
-
words = text.split()
|
316 |
-
metrics.fertility = len(encoded) / max(len(words), 1)
|
317 |
-
|
318 |
-
# Robustness Score (composite metric)
|
319 |
-
metrics.robustness_score = min(100, (
|
320 |
-
(100 - metrics.oov_rate) * 0.4 + # Low OOV is good
|
321 |
-
min(metrics.compression * 10, 40) + # Good compression
|
322 |
-
min(metrics.effectiveness_per_kb * 10, 20) # Efficiency
|
323 |
-
))
|
324 |
-
|
325 |
-
return metrics
|
326 |
-
|
327 |
-
except Exception as e:
|
328 |
-
logger.error(f"Metrics calculation failed: {e}")
|
329 |
-
return TokenizerMetrics()
|
330 |
|
331 |
-
# ---
|
332 |
|
333 |
TOKENIZER_FILE = "ag_bpe_tokenizer_v4.json"
|
334 |
-
|
|
|
|
|
335 |
|
336 |
-
# Initialize tokenizer
|
337 |
try:
|
338 |
if not Path(TOKENIZER_FILE).exists():
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
"
|
343 |
-
"
|
344 |
-
"g": 20, "w": 21, "y": 22, "b": 23, "v": 24, "k": 25, "x": 26,
|
345 |
-
"j": 27, "q": 28, "z": 29, "th": 30, "he": 31, "in": 32,
|
346 |
-
"er": 33, "an": 34, " the": 35, "ing": 36, "ed": 37, "and": 38,
|
347 |
-
"to": 39, "of": 40, "is": 41, "it": 42, "for": 43, "as": 44,
|
348 |
-
"with": 45, "was": 46, "that": 47, "be": 48, "on": 49,
|
349 |
-
"Hello": 50, " world": 51, "AI": 52, "test": 53, "code": 54
|
350 |
-
}
|
351 |
-
|
352 |
-
demo_data = {
|
353 |
-
"vocab": demo_vocab,
|
354 |
-
"merges": {"t h": 30, "h e": 31, "i n": 32},
|
355 |
"special_tokens": {"<unk>": 0, "<bos>": 1, "<eos>": 2}
|
356 |
}
|
357 |
-
|
358 |
with open(TOKENIZER_FILE, 'w', encoding='utf-8') as f:
|
359 |
-
json.dump(
|
360 |
-
|
361 |
-
|
|
|
362 |
tokenizer = AGBPETokenizer.from_file(TOKENIZER_FILE)
|
363 |
-
|
364 |
-
|
365 |
-
except Exception as e:
|
366 |
-
logger.error(f"❌ Failed to initialize tokenizer: {e}")
|
367 |
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
return (
|
372 |
-
"<div style='color: red; font-weight: bold;'>⚠️ Tokenizer not loaded!</div>",
|
373 |
-
"<div style='color: red;'>No metrics available</div>",
|
374 |
-
{}
|
375 |
-
)
|
376 |
-
|
377 |
-
if not text:
|
378 |
-
return (
|
379 |
-
"<div style='color: #888; padding: 20px;'>✍️ Enter text to see tokenization...</div>",
|
380 |
-
"<div style='color: #888;'>Waiting for input...</div>",
|
381 |
-
{}
|
382 |
-
)
|
383 |
-
|
384 |
-
try:
|
385 |
-
# Get tokens
|
386 |
-
encoded_ids = tokenizer.encode(text, add_special_tokens=False)
|
387 |
-
tokens = [tokenizer.id_to_token.get(i, f"<UNK_{i}>") for i in encoded_ids]
|
388 |
-
|
389 |
-
# Calculate metrics
|
390 |
-
metrics = tokenizer.calculate_metrics(text)
|
391 |
-
metrics_dict = metrics.to_dict()
|
392 |
-
|
393 |
-
# Generate visualization HTML
|
394 |
-
html_tokens = generate_token_html(tokens, encoded_ids)
|
395 |
-
|
396 |
-
# Generate stats HTML
|
397 |
-
html_stats = generate_stats_html(metrics_dict)
|
398 |
-
|
399 |
-
return html_tokens, html_stats, metrics_dict
|
400 |
-
|
401 |
-
except Exception as e:
|
402 |
-
logger.error(f"Processing failed: {e}")
|
403 |
-
return (
|
404 |
-
f"<div style='color: red;'>❌ Error: {str(e)}</div>",
|
405 |
-
"<div style='color: red;'>Error calculating metrics</div>",
|
406 |
-
{}
|
407 |
-
)
|
408 |
|
409 |
-
def generate_token_html(tokens: List[str], token_ids: List[int]) -> str:
|
410 |
-
"""Generate beautiful token visualization"""
|
411 |
-
gradients = [
|
412 |
-
"linear-gradient(135deg, #667eea 0%, #764ba2 100%)",
|
413 |
-
"linear-gradient(135deg, #f093fb 0%, #f5576c 100%)",
|
414 |
-
"linear-gradient(135deg, #4facfe 0%, #00f2fe 100%)",
|
415 |
-
"linear-gradient(135deg, #43e97b 0%, #38f9d7 100%)",
|
416 |
-
"linear-gradient(135deg, #fa709a 0%, #fee140 100%)",
|
417 |
-
"linear-gradient(135deg, #30cfd0 0%, #330867 100%)",
|
418 |
-
"linear-gradient(135deg, #a8edea 0%, #fed6e3 100%)"
|
419 |
-
]
|
420 |
-
|
421 |
-
html = """
|
422 |
-
<div style='
|
423 |
-
display: flex;
|
424 |
-
flex-wrap: wrap;
|
425 |
-
gap: 8px;
|
426 |
-
padding: 20px;
|
427 |
-
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
|
428 |
-
border-radius: 12px;
|
429 |
-
box-shadow: 0 10px 30px rgba(0,0,0,0.1);
|
430 |
-
'>
|
431 |
-
"""
|
432 |
-
|
433 |
-
for i, (token, tid) in enumerate(zip(tokens, token_ids)):
|
434 |
-
gradient = gradients[i % len(gradients)]
|
435 |
-
safe_token = html.escape(token)
|
436 |
-
|
437 |
-
# Special styling for UNK tokens
|
438 |
-
is_unk = tid == tokenizer.unk_token_id if tokenizer else False
|
439 |
-
border_style = "3px solid #ff4444" if is_unk else "1px solid rgba(255,255,255,0.3)"
|
440 |
-
|
441 |
-
html += f"""
|
442 |
-
<div style='
|
443 |
-
display: flex;
|
444 |
-
flex-direction: column;
|
445 |
-
align-items: center;
|
446 |
-
padding: 12px 16px;
|
447 |
-
background: {gradient};
|
448 |
-
border-radius: 10px;
|
449 |
-
border: {border_style};
|
450 |
-
box-shadow: 0 4px 15px rgba(0,0,0,0.1);
|
451 |
-
transition: all 0.3s ease;
|
452 |
-
cursor: pointer;
|
453 |
-
' onmouseover='this.style.transform="scale(1.05)"' onmouseout='this.style.transform="scale(1)"'>
|
454 |
-
<span style='
|
455 |
-
color: white;
|
456 |
-
font-size: 16px;
|
457 |
-
font-weight: 600;
|
458 |
-
text-shadow: 0 1px 3px rgba(0,0,0,0.2);
|
459 |
-
white-space: pre-wrap;
|
460 |
-
max-width: 150px;
|
461 |
-
overflow: hidden;
|
462 |
-
text-overflow: ellipsis;
|
463 |
-
'>{safe_token}</span>
|
464 |
-
<span style='
|
465 |
-
color: white;
|
466 |
-
font-size: 12px;
|
467 |
-
font-weight: bold;
|
468 |
-
margin-top: 6px;
|
469 |
-
background: rgba(0,0,0,0.2);
|
470 |
-
padding: 2px 8px;
|
471 |
-
border-radius: 12px;
|
472 |
-
'>#{tid}</span>
|
473 |
-
</div>
|
474 |
-
"""
|
475 |
-
|
476 |
-
html += "</div>"
|
477 |
-
return html
|
478 |
|
479 |
-
def
|
480 |
-
"""Generate beautiful metrics display"""
|
481 |
-
html = """
|
482 |
-
<div style='
|
483 |
-
display: grid;
|
484 |
-
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
485 |
-
gap: 15px;
|
486 |
-
padding: 20px;
|
487 |
-
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
488 |
-
border-radius: 12px;
|
489 |
-
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
|
490 |
-
'>
|
491 |
"""
|
492 |
-
|
493 |
-
|
494 |
-
"Vocab Size": "📚", "Compression": "🗜️", "Avg Len": "📏",
|
495 |
-
"OOV Rate (%)": "❓", "Enc Speed (ms)": "⚡", "Dec Speed (ms)": "⏱️",
|
496 |
-
"Throughput (chars/s)": "🚀", "Entropy": "🌡️", "Fertility": "🌱",
|
497 |
-
"Robustness Score": "💪", "Effectiveness/KB": "📊", "Vocab Size (KB)": "💾"
|
498 |
-
}
|
499 |
-
|
500 |
-
for key, value in metrics.items():
|
501 |
-
if key == "Tokenizer":
|
502 |
-
continue
|
503 |
-
icon = icons.get(key, "📈")
|
504 |
-
|
505 |
-
html += f"""
|
506 |
-
<div style='
|
507 |
-
background: rgba(255,255,255,0.95);
|
508 |
-
padding: 15px;
|
509 |
-
border-radius: 8px;
|
510 |
-
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
511 |
-
transition: all 0.3s ease;
|
512 |
-
' onmouseover='this.style.transform="translateY(-2px)"' onmouseout='this.style.transform="translateY(0)"'>
|
513 |
-
<div style='
|
514 |
-
display: flex;
|
515 |
-
align-items: center;
|
516 |
-
margin-bottom: 8px;
|
517 |
-
'>
|
518 |
-
<span style='font-size: 24px; margin-right: 8px;'>{icon}</span>
|
519 |
-
<span style='
|
520 |
-
color: #4a5568;
|
521 |
-
font-size: 12px;
|
522 |
-
font-weight: 600;
|
523 |
-
text-transform: uppercase;
|
524 |
-
'>{key}</span>
|
525 |
-
</div>
|
526 |
-
<div style='
|
527 |
-
color: #1a202c;
|
528 |
-
font-size: 20px;
|
529 |
-
font-weight: bold;
|
530 |
-
'>{value}</div>
|
531 |
-
</div>
|
532 |
-
"""
|
533 |
-
|
534 |
-
html += "</div>"
|
535 |
-
return html
|
536 |
-
|
537 |
-
# Create Gradio interface
|
538 |
-
with gr.Blocks(
|
539 |
-
theme=gr.themes.Soft(
|
540 |
-
primary_hue="purple",
|
541 |
-
secondary_hue="blue",
|
542 |
-
font=gr.themes.GoogleFont("Inter")
|
543 |
-
),
|
544 |
-
css="""
|
545 |
-
.gradio-container {
|
546 |
-
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
547 |
-
}
|
548 |
-
footer {display: none !important;}
|
549 |
-
.gr-button {
|
550 |
-
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
551 |
-
border: none;
|
552 |
-
color: white;
|
553 |
-
}
|
554 |
-
.gr-button:hover {
|
555 |
-
transform: scale(1.05);
|
556 |
-
box-shadow: 0 5px 20px rgba(0,0,0,0.3);
|
557 |
-
}
|
558 |
"""
|
559 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
560 |
gr.Markdown(
|
561 |
"""
|
562 |
-
#
|
563 |
-
|
564 |
-
|
565 |
-
<div style='
|
566 |
-
background: rgba(255,255,255,0.9);
|
567 |
-
padding: 15px;
|
568 |
-
border-radius: 8px;
|
569 |
-
margin: 10px 0;
|
570 |
-
'>
|
571 |
-
<b>✨ Features:</b> Longest-match tokenization • Real-time metrics • Performance analysis • Beautiful visualization
|
572 |
-
</div>
|
573 |
"""
|
574 |
)
|
575 |
|
576 |
-
with gr.
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
max_lines=10,
|
583 |
-
autofocus=True
|
584 |
-
)
|
585 |
-
|
586 |
-
gr.Examples(
|
587 |
-
examples=[
|
588 |
-
"The quick brown fox jumps over the lazy dog.",
|
589 |
-
"Artificial Intelligence is revolutionizing technology! 🚀",
|
590 |
-
"def fibonacci(n):\n return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)",
|
591 |
-
"Les mathématiques sont le langage de l'univers. 🌌",
|
592 |
-
"東京は日本の首都です。人口は約1400万人です。",
|
593 |
-
"Blockchain technology enables decentralized systems.",
|
594 |
-
],
|
595 |
-
inputs=input_text,
|
596 |
-
label="💡 Quick Examples"
|
597 |
-
)
|
598 |
-
|
599 |
-
with gr.Row():
|
600 |
-
output_viz = gr.HTML(
|
601 |
-
label="🎨 Token Visualization",
|
602 |
-
value="<div style='padding: 40px; text-align: center; color: #888;'>Waiting for input...</div>"
|
603 |
-
)
|
604 |
-
|
605 |
-
with gr.Row():
|
606 |
-
output_stats = gr.HTML(
|
607 |
-
label="📊 Metrics Dashboard",
|
608 |
-
value="<div style='padding: 40px; text-align: center; color: #888;'>Metrics will appear here...</div>"
|
609 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
610 |
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
outputs=[output_viz, output_stats],
|
616 |
-
queue=True
|
617 |
)
|
618 |
|
619 |
-
gr.
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
|
|
628 |
)
|
629 |
|
630 |
if __name__ == "__main__":
|
631 |
-
demo.
|
632 |
-
demo.launch(share=False, debug=True)
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
"""
|
3 |
+
AG-BPE Standalone Usage Script & Web Visualizer
|
4 |
+
================================================
|
5 |
+
|
6 |
+
This script demonstrates how to load and use a pre-trained AG-BPE tokenizer
|
7 |
+
and provides a real-time web interface using Gradio to visualize its behavior.
|
8 |
|
9 |
+
This version has been modified to use a "longest-match" strategy directly on the
|
10 |
+
vocabulary, ignoring the BPE merge rules.
|
11 |
+
"""
|
12 |
import json
|
13 |
import regex as re
|
14 |
from pathlib import Path
|
15 |
+
from typing import List, Dict, Tuple
|
16 |
import unicodedata
|
17 |
import gradio as gr
|
18 |
import html
|
19 |
import math
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
# --- TextCleaner Class (Unchanged) ---
|
22 |
class TextCleaner:
|
23 |
+
"""A text cleaner for AI datasets, designed to remove invisible, abnormal, and disruptive characters."""
|
|
|
24 |
UNWANTED_CHARS = {
|
25 |
'\ufffd', '\u200b', '\u200c', '\u200d', '\u2060', '\u2061', '\u2063',
|
26 |
'\u00a0', '\u202f', '\u2007', '\u2028', '\u2029', '\ufeff', '\ue000',
|
27 |
'\uf8ff', '\ue001', '\xad', '\u180e', '\u200e', '\uFE0F',
|
28 |
}
|
|
|
|
|
|
|
29 |
|
30 |
@classmethod
|
|
|
31 |
def clean_text(cls, text: str) -> str:
|
32 |
+
"""Cleans a given string by normalizing it, removing unwanted characters, and collapsing whitespace."""
|
33 |
+
text = unicodedata.normalize("NFKC", text)
|
34 |
+
text = text.replace('’', "'").replace('‘', "'")
|
35 |
+
text = text.replace('“', '"').replace('”', '"')
|
36 |
+
for char in cls.UNWANTED_CHARS:
|
37 |
+
text = text.replace(char, '')
|
38 |
+
text = ''.join(c for c in text if ord(c) >= 32 or c in '\n\r\t')
|
39 |
+
text = re.sub(r'\s+', ' ', text)
|
40 |
+
return text.strip()
|
41 |
+
|
42 |
+
# --- Standalone Tokenizer Class (Logic Changed) ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
class AGBPETokenizer:
|
44 |
+
"""
|
45 |
+
A self-contained tokenizer that loads a pre-trained model from a JSON file.
|
46 |
+
MODIFIED: This version uses a greedy longest-match algorithm on the vocabulary,
|
47 |
+
ignoring any BPE merge rules.
|
48 |
+
"""
|
49 |
def __init__(self, vocab: Dict[str, int], merges: Dict[str, int], special_tokens: Dict[str, int]):
|
50 |
+
"""Initializes the tokenizer from loaded vocabulary and merge data."""
|
51 |
+
self.vocab = vocab
|
52 |
+
# self.merges is no longer used, but kept for file loading compatibility
|
53 |
+
self.special_tokens_map = special_tokens
|
54 |
+
self.id_to_token: Dict[int, str] = {i: s for s, i in self.vocab.items()}
|
55 |
+
|
56 |
+
self.pat = re.compile(r'\s*\S+')
|
57 |
+
|
58 |
+
self.unk_token_id = self.vocab.get('<unk>')
|
59 |
+
if self.unk_token_id is None:
|
60 |
+
# Fallback for vocabularies without <unk>
|
61 |
+
if self.vocab:
|
62 |
+
self.unk_token_id = next(iter(self.vocab.values()))
|
63 |
+
print(f"Warning: '<unk>' token not found. Using first token as fallback (ID: {self.unk_token_id}).")
|
64 |
+
else:
|
65 |
+
raise ValueError("The vocabulary is empty and '<unk>' token is missing.")
|
66 |
+
|
67 |
+
self.text_cleaner = TextCleaner()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
@classmethod
|
70 |
def from_file(cls, filepath: str) -> 'AGBPETokenizer':
|
71 |
+
"""Class method to conveniently load a tokenizer from a JSON file path."""
|
72 |
+
path = Path(filepath)
|
73 |
+
if not path.exists():
|
74 |
+
raise FileNotFoundError(f"Tokenizer file not found: '{filepath}'")
|
75 |
+
with open(path, 'r', encoding='utf-8') as f:
|
76 |
+
data = json.load(f)
|
77 |
+
required_keys = ['vocab', 'merges', 'special_tokens']
|
78 |
+
if not all(key in data for key in required_keys):
|
79 |
+
raise ValueError("The JSON file is malformed. Missing one of: vocab, merges, special_tokens.")
|
80 |
+
return cls(data['vocab'], data['merges'], data['special_tokens'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
|
|
82 |
def _find_best_vocab_match(self, text_chunk: str) -> List[int]:
|
83 |
+
"""
|
84 |
+
Tokenizes a chunk of text by greedily finding the longest possible
|
85 |
+
substring that exists in the vocabulary.
|
86 |
+
"""
|
87 |
ids = []
|
88 |
i = 0
|
|
|
|
|
89 |
while i < len(text_chunk):
|
90 |
found_match = False
|
91 |
+
# Search for the longest possible match from current position
|
92 |
+
for j in range(len(text_chunk), i, -1):
|
|
|
93 |
substring = text_chunk[i:j]
|
94 |
if substring in self.vocab:
|
95 |
ids.append(self.vocab[substring])
|
96 |
+
i = j # Move pointer to the end of the match
|
97 |
found_match = True
|
98 |
+
break # Exit the inner loop to continue from the new position
|
99 |
|
100 |
if not found_match:
|
101 |
+
# If no match was found (not even a single character),
|
102 |
+
# use the unknown token and advance by one character.
|
103 |
ids.append(self.unk_token_id)
|
|
|
104 |
i += 1
|
|
|
|
|
|
|
|
|
|
|
105 |
return ids
|
106 |
|
107 |
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
|
108 |
+
"""Encodes a string of text into a list of token IDs."""
|
109 |
+
cleaned_text = self.text_cleaner.clean_text(text)
|
110 |
+
token_ids = []
|
111 |
+
|
112 |
+
if add_special_tokens and (bos_id := self.special_tokens_map.get('<bos>')) is not None:
|
113 |
+
token_ids.append(bos_id)
|
114 |
+
|
115 |
+
# Pre-tokenize the text into chunks (words and their preceding spaces)
|
116 |
+
for chunk in self.pat.findall(cleaned_text):
|
117 |
+
# Apply the new longest-match algorithm on each chunk
|
118 |
+
chunk_ids = self._find_best_vocab_match(chunk)
|
119 |
+
token_ids.extend(chunk_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
+
if add_special_tokens and (eos_id := self.special_tokens_map.get('<eos>')) is not None:
|
122 |
+
token_ids.append(eos_id)
|
123 |
|
124 |
+
return token_ids
|
|
|
|
|
125 |
|
126 |
def decode(self, token_ids: List[int]) -> str:
|
127 |
+
"""Decodes a list of token IDs back into a string of text."""
|
128 |
+
special_ids_to_skip = set(self.special_tokens_map.values())
|
129 |
+
tokens = [self.id_to_token.get(token_id, '') for token_id in token_ids if token_id not in special_ids_to_skip]
|
130 |
+
return "".join(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
+
# --- Gradio Web Application (Unchanged) ---
|
134 |
|
135 |
TOKENIZER_FILE = "ag_bpe_tokenizer_v4.json"
|
136 |
+
TOKENIZER_LOADED = False
|
137 |
+
ERROR_MESSAGE = ""
|
138 |
+
tokenizer = None
|
139 |
|
|
|
140 |
try:
|
141 |
if not Path(TOKENIZER_FILE).exists():
|
142 |
+
print(f"⚠️ Warning: Tokenizer file '{TOKENIZER_FILE}' not found.")
|
143 |
+
print("Creating a dummy tokenizer file for local testing.")
|
144 |
+
dummy_data = {
|
145 |
+
"vocab": {"<unk>": 0, "<bos>": 1, "<eos>": 2, " comm": 3, "ent": 4, "?": 5, "Hello": 8, " world": 9, '"comm"': 10, " comment": 11},
|
146 |
+
"merges": {" c o m m": 1, "e n t": 2, " comment":3},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
"special_tokens": {"<unk>": 0, "<bos>": 1, "<eos>": 2}
|
148 |
}
|
|
|
149 |
with open(TOKENIZER_FILE, 'w', encoding='utf-8') as f:
|
150 |
+
json.dump(dummy_data, f, indent=2)
|
151 |
+
print("Dummy file created. The app will use this file.")
|
152 |
+
|
153 |
+
print(f"🧠 Loading tokenizer from '{TOKENIZER_FILE}'...")
|
154 |
tokenizer = AGBPETokenizer.from_file(TOKENIZER_FILE)
|
155 |
+
TOKENIZER_LOADED = True
|
156 |
+
print(f"✅ Tokenizer loaded successfully. Vocabulary size: {len(tokenizer.vocab)}")
|
|
|
|
|
157 |
|
158 |
+
except (FileNotFoundError, ValueError, KeyError) as e:
|
159 |
+
ERROR_MESSAGE = str(e)
|
160 |
+
print(f"❌ ERROR loading tokenizer: {ERROR_MESSAGE}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
+
def visualize_tokenization(text: str) -> Tuple[str, float, float, float]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
"""
|
165 |
+
Takes input text, tokenizes it, calculates stats, and returns
|
166 |
+
a styled HTML string and the statistics for display.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
"""
|
168 |
+
if not TOKENIZER_LOADED or not tokenizer:
|
169 |
+
error_html = f"<p style='color: red; font-weight: bold;'>TOKENIZER LOADING ERROR: {ERROR_MESSAGE}</p>"
|
170 |
+
return error_html, 0.0, 0.0, 0.0
|
171 |
+
|
172 |
+
if not text:
|
173 |
+
return "<p style='color: #888;'>Please enter some text to see the visualization...</p>", 0.0, 0.0, 0.0
|
174 |
+
|
175 |
+
encoded_ids = tokenizer.encode(text, add_special_tokens=False)
|
176 |
+
tokens = [tokenizer.id_to_token.get(i, f"<unk:{i}>") for i in encoded_ids]
|
177 |
+
|
178 |
+
# --- Calculate Statistics ---
|
179 |
+
avg_len, std_dev, ratio = 0.0, 0.0, 0.0
|
180 |
+
if tokens:
|
181 |
+
token_lengths = [len(t) for t in tokens]
|
182 |
+
avg_len = sum(token_lengths) / len(token_lengths)
|
183 |
+
if len(token_lengths) > 1:
|
184 |
+
variance = sum([(x - avg_len) ** 2 for x in token_lengths]) / (len(token_lengths) - 1)
|
185 |
+
std_dev = math.sqrt(variance)
|
186 |
+
if text:
|
187 |
+
ratio = len(tokens) / len(text)
|
188 |
+
|
189 |
+
# --- Generate HTML ---
|
190 |
+
colors = ["#dbeafe", "#dcfce7", "#fee2e2", "#fef3c7", "#f3e8ff", "#d1fae5", "#e0f2fe"]
|
191 |
+
html_output = "<div style='display: flex; flex-wrap: wrap; align-items: flex-start; font-family: sans-serif;'>"
|
192 |
+
|
193 |
+
for i, token_id in enumerate(encoded_ids):
|
194 |
+
safe_token_string = html.escape(tokens[i])
|
195 |
+
color = colors[i % len(colors)]
|
196 |
+
html_output += f"""
|
197 |
+
<div style="display: inline-flex; flex-direction: column; align-items: center; margin: 4px; padding: 8px 10px; border-radius: 8px; background-color: {color}; border: 1px solid rgba(0,0,0,0.1); box-shadow: 0 1px 3px rgba(0,0,0,0.05); text-align: center;">
|
198 |
+
<span style="font-size: 1.1em; font-weight: 500; color: #111827; white-space: pre-wrap;">{safe_token_string}</span>
|
199 |
+
<span style="font-size: 0.9em; font-weight: 700; color: #1e3a8a; margin-top: 5px; background-color: rgba(255,255,255,0.6); padding: 2px 6px; border-radius: 5px;">{token_id}</span>
|
200 |
+
</div>"""
|
201 |
+
html_output += "</div>"
|
202 |
+
|
203 |
+
return html_output, round(avg_len, 2), round(std_dev, 2), round(ratio, 3)
|
204 |
+
|
205 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky"), css="footer {display: none !important}") as demo:
|
206 |
gr.Markdown(
|
207 |
"""
|
208 |
+
# 👁️ Real-time Tokenizer Visualizer
|
209 |
+
Enter text in the field below to see the tokenization happen live.
|
210 |
+
Each colored card is a "token", with its corresponding numerical ID shown below it.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
"""
|
212 |
)
|
213 |
|
214 |
+
with gr.Column():
|
215 |
+
input_textbox = gr.Textbox(
|
216 |
+
label="Enter your text here",
|
217 |
+
placeholder="Type something...",
|
218 |
+
lines=5,
|
219 |
+
show_label=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
)
|
221 |
+
|
222 |
+
with gr.Row():
|
223 |
+
avg_len_box = gr.Textbox(label="Avg. Token Len", interactive=False)
|
224 |
+
std_dev_box = gr.Textbox(label="Std. Dev Len", interactive=False)
|
225 |
+
ratio_box = gr.Textbox(label="Tokens/Chars Ratio", interactive=False)
|
226 |
+
|
227 |
+
output_html = gr.HTML(label="Tokens and IDs")
|
228 |
|
229 |
+
input_textbox.input(
|
230 |
+
fn=visualize_tokenization,
|
231 |
+
inputs=[input_textbox],
|
232 |
+
outputs=[output_html, avg_len_box, std_dev_box, ratio_box]
|
|
|
|
|
233 |
)
|
234 |
|
235 |
+
gr.Examples(
|
236 |
+
examples=[
|
237 |
+
"Artificial intelligence is fascinating.",
|
238 |
+
'Test with "quotes" and spaces.',
|
239 |
+
"Code like `if (x==10)` and emojis 👍🚀 are handled.",
|
240 |
+
"Hello world! This is a test of the AG-BPE tokenizer.",
|
241 |
+
"안녕하세요",
|
242 |
+
"Salut comment ça va ?"
|
243 |
+
],
|
244 |
+
inputs=input_textbox
|
245 |
)
|
246 |
|
247 |
if __name__ == "__main__":
|
248 |
+
demo.launch()
|
|