import gradio as gr
import jiwer
import pandas as pd
import logging
from typing import List, Optional, Tuple, Dict

# Set up logging configuration
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    force=True,
    handlers=[
        logging.StreamHandler(),
    ]
)
logger = logging.getLogger(__name__)

def calculate_wer_metrics(
    hypothesis: str,
    reference: str,
    normalize: bool = True,
    words_to_filter: Optional[List[str]] = None
) -> Dict:
    """
    Calculate WER metrics between hypothesis and reference texts.
    
    Args:
        hypothesis (str): The hypothesis text
        reference (str): The reference text
        normalize (bool): Whether to normalize texts before comparison
        words_to_filter (List[str], optional): Words to filter out before comparison
        
    Returns:
        dict: Dictionary containing WER metrics
        
    Raises:
        ValueError: If inputs are invalid or result in empty text after processing
    """
    logger.info(f"Calculating WER metrics with inputs - Hypothesis: {hypothesis}, Reference: {reference}")
    
    # Validate inputs
    if not hypothesis.strip() or not reference.strip():
        raise ValueError("Both hypothesis and reference texts must contain non-empty strings")
    
    if normalize:
        # Define basic transformations
        basic_transform = jiwer.Compose([
            jiwer.ExpandCommonEnglishContractions(),
            jiwer.ToLowerCase(),
            jiwer.RemoveMultipleSpaces(),
            jiwer.RemovePunctuation(),
            jiwer.Strip(),
            jiwer.ReduceToListOfListOfWords()
        ])
        
        if words_to_filter and any(words_to_filter):
            def filter_words_transform(words: List[str]) -> List[str]:
                filtered = [word for word in words 
                          if word.lower() not in [w.lower() for w in words_to_filter]]
                if not filtered:
                    raise ValueError("Text is empty after filtering words")
                return filtered
            
            transformation = jiwer.Compose([
                basic_transform,
                filter_words_transform
            ])
        else:
            transformation = basic_transform
        
        # Pre-check the transformed text
        try:
            transformed_ref = transformation(reference)
            transformed_hyp = transformation(hypothesis)
            if not transformed_ref or not transformed_hyp:
                raise ValueError("Text is empty after normalization")
            logger.debug(f"Transformed reference: {transformed_ref}")
            logger.debug(f"Transformed hypothesis: {transformed_hyp}")
        except Exception as e:
            logger.error(f"Transformation error: {str(e)}")
            raise ValueError(f"Error during text transformation: {str(e)}")
        
        measures = jiwer.compute_measures(
            truth=reference,
            hypothesis=hypothesis,
            truth_transform=transformation,
            hypothesis_transform=transformation
        )
    else:
        measures = jiwer.compute_measures(
            truth=reference,
            hypothesis=hypothesis
        )
    
    return measures

def process_inputs(
    reference: str,
    hypothesis: str,
    normalize: bool,
    words_to_filter: str
) -> Tuple[str, str, str, str]:
    """
    Process inputs and calculate WER metrics.
    
    Args:
        reference (str): Reference text
        hypothesis (str): Hypothesis text
        normalize (bool): Whether to normalize text
        words_to_filter (str): Comma-separated words to filter
        
    Returns:
        Tuple[str, str, str, str]: HTML formatted main metrics, error analysis, 
                                  and explanations
    """
    if not reference or not hypothesis:
        return "Please provide both reference and hypothesis texts.", "", "", ""
    
    try:
        filter_words = [word.strip() for word in words_to_filter.split(",")] if words_to_filter else None
        measures = calculate_wer_metrics(
            hypothesis=hypothesis,
            reference=reference,
            normalize=normalize,
            words_to_filter=filter_words
        )
        
        # Format main metrics
        metrics_df = pd.DataFrame({
            'Metric': ['WER', 'MER', 'WIL', 'WIP'],
            'Value': [
                f"{measures['wer']:.3f}",
                f"{measures['mer']:.3f}",
                f"{measures['wil']:.3f}",
                f"{measures['wip']:.3f}"
            ]
        })
        
        # Format error analysis
        error_df = pd.DataFrame({
            'Metric': ['Substitutions', 'Deletions', 'Insertions', 'Hits'],
            'Count': [
                measures['substitutions'],
                measures['deletions'],
                measures['insertions'],
                measures['hits']
            ]
        })
        
        metrics_html = metrics_df.to_html(index=False)
        error_html = error_df.to_html(index=False)
        
        explanation = """
        <h3>Metrics Explanation:</h3>
        <ul>
            <li><b>WER (Word Error Rate)</b>: The percentage of words that were incorrectly predicted</li>
            <li><b>MER (Match Error Rate)</b>: The percentage of words that were incorrectly matched</li>
            <li><b>WIL (Word Information Lost)</b>: The percentage of word information that was lost</li>
            <li><b>WIP (Word Information Preserved)</b>: The percentage of word information that was preserved</li>
        </ul>
        """
        
        return metrics_html, error_html, explanation, ""
        
    except Exception as e:
        error_msg = f"Error calculating WER: {str(e)}"
        logger.error(error_msg)
        return "", "", "", error_msg

def load_example() -> Tuple[str, str]:
    """Load example texts for demonstration."""
    return (
        "the quick brown fox jumps over the lazy dog",
        "the quick brown fox jumped over lazy dog"
    )

def create_interface() -> gr.Blocks:
    """Create the Gradio interface."""
    with gr.Blocks(title="WER Evaluation Tool") as interface:
        gr.Markdown("# Word Error Rate (WER) Evaluation Tool")
        gr.Markdown(
            "This tool helps you evaluate the Word Error Rate (WER) between a reference "
            "text and a hypothesis text. WER is commonly used in speech recognition and "
            "machine translation evaluation."
        )
        
        with gr.Row():
            with gr.Column():
                reference = gr.Textbox(
                    label="Reference Text",
                    placeholder="Enter the reference text here...",
                    lines=5
                )
            with gr.Column():
                hypothesis = gr.Textbox(
                    label="Hypothesis Text",
                    placeholder="Enter the hypothesis text here...",
                    lines=5
                )
        
        with gr.Row():
            normalize = gr.Checkbox(
                label="Normalize text (lowercase, remove punctuation)",
                value=True
            )
            words_to_filter = gr.Textbox(
                label="Words to filter (comma-separated)",
                placeholder="e.g., um, uh, ah"
            )
        
        with gr.Row():
            example_btn = gr.Button("Load Example")
            calculate_btn = gr.Button("Calculate WER", variant="primary")
        
        with gr.Row():
            metrics_output = gr.HTML(label="Main Metrics")
            error_output = gr.HTML(label="Error Analysis")
        
        explanation_output = gr.HTML()
        error_msg_output = gr.HTML()
        
        # Event handlers
        example_btn.click(
            load_example,
            outputs=[reference, hypothesis]
        )
        
        calculate_btn.click(
            process_inputs,
            inputs=[reference, hypothesis, normalize, words_to_filter],
            outputs=[metrics_output, error_output, explanation_output, error_msg_output]
        )
        
    return interface

if __name__ == "__main__":
    logger.info("Application started")
    app = create_interface()
    app.launch()