#!/usr/bin/env python3 """ Byte Pair Encoding Tokenizer for Indian Languages A simple implementation of BPE tokenizer with Marathi-specific preprocessing. Author: Shilpaj Bhalerao Date: 2025-01-05 """ # Standard Library Imports import re # Third Party Imports from tqdm import tqdm class BPETokenizer: """ Byte Pair Encoding Tokenizer :param vocab_size (int): Size of final vocabulary (including base bytes) :param merges (dict): Dictionary of merge rules :param vocab (dict): Dictionary mapping token IDs to their byte sequences :param inverse_vocab (dict): Dictionary mapping byte sequences to token IDs """ def __init__(self, vocab_size=1000, use_regex=False): """ Initialize the tokenizer with desired vocabulary size. """ self.vocab_size = vocab_size self.merges = {} self.len_of_ids = 0 self.len_raw_bytes = 0 self.vocab = {idx: bytes([idx]) for idx in range(256)} self.inverse_vocab = {bytes([idx]): idx for idx in range(256)} self.use_regex = use_regex # Marathi tokenization regex pattern self.marathi_regex = re.compile( r"([\u0900-\u094F\u0951-\u097F]+|" # Marathi words and ligatures r"[\u0966-\u096F]+|" # Marathi numerals (०-९) r"\d+(?:\s[\u0900-\u097F]+)?|" # Arabic numerals with Marathi context r"#[\w\u0900-\u097F]+|" # Hashtags r"[\w\u0900-\u097F]+[''][\w\u0900-\u097F]+|" # Compound words with apostrophes r"[\w\u0900-\u097F]+(?:-[\w\u0900-\u097F]+)*|" # Hyphenated words r"[\w\u0900-\u097F]+\.[\w\u0900-\u097F]*|" # Abbreviations r'\"[^\"]+\"|\'[^\']+\'|' # Quoted text r"[\u0964\u0965.!?…]|" # Marathi punctuation r"[^\s\u0900-\u097F]+)" # Non-Marathi symbols ) def preprocess(self, text: str) -> str: """ Preprocess Marathi text before tokenization. :param text: Input Marathi text :return: Preprocessed text with tokens separated by spaces """ # Find all tokens using the Marathi regex tokens = self.marathi_regex.findall(text) # Join tokens with spaces processed_text = ' '.join(tokens) # Normalize whitespace processed_text = ' '.join(processed_text.split()) return processed_text def _get_stats(self, ids: list[int]) -> dict[tuple[int, int], int]: """ Count frequency of adjacent pairs in sequence. :param ids: list of integers :return: dictionary of pairs and their frequencies """ counts = {} for pair in zip(ids, ids[1:]): counts[pair] = counts.get(pair, 0) + 1 return counts def _merge(self, ids: list[int], pair: tuple[int, int], idx: int) -> list[int]: """ Replace all occurrences of pair with new token idx. :param ids: list of integers :param pair: tuple of integers :param idx: integer :return: list of integers """ newids = [] i = 0 while i < len(ids): if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: newids.append(idx) i += 2 else: newids.append(ids[i]) i += 1 return newids def train(self, text: str): """ Train the BPE tokenizer on the given text. :param text: Input text to train on """ print("Training BPE tokenizer...") # Preprocess text first if self.use_regex: text = self.preprocess(text) # Convert text to bytes and get initial tokens raw_bytes = text.encode("utf-8") raw_bytes = list(map(int, raw_bytes)) # convert to integers self.len_raw_bytes = len(raw_bytes) # Calculate number of merges needed num_merges = self.vocab_size - 256 ids = list(raw_bytes) # copy so we don't destroy the original list # Perform merges for i in tqdm(range(num_merges)): stats = self._get_stats(ids) if not stats: break # Find most frequent pair pair = max(stats, key=stats.get) idx = 256 + i # Perform the merge ids = self._merge(ids, pair, idx) self.len_of_ids = len(ids) self.merges[pair] = idx # Update vocabulary new_token = self.vocab[pair[0]] + self.vocab[pair[1]] self.vocab[idx] = new_token self.inverse_vocab[new_token] = idx def encode(self, text: str) -> list[int]: """ Encode text into token IDs. :param text: Text to encode :return: List of token IDs """ # Preprocess if needed if self.use_regex: text = self.preprocess(text) # Convert text to list of integers tokens = list(text.encode("utf-8")) while len(tokens) >= 2: stats = self._get_stats(tokens) pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) if pair not in self.merges: break # nothing else can be merged idx = self.merges[pair] tokens = self._merge(tokens, pair, idx) return tokens def decode(self, ids: list[int]) -> str: """ Decode token IDs back to text. :param ids: List of token IDs :return: Decoded text """ tokens = b"".join(self.vocab[idx] for idx in ids) return tokens.decode("utf-8", errors="replace") def token_to_text(self, token_id: int) -> str: """ Convert a single token ID to its text representation. :param token_id: Token ID :return: Text representation of the token """ return self.vocab[token_id].decode("utf-8", errors="replace") def save(self, path: str): """ Save tokenizer state to file. :param path: Path to save the file """ import json state = { 'vocab_size': self.vocab_size, 'merges': list(self.merges.items()), # Convert to list of tuples 'vocab': {k: list(v) for k, v in self.vocab.items()} # Convert bytes to lists } with open(path, 'w') as f: json.dump(state, f) @classmethod def load(cls, path: str): """ Load tokenizer state from file. :param path: Path to load the file :return: Loaded tokenizer """ import json with open(path, 'r') as f: state = json.load(f) tokenizer = cls(vocab_size=state['vocab_size']) # Convert lists back to tuples for the merge pairs tokenizer.merges = {tuple(k): v for k, v in state['merges']} tokenizer.vocab = {int(k): bytes(v) for k, v in state['vocab'].items()} tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()} return tokenizer def get_vocab_size(self) -> int: """ Get the size of the vocabulary. :return: Size of the vocabulary """ return len(self.vocab) def get_compression_ratio(self, text: str) -> float: """ Get the compression ratio of the text. :param text: Input text :return: Compression ratio (original_length / encoded_length) """ # Preprocess if needed if self.use_regex: text = self.preprocess(text) return round(self.len_raw_bytes / self.len_of_ids, 4) def get_token_length(self, text: str) -> int: """ Get the length of the tokenized text. :param text: Input text :return: Length of the tokenized text """ return self.len_raw_bytes def get_ids_length(self, text: str) -> int: """ Get the length of the tokenized text. :param text: Input text :return: Length of the tokenized text """ return self.len_of_ids def is_encoded_equals_decoded(self, text: str) -> bool: """ Check if encoding and decoding are consistent. :param text: Input text :return: True if consistent, False otherwise """ encoded = self.encode(text) decoded = self.decode(encoded) return text == decoded if __name__ == "__main__": # Read text from file with open("dataset.txt", "r") as file: text = file.read() # Initialize and train tokenizer = BPETokenizer(vocab_size=3000) tokenizer.train(text) # Save and load tokenizer.save("tokenizer.json") loaded_tokenizer = BPETokenizer.load("tokenizer.json") # Encode and decode encoded = tokenizer.encode("या पुतळ्याच्या डोक्यावर अज्ञातांनी चप्पल ठेवल्याचे आढळून आले आहे.") decoded = loaded_tokenizer.decode(encoded) # Check consistency print("Is encoded equals to loaded decoded? ", decoded == "या पुतळ्याच्या डोक्यावर अज्ञातांनी चप्पल ठेवल्याचे आढळून आले आहे.") # Print vocab size print(f"Vocab size: {tokenizer.get_vocab_size()}") # Print token length print(f"Token length: {tokenizer.get_token_length(text)}") # Print ids length print(f"Ids length: {tokenizer.get_ids_length(text)}") # Print compression ratio print(f"Compression ratio: {tokenizer.get_compression_ratio(text)}X")