import numpy as np
import logging
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
import os
import json
from huggingface_hub import Repository
from huggingface_hub import HfApi

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

WAVELET_TOKENIZER_CONFIG = {
    "model_type": "wavelet",
    "tokenizer_class": "WaveletTokenizer",
    "auto_map": {
        "AutoTokenizer": ["tokenizer.WaveletTokenizer", None]
    }
}

@dataclass
class WaveletTokenizerConfig:
    vocab_size: int = 256
    padding_idx: int = 0
    eeg_channels: int = 74     # Source modality (EEG)
    mu: float = 255.0         # Static μ value for μ-law compression
    verbose: bool = True       # Control logging

class WaveletTokenizer(PreTrainedTokenizer):
    model_input_names = ["input_ids", "attention_mask", "position_ids"]
    
    def __init__(
        self,
        vocab_size: int = 256,
        mu: float = 255.0,
        verbose: bool = True,
        **kwargs
    ):
        self.auto_map = {
            "AutoTokenizer": ["tokenizer.WaveletTokenizer", None]
        }
        
        # Set vocab size first
        self._vocab_size = vocab_size
        self.mu = mu
        self.verbose = verbose
        
        # Store normalization state
        self.channel_mins = None
        self.channel_maxs = None
        
        # Initialize parent class after setting vocab_size
        super().__init__(**kwargs)
        
        if self.verbose:
            logger.info(f"Initialized WaveletTokenizer with μ={self.mu:.2f}")
    
    @property
    def vocab_size(self) -> int:
        """Returns the size of vocabulary (number of possible quantization levels)."""
        return self._vocab_size
    
    @vocab_size.setter
    def vocab_size(self, size: int):
        self._vocab_size = size
    
    def save_pretrained(
        self, 
        save_directory: str,
        legacy_format: bool = True,
        filename_prefix: Optional[str] = None,
        push_to_hub: bool = False,
        **kwargs
    ) -> Tuple[str, ...]:
        """Save tokenizer configuration to a directory."""
        if not os.path.exists(save_directory):
            os.makedirs(save_directory)
            
        # Save tokenizer config
        config = {
            **WAVELET_TOKENIZER_CONFIG,
            "vocab_size": self.vocab_size,
            "mu": self.mu,
            "verbose": self.verbose
        }
        
        config_file = os.path.join(
            save_directory, 
            (filename_prefix + "-" if filename_prefix else "") + "tokenizer_config.json"
        )
        
        with open(config_file, "w") as f:
            json.dump(config, f, indent=2)
            
        # Save vocabulary
        vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
        
        if push_to_hub:
            # Upload files using HTTP
            api = HfApi()
            api.upload_file(
                path_or_fileobj=config_file,
                path_in_repo="tokenizer_config.json",
                repo_id=save_directory,
                commit_message=kwargs.get("commit_message", "Upload tokenizer config")
            )
            
            # Upload vocabulary file
            vocab_file = vocab_files[0]
            api.upload_file(
                path_or_fileobj=vocab_file,
                path_in_repo=os.path.basename(vocab_file),
                repo_id=save_directory,
                commit_message=kwargs.get("commit_message", "Upload tokenizer vocabulary")
            )
            
        return vocab_files + (config_file,)
    
    @classmethod
    def from_pretrained(
        cls, 
        pretrained_model_name_or_path: str, 
        **kwargs
    ) -> "WaveletTokenizer":
        """Load tokenizer from HuggingFace Hub."""
        # Load config first
        config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
        if os.path.exists(config_file):
            with open(config_file, "r") as f:
                config = json.load(f)
            # Update with any passed kwargs
            config.update(kwargs)
        else:
            config = kwargs
            
        return cls(**config)
    
    def get_vocab(self) -> Dict[str, int]:
        """Returns vocab as a dict mapping token strings to ids."""
        # Create a minimal vocabulary with quantization levels
        return {str(i): i for i in range(self.vocab_size)}
    
    def _convert_token_to_id(self, token: str) -> int:
        """Converts a token string to its ID."""
        try:
            return int(token)
        except ValueError:
            return 0  # Return 0 for unknown tokens
    
    def _convert_id_to_token(self, index: int) -> str:
        """Converts an ID back to its token string."""
        return str(index)
    
    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Converts a sequence of tokens to a single string."""
        return " ".join(tokens)
    
    def _tokenize(self, text: str) -> List[str]:
        """Basic tokenization for compatibility."""
        if isinstance(text, str):
            return [text]
        return [str(t) for t in text]
    
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]:
        """Save the vocabulary to a directory."""
        vocab_file = os.path.join(
            save_directory, 
            (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
        )
        
        with open(vocab_file, "w", encoding="utf-8") as f:
            json.dump(self.get_vocab(), f, ensure_ascii=False)
            
        return (vocab_file,)
    
    def __call__(
        self,
        eeg_data: np.ndarray,
        **kwargs
    ) -> Dict[str, np.ndarray]:
        """
        Main entry point for tokenization. Handles numpy array input.
        
        Args:
            eeg_data: Raw EEG array of shape (n_channels, time_points)
            
        Returns:
            Dictionary containing:
                - input_ids: Tokenized signal values
                - attention_mask: Binary mask (all ones since we don't pad)
                - position_ids: Sequential position indices
        """
        # Process through tokenization pipeline
        input_ids = self.encode(eeg_data)
        
        # Create attention mask (all ones since we're not padding)
        attention_mask = np.ones_like(input_ids)
        
        # Create position IDs
        n_channels, time_points = eeg_data.shape
        position_ids = np.tile(np.arange(time_points), (n_channels, 1))
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "position_ids": position_ids
        }
    
    def encode(self, eeg_data: np.ndarray) -> np.ndarray:
        """Convert EEG data to token IDs."""
        # 1. Normalize to [0, 1]
        normalized = self.normalize(eeg_data)
        
        # 2. Convert to [-1, 1] for μ-law compression
        centered = 2 * normalized - 1
        
        # 3. Apply μ-law compression
        compressed = self.mu_law_encode(centered)
        
        # 4. Quantize to tokens
        input_values = (compressed + 1) / 2  # to [0, 1]
        token_ids = (input_values * (self.vocab_size - 1)).astype(np.int64)
        
        return token_ids
    
    def normalize(self, x: np.ndarray) -> np.ndarray:
        """
        Apply static normalization per channel and store min/max values.
        Input shape: (n_channels, time_points)
        """
        # Compute min/max per channel and expand dimensions to match input
        self.channel_mins = x.min(axis=1)[:, np.newaxis]  # Shape: (n_channels, 1)
        self.channel_maxs = x.max(axis=1)[:, np.newaxis]  # Shape: (n_channels, 1)
        
        normalized = (x - self.channel_mins) / (self.channel_maxs - self.channel_mins + 1e-8)
        
        if self.verbose:
            logger.info(f"Min-max normalization: input range [{x.min():.3f}, {x.max():.3f}] → [{normalized.min():.3f}, {normalized.max():.3f}]")
        return normalized
    
    def mu_law_encode(self, x: np.ndarray) -> np.ndarray:
        """
        Apply μ-law compression.
        Expects input in [-1, 1] range.
        """
        assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}"
        compressed = np.sign(x) * np.log1p(self.mu * np.abs(x)) / np.log1p(self.mu)
        
        if self.verbose:
            logger.info(f"μ-law compression (μ={self.mu:.2f}): variance before={np.var(x):.3f}, after={np.var(compressed):.3f}")
        return compressed
    
    def mu_law_decode(self, x: np.ndarray) -> np.ndarray:
        """
        Inverse μ-law compression.
        Expects input in [-1, 1] range.
        """
        assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}"
        return np.sign(x) * (1/self.mu) * (np.power(1 + self.mu, np.abs(x)) - 1.0)
    
    def decode(self, token_ids: np.ndarray) -> np.ndarray:
        """
        Decode token IDs back to EEG signal.
        
        Args:
            token_ids: Array of token IDs of shape (n_channels, time_points)
            
        Returns:
            Array of shape (n_channels, time_points)
        """
        # Convert to continuous values in [-1, 1]
        values = token_ids.astype(np.float32) / (self.vocab_size - 1)  # [0, 1]
        values = 2 * values - 1  # [-1, 1]
        
        # Apply inverse μ-law compression
        values = self.mu_law_decode(values)
        
        # Convert back to [0, 1]
        values = (values + 1) / 2
        
        # Denormalize to original scale
        if self.channel_mins is not None and self.channel_maxs is not None:
            values = values * (self.channel_maxs - self.channel_mins) + self.channel_mins
        
        return values