Spaces:
Configuration error
Configuration error
import os | |
import json | |
import regex as re | |
from natsort import natsorted | |
from tqdm import tqdm | |
# Add the Marathi regex pattern at the top level | |
MARATHI_PATTERN = re.compile(r""" | |
# Contractions and common affixes | |
'चा|'ची|'चे|'ला|'ले|'नी| | |
# Words with optional vowel signs and modifiers | |
[\p{L}\p{M}]+| | |
# Numbers | |
\p{N}+| | |
# Punctuation and special characters | |
[^\s\p{L}\p{N}\p{M}]+| | |
# Whitespace | |
\s+ | |
""", re.VERBOSE) | |
def text_to_bytes(text): | |
"""Convert text to byte tokens after applying Marathi regex""" | |
words = MARATHI_PATTERN.findall(text) | |
all_bytes = [] | |
for word in words: | |
bytes_tokens = [b for c in word for b in c.encode('utf-8')] | |
all_bytes.extend(bytes_tokens) | |
return all_bytes | |
def read_text_files(folder_path='train', limit=10): | |
# Check if the folder exists | |
if not os.path.exists(folder_path): | |
print(f"Error: The folder '{folder_path}' does not exist.") | |
return | |
# Get list of all files in the folder | |
files = os.listdir(folder_path) | |
# Filter for text files and sort them naturally | |
text_files = natsorted([f for f in files if f.endswith(('.txt', '.text'))]) | |
if not text_files: | |
print(f"No text files found in '{folder_path}' folder.") | |
return | |
# Take only the first 'limit' files | |
text_files = text_files[:limit] | |
# Initialize list to store all tokens | |
all_tokens = [] | |
# Read and print contents of each file | |
for file_name in text_files: | |
file_path = os.path.join(folder_path, file_name) | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
content = file.read() | |
# Convert text to bytes using Marathi-aware tokenization | |
tokens = text_to_bytes(content) | |
all_tokens.extend(tokens) | |
except Exception as e: | |
print(f"Error reading {file_name}: {str(e)}") | |
print("\n=== Combined Statistics ===") | |
print("Total number of tokens:", len(all_tokens)) | |
print("First 100 tokens:", all_tokens[:100]) | |
return all_tokens | |
def get_stats(ids): | |
counts = {} | |
for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements | |
counts[pair] = counts.get(pair, 0) + 1 | |
return counts | |
def merge(ids, pair, idx): | |
newids = [] | |
i = 0 | |
while i < len(ids): | |
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: | |
newids.append(idx) | |
i += 2 | |
else: | |
newids.append(ids[i]) | |
i += 1 | |
return newids | |
def encode(text, merges): | |
""" | |
Encode text into tokens using the learned merges | |
""" | |
# First convert text to bytes using Marathi-aware tokenization | |
ids = text_to_bytes(text) | |
# Apply the merges in order of their token indices | |
# Sort by the token index to ensure consistent ordering | |
sorted_merges = sorted(merges.items(), key=lambda x: x[1]) | |
for (p1, p2), idx in sorted_merges: | |
ids = merge(ids, (p1, p2), idx) | |
return ids | |
def decode(ids, merges): | |
""" | |
Decode tokens back to text using the learned merges | |
""" | |
# Create reverse mapping from token to pair | |
reverse_merges = {idx: pair for pair, idx in merges.items()} | |
# Expand all tokens recursively | |
def expand_token(token): | |
if token < 256: # Base case: token is a byte | |
return bytes([token]) | |
# Recursive case: expand the token into its constituent pair | |
pair = reverse_merges[token] | |
return expand_token(pair[0]) + expand_token(pair[1]) | |
# Expand all tokens and concatenate | |
bytes_list = [expand_token(id) for id in ids] | |
bytes_data = b''.join(bytes_list) | |
# Convert bytes back to text | |
try: | |
return bytes_data.decode('utf-8') | |
except UnicodeDecodeError: | |
return "[DECODE_ERROR]" | |
class Tokenizer: | |
def __init__(self, merges=None): | |
self.merges = merges or {} | |
def encode(self, text): | |
return encode(text, self.merges) | |
def decode(self, ids): | |
return decode(ids, self.merges) | |
def save(self, path): | |
"""Save the tokenizer to a JSON file""" | |
# Convert tuple keys to strings for JSON serialization | |
serializable_merges = {f"{p1},{p2}": idx for (p1, p2), idx in self.merges.items()} | |
with open(path, 'w') as f: | |
json.dump(serializable_merges, f) | |
def load(cls, path): | |
"""Load a tokenizer from a JSON file""" | |
with open(path, 'r') as f: | |
serialized_merges = json.load(f) | |
# Convert string keys back to tuples | |
merges = {tuple(map(int, k.split(','))): v for k, v in serialized_merges.items()} | |
return cls(merges) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--checkpoint', type=str, help='Path to tokenizer checkpoint') | |
parser.add_argument('--train', action='store_true', help='Train a new tokenizer') | |
parser.add_argument('--encode', type=str, help='Text to encode') | |
parser.add_argument('--decode', type=str, help='Comma-separated integers to decode') | |
args = parser.parse_args() | |
if args.train: | |
# Train new tokenizer | |
all_tokens = read_text_files(limit=100) | |
initial_len = len(all_tokens) | |
# --- | |
vocab_size = 5000 # the desired final vocabulary size | |
num_merges = vocab_size - 256 | |
ids = list(all_tokens) # copy so we don't destroy the original list | |
merges = {} # (int, int) -> int | |
pbar = tqdm(range(num_merges), desc="Merging tokens") | |
for i in pbar: | |
stats = get_stats(ids) | |
pair = max(stats, key=stats.get) | |
idx = 256 + i | |
ids = merge(ids, pair, idx) | |
merges[pair] = idx | |
current_ratio = initial_len / len(ids) | |
pbar.write(f"Iteration {i+1}: compression ratio: {current_ratio:.2f}X") | |
print("\nFinal Statistics:") | |
print("Initial tokens length:", initial_len) | |
print("Final ids length:", len(ids)) | |
print(f"Final compression ratio: {initial_len / len(ids):.2f}X") | |
tokenizer = Tokenizer(merges) | |
if args.checkpoint: | |
tokenizer.save(args.checkpoint) | |
print(f"Saved tokenizer to {args.checkpoint}") | |
elif args.encode or args.decode: | |
if not args.checkpoint: | |
print("Error: --checkpoint is required for encode/decode operations") | |
exit(1) | |
# Load tokenizer for encoding/decoding | |
tokenizer = Tokenizer.load(args.checkpoint) | |
print(f"Loaded tokenizer from {args.checkpoint}") | |
if args.encode: | |
# Encode the provided text | |
encoded = tokenizer.encode(args.encode) | |
print(f"\nEncoding: {args.encode}") | |
print(f"Encoded tokens: {encoded}") | |
if args.decode: | |
# Decode the provided tokens | |
try: | |
tokens = [int(x.strip()) for x in args.decode.split(',')] | |
decoded = tokenizer.decode(tokens) | |
print(f"\nDecoding: {tokens}") | |
print(f"Decoded text: {decoded}") | |
except ValueError: | |
print("Error: decode argument should be comma-separated integers") | |
exit(1) | |
else: | |
parser.print_help() | |
exit(1) | |
# Test encode/decode | |
test_text = "नमस्कार, जग! ही एक चाचणी आहे." | |
encoded = tokenizer.encode(test_text) | |
decoded = tokenizer.decode(encoded) | |
print("\nEncoding/Decoding Test:") | |
print(f"Original: {test_text}") | |
print(f"Encoded: {encoded}") | |
print(f"Decoded: {decoded}") | |
print(f"Successful roundtrip: {test_text == decoded}") | |