marathi-tokenizer-era / read_files.py
nragrawal's picture
Initial commit
09ffe17
import os
import json
import regex as re
from natsort import natsorted
from tqdm import tqdm
# Add the Marathi regex pattern at the top level
MARATHI_PATTERN = re.compile(r"""
# Contractions and common affixes
'चा|'ची|'चे|'ला|'ले|'नी|
# Words with optional vowel signs and modifiers
[\p{L}\p{M}]+|
# Numbers
\p{N}+|
# Punctuation and special characters
[^\s\p{L}\p{N}\p{M}]+|
# Whitespace
\s+
""", re.VERBOSE)
def text_to_bytes(text):
"""Convert text to byte tokens after applying Marathi regex"""
words = MARATHI_PATTERN.findall(text)
all_bytes = []
for word in words:
bytes_tokens = [b for c in word for b in c.encode('utf-8')]
all_bytes.extend(bytes_tokens)
return all_bytes
def read_text_files(folder_path='train', limit=10):
# Check if the folder exists
if not os.path.exists(folder_path):
print(f"Error: The folder '{folder_path}' does not exist.")
return
# Get list of all files in the folder
files = os.listdir(folder_path)
# Filter for text files and sort them naturally
text_files = natsorted([f for f in files if f.endswith(('.txt', '.text'))])
if not text_files:
print(f"No text files found in '{folder_path}' folder.")
return
# Take only the first 'limit' files
text_files = text_files[:limit]
# Initialize list to store all tokens
all_tokens = []
# Read and print contents of each file
for file_name in text_files:
file_path = os.path.join(folder_path, file_name)
try:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
# Convert text to bytes using Marathi-aware tokenization
tokens = text_to_bytes(content)
all_tokens.extend(tokens)
except Exception as e:
print(f"Error reading {file_name}: {str(e)}")
print("\n=== Combined Statistics ===")
print("Total number of tokens:", len(all_tokens))
print("First 100 tokens:", all_tokens[:100])
return all_tokens
def get_stats(ids):
counts = {}
for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
def encode(text, merges):
"""
Encode text into tokens using the learned merges
"""
# First convert text to bytes using Marathi-aware tokenization
ids = text_to_bytes(text)
# Apply the merges in order of their token indices
# Sort by the token index to ensure consistent ordering
sorted_merges = sorted(merges.items(), key=lambda x: x[1])
for (p1, p2), idx in sorted_merges:
ids = merge(ids, (p1, p2), idx)
return ids
def decode(ids, merges):
"""
Decode tokens back to text using the learned merges
"""
# Create reverse mapping from token to pair
reverse_merges = {idx: pair for pair, idx in merges.items()}
# Expand all tokens recursively
def expand_token(token):
if token < 256: # Base case: token is a byte
return bytes([token])
# Recursive case: expand the token into its constituent pair
pair = reverse_merges[token]
return expand_token(pair[0]) + expand_token(pair[1])
# Expand all tokens and concatenate
bytes_list = [expand_token(id) for id in ids]
bytes_data = b''.join(bytes_list)
# Convert bytes back to text
try:
return bytes_data.decode('utf-8')
except UnicodeDecodeError:
return "[DECODE_ERROR]"
class Tokenizer:
def __init__(self, merges=None):
self.merges = merges or {}
def encode(self, text):
return encode(text, self.merges)
def decode(self, ids):
return decode(ids, self.merges)
def save(self, path):
"""Save the tokenizer to a JSON file"""
# Convert tuple keys to strings for JSON serialization
serializable_merges = {f"{p1},{p2}": idx for (p1, p2), idx in self.merges.items()}
with open(path, 'w') as f:
json.dump(serializable_merges, f)
@classmethod
def load(cls, path):
"""Load a tokenizer from a JSON file"""
with open(path, 'r') as f:
serialized_merges = json.load(f)
# Convert string keys back to tuples
merges = {tuple(map(int, k.split(','))): v for k, v in serialized_merges.items()}
return cls(merges)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, help='Path to tokenizer checkpoint')
parser.add_argument('--train', action='store_true', help='Train a new tokenizer')
parser.add_argument('--encode', type=str, help='Text to encode')
parser.add_argument('--decode', type=str, help='Comma-separated integers to decode')
args = parser.parse_args()
if args.train:
# Train new tokenizer
all_tokens = read_text_files(limit=100)
initial_len = len(all_tokens)
# ---
vocab_size = 5000 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(all_tokens) # copy so we don't destroy the original list
merges = {} # (int, int) -> int
pbar = tqdm(range(num_merges), desc="Merging tokens")
for i in pbar:
stats = get_stats(ids)
pair = max(stats, key=stats.get)
idx = 256 + i
ids = merge(ids, pair, idx)
merges[pair] = idx
current_ratio = initial_len / len(ids)
pbar.write(f"Iteration {i+1}: compression ratio: {current_ratio:.2f}X")
print("\nFinal Statistics:")
print("Initial tokens length:", initial_len)
print("Final ids length:", len(ids))
print(f"Final compression ratio: {initial_len / len(ids):.2f}X")
tokenizer = Tokenizer(merges)
if args.checkpoint:
tokenizer.save(args.checkpoint)
print(f"Saved tokenizer to {args.checkpoint}")
elif args.encode or args.decode:
if not args.checkpoint:
print("Error: --checkpoint is required for encode/decode operations")
exit(1)
# Load tokenizer for encoding/decoding
tokenizer = Tokenizer.load(args.checkpoint)
print(f"Loaded tokenizer from {args.checkpoint}")
if args.encode:
# Encode the provided text
encoded = tokenizer.encode(args.encode)
print(f"\nEncoding: {args.encode}")
print(f"Encoded tokens: {encoded}")
if args.decode:
# Decode the provided tokens
try:
tokens = [int(x.strip()) for x in args.decode.split(',')]
decoded = tokenizer.decode(tokens)
print(f"\nDecoding: {tokens}")
print(f"Decoded text: {decoded}")
except ValueError:
print("Error: decode argument should be comma-separated integers")
exit(1)
else:
parser.print_help()
exit(1)
# Test encode/decode
test_text = "नमस्कार, जग! ही एक चाचणी आहे."
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
print("\nEncoding/Decoding Test:")
print(f"Original: {test_text}")
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")
print(f"Successful roundtrip: {test_text == decoded}")