#!/usr/bin/env python3
"""
Byte Pair Encoding Tokenizer for Indian Languages
A simple implementation of BPE tokenizer with Marathi-specific preprocessing.
Author: Shilpaj Bhalerao
Date: 2025-01-05
"""
# Standard Library Imports
import re

# Third Party Imports
from tqdm import tqdm


class BPETokenizer:
    """
    Byte Pair Encoding Tokenizer
    
    :param vocab_size (int): Size of final vocabulary (including base bytes)
    :param merges (dict): Dictionary of merge rules
    :param vocab (dict): Dictionary mapping token IDs to their byte sequences
    :param inverse_vocab (dict): Dictionary mapping byte sequences to token IDs
    """
    
    def __init__(self, vocab_size=1000, use_regex=False):
        """
        Initialize the tokenizer with desired vocabulary size.
        """
        self.vocab_size = vocab_size
        self.merges = {}
        self.len_of_ids = 0
        self.len_raw_bytes = 0
        self.vocab = {idx: bytes([idx]) for idx in range(256)}
        self.inverse_vocab = {bytes([idx]): idx for idx in range(256)}
        self.use_regex = use_regex
        
        # Marathi tokenization regex pattern
        self.marathi_regex = re.compile(
            r"([\u0900-\u094F\u0951-\u097F]+|"           # Marathi words and ligatures
            r"[\u0966-\u096F]+|"                         # Marathi numerals (०-९)
            r"\d+(?:\s[\u0900-\u097F]+)?|"              # Arabic numerals with Marathi context
            r"#[\w\u0900-\u097F]+|"                     # Hashtags
            r"[\w\u0900-\u097F]+[''][\w\u0900-\u097F]+|" # Compound words with apostrophes
            r"[\w\u0900-\u097F]+(?:-[\w\u0900-\u097F]+)*|" # Hyphenated words
            r"[\w\u0900-\u097F]+\.[\w\u0900-\u097F]*|"  # Abbreviations
            r'\"[^\"]+\"|\'[^\']+\'|'                   # Quoted text
            r"[\u0964\u0965.!?…]|"                      # Marathi punctuation
            r"[^\s\u0900-\u097F]+)"                     # Non-Marathi symbols
        )
    
    def preprocess(self, text: str) -> str:
        """
        Preprocess Marathi text before tokenization.
        
        :param text: Input Marathi text
        :return: Preprocessed text with tokens separated by spaces
        """
        # Find all tokens using the Marathi regex
        tokens = self.marathi_regex.findall(text)
        
        # Join tokens with spaces
        processed_text = ' '.join(tokens)
        
        # Normalize whitespace
        processed_text = ' '.join(processed_text.split())
        
        return processed_text
    
    def _get_stats(self, ids: list[int]) -> dict[tuple[int, int], int]:
        """
        Count frequency of adjacent pairs in sequence.
        
        :param ids: list of integers
        :return: dictionary of pairs and their frequencies
        """
        counts = {}
        for pair in zip(ids, ids[1:]):
            counts[pair] = counts.get(pair, 0) + 1
        return counts
    
    def _merge(self, ids: list[int], pair: tuple[int, int], idx: int) -> list[int]:
        """
        Replace all occurrences of pair with new token idx.
        
        :param ids: list of integers
        :param pair: tuple of integers
        :param idx: integer
        :return: list of integers
        """
        newids = []
        i = 0
        while i < len(ids):
            if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
                newids.append(idx)
                i += 2
            else:
                newids.append(ids[i])
                i += 1
        return newids
    
    def train(self, text: str):
        """
        Train the BPE tokenizer on the given text.
        :param text: Input text to train on
        """
        print("Training BPE tokenizer...")

        # Preprocess text first
        if self.use_regex:
            text = self.preprocess(text)
        
        # Convert text to bytes and get initial tokens
        raw_bytes = text.encode("utf-8")
        raw_bytes = list(map(int, raw_bytes))  # convert to integers
        self.len_raw_bytes = len(raw_bytes)
        
        # Calculate number of merges needed
        num_merges = self.vocab_size - 256
        ids = list(raw_bytes)  # copy so we don't destroy the original list
        
        # Perform merges
        for i in tqdm(range(num_merges)):
            stats = self._get_stats(ids)
            if not stats:
                break
                
            # Find most frequent pair
            pair = max(stats, key=stats.get)
            idx = 256 + i
            
            # Perform the merge
            ids = self._merge(ids, pair, idx)
            self.len_of_ids = len(ids)
            self.merges[pair] = idx
            
            # Update vocabulary
            new_token = self.vocab[pair[0]] + self.vocab[pair[1]]
            self.vocab[idx] = new_token
            self.inverse_vocab[new_token] = idx
    
    def encode(self, text: str) -> list[int]:
        """
        Encode text into token IDs.
        
        :param text: Text to encode
        :return: List of token IDs
        """
        # Preprocess if needed
        if self.use_regex:
            text = self.preprocess(text)
        
        # Convert text to list of integers
        tokens = list(text.encode("utf-8"))
        while len(tokens) >= 2:
            stats = self._get_stats(tokens)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break # nothing else can be merged
            idx = self.merges[pair]
            tokens = self._merge(tokens, pair, idx)
        return tokens
    
    def decode(self, ids: list[int]) -> str:
        """
        Decode token IDs back to text.
        
        :param ids: List of token IDs
        :return: Decoded text
        """
        tokens = b"".join(self.vocab[idx] for idx in ids)
        return tokens.decode("utf-8", errors="replace")
    
    def token_to_text(self, token_id: int) -> str:
        """
        Convert a single token ID to its text representation.
        
        :param token_id: Token ID
        :return: Text representation of the token
        """
        return self.vocab[token_id].decode("utf-8", errors="replace")
    
    def save(self, path: str):
        """
        Save tokenizer state to file.
        
        :param path: Path to save the file
        """
        import json
        state = {
            'vocab_size': self.vocab_size,
            'merges': list(self.merges.items()),  # Convert to list of tuples
            'vocab': {k: list(v) for k, v in self.vocab.items()}  # Convert bytes to lists
        }
        with open(path, 'w') as f:
            json.dump(state, f)
    
    @classmethod
    def load(cls, path: str):
        """
        Load tokenizer state from file.
        
        :param path: Path to load the file
        :return: Loaded tokenizer
        """
        import json
        with open(path, 'r') as f:
            state = json.load(f)
        
        tokenizer = cls(vocab_size=state['vocab_size'])
        # Convert lists back to tuples for the merge pairs
        tokenizer.merges = {tuple(k): v for k, v in state['merges']}
        tokenizer.vocab = {int(k): bytes(v) for k, v in state['vocab'].items()}
        tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
        return tokenizer
    
    def get_vocab_size(self) -> int:
        """
        Get the size of the vocabulary.
        
        :return: Size of the vocabulary
        """
        return len(self.vocab)
    
    def get_compression_ratio(self, text: str) -> float:
        """
        Get the compression ratio of the text.
        
        :param text: Input text
        :return: Compression ratio (original_length / encoded_length)
        """
        # Preprocess if needed
        if self.use_regex:
            text = self.preprocess(text)
        
        return round(self.len_raw_bytes / self.len_of_ids, 4)

    def get_token_length(self, text: str) -> int:
        """
        Get the length of the tokenized text.
        
        :param text: Input text
        :return: Length of the tokenized text
        """
        return self.len_raw_bytes
    
    def get_ids_length(self, text: str) -> int:
        """
        Get the length of the tokenized text.
        
        :param text: Input text
        :return: Length of the tokenized text
        """
        return self.len_of_ids

    def is_encoded_equals_decoded(self, text: str) -> bool:
        """
        Check if encoding and decoding are consistent.
        
        :param text: Input text
        :return: True if consistent, False otherwise
        """
        encoded = self.encode(text)
        decoded = self.decode(encoded)
        return text == decoded


if __name__ == "__main__":

    # Read text from file
    with open("dataset.txt", "r") as file:
        text = file.read()

    # Initialize and train
    tokenizer = BPETokenizer(vocab_size=3000)
    tokenizer.train(text)

    # Save and load
    tokenizer.save("tokenizer.json")
    loaded_tokenizer = BPETokenizer.load("tokenizer.json")

    # Encode and decode
    encoded = tokenizer.encode("या पुतळ्याच्या डोक्यावर अज्ञातांनी चप्पल ठेवल्याचे आढळून आले आहे.")
    decoded = loaded_tokenizer.decode(encoded)

    # Check consistency
    print("Is encoded equals to loaded decoded? ", decoded == "या पुतळ्याच्या डोक्यावर अज्ञातांनी चप्पल ठेवल्याचे आढळून आले आहे.")

    # Print vocab size
    print(f"Vocab size: {tokenizer.get_vocab_size()}")

    # Print token length
    print(f"Token length: {tokenizer.get_token_length(text)}")

    # Print ids length
    print(f"Ids length: {tokenizer.get_ids_length(text)}")

    # Print compression ratio
    print(f"Compression ratio: {tokenizer.get_compression_ratio(text)}X")