import gradio as gr
from datasets import load_dataset
from difflib import ndiff

from semhash import SemHash
from semhash.datamodels import DeduplicationResult

from model2vec import StaticModel

# Default parameters
default_dataset_name = "SetFit/amazon_massive_scenario_en-US"
default_dataset1_split = "train"
default_dataset2_split = "test"
default_text_column = "text"
default_threshold = 0.9

# Load the model to use
model = StaticModel.from_pretrained("minishlab/potion-base-8M")


def display_word_differences(x: str, y: str) -> str:
    """
    Display the word-level differences between two texts, formatted to avoid
    misinterpretation of Markdown syntax.
    """
    diff = ndiff(x.split(), y.split())
    formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
    return f"```\n{formatted_diff}\n```"


def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str) -> list[str]:
    """Load texts from a specified dataset split."""
    ds = load_dataset(dataset_name, split=dataset_split)
    return [example[text_column] for example in ds]


def deduplicate_single_dataset(texts: list[str], threshold: float) -> DeduplicationResult:
    """Deduplicate within a single dataset using SemHash, treating each text as a raw string record."""
    # Build a SemHash index from the raw texts
    semhash = SemHash.from_records(records=texts, model=model)
    # Deduplicate the entire dataset
    return semhash.self_deduplicate(threshold=threshold)


def deduplicate_two_datasets(texts1: list[str], texts2: list[str], threshold: float) -> DeduplicationResult:
    """Deduplicate dataset2 against dataset1, both as raw strings, using SemHash."""
    # Build SemHash index on dataset1
    semhash = SemHash.from_records(records=texts1, model=model)
    # Deduplicate texts2 against dataset1
    return semhash.deduplicate(records=texts2, threshold=threshold)


def perform_deduplication(
    deduplication_type: str,
    dataset1_name: str,
    dataset1_split: str,
    dataset1_text_column: str,
    dataset2_name: str = "",
    dataset2_split: str = "",
    dataset2_text_column: str = "",
    threshold: float = default_threshold,
    progress: gr.Progress = gr.Progress(track_tqdm=True)
):
    """
    Perform deduplication on one or two datasets using SemHash. This function
    streams status updates to Gradio for user feedback.
    """
    try:
        threshold = float(threshold)

        # Load Dataset 1
        yield "Loading Dataset 1...", ""
        texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)

        if deduplication_type == "Single dataset":
            # Single-dataset deduplication
            yield "Deduplicating within Dataset 1 (SemHash)...", ""
            result = deduplicate_single_dataset(texts1, threshold=threshold)

            # Sort all duplicates in descending order of their highest score
            for duprec in result.duplicates:
                duprec.duplicates.sort(key=lambda x: x[1], reverse=True)

            # Summarize results
            num_duplicates = len(result.duplicates)
            deduplicated_count = len(result.deduplicated)
            total_docs = len(texts1)
            
            result_text = (
                f"**Total documents (Dataset 1):** {total_docs}\n\n"
                f"**Duplicates found:** {num_duplicates}\n\n"
                f"**Unique documents after deduplication:** {deduplicated_count}\n\n"
                + "-" * 50 + "\n\n"
            )

            # Show example duplicates
            if num_duplicates > 0:
                result_text += "**Example duplicates:**\n\n"

                # Only show duplicates that actually have near-duplicate records
                duplicates_with_data = [duprec for duprec in result.duplicates if duprec.duplicates]
                if duplicates_with_data:
                    for duprec in duplicates_with_data[:5]:
                        dup_text = duprec.record
                        orig_text, score = duprec.duplicates[0]
                        differences = display_word_differences(orig_text, dup_text)
                        result_text += (
                            f"**Original:**\n{orig_text}\n\n"
                            f"**Duplicate:**\n{dup_text}\n\n"
                            f"**Similarity Score:** {score:.4f}\n"
                            f"**Differences:**\n{differences}\n"
                            + "-" * 50 + "\n\n"
                        )
                else:
                    result_text += "No near-duplicate details available.\n\n"
            else:
                result_text += "No duplicates found."

            yield "Deduplication completed.", result_text

        else:
            # Cross-dataset deduplication
            yield "Loading Dataset 2...", ""
            texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)

            yield "Deduplicating Dataset 2 against Dataset 1 (SemHash)...", ""
            result = deduplicate_two_datasets(texts1, texts2, threshold=threshold)

            # Sort duplicates in descending order of their highest score
            for duprec in result.duplicates:
                duprec.duplicates.sort(key=lambda x: x[1], reverse=True)

            num_duplicates = len(result.duplicates)
            total_docs2 = len(texts2)
            deduplicated_count = len(result.deduplicated)

            result_text = (
                f"**Total documents in {dataset2_name}/{dataset2_split}:** {total_docs2}\n\n"
                f"**Duplicates found in Dataset 2:** {num_duplicates}\n\n"
                f"**Unique documents after deduplication:** {deduplicated_count}\n\n"
                + "-" * 50 + "\n\n"
            )

            if num_duplicates > 0:
                result_text += "**Example duplicates from Dataset 2:**\n\n"

                # Again, only show duplicates that actually have near-duplicate records
                duplicates_with_data = [duprec for duprec in result.duplicates if duprec.duplicates]
                if duplicates_with_data:
                    for duprec in duplicates_with_data[:5]:
                        dup_text = duprec.record  # The "duplicate" text from dataset2
                        orig_text, score = duprec.duplicates[0]
                        differences = display_word_differences(orig_text, dup_text)
                        result_text += (
                            f"**Original (Dataset 1):**\n{orig_text}\n\n"
                            f"**Duplicate (Dataset 2):**\n{dup_text}\n\n"
                            f"**Similarity Score:** {score:.4f}\n"
                            f"**Differences:**\n{differences}\n"
                            + "-" * 50 + "\n\n"
                        )
                else:
                    result_text += "No near-duplicate details available.\n\n"
            else:
                result_text += "No duplicates found."

            yield "Deduplication completed.", result_text

    except Exception as e:
        yield f"An error occurred: {e}", ""
        raise e


# --- Gradio App ---
with gr.Blocks(theme=gr.themes.Ocean(), css="#status_output { height: 50px; overflow: auto; }") as demo:
    gr.Markdown("# Semantic Text Deduplication Using SemHash")
    gr.Markdown("""
    This demo showcases **semantic deduplication** using [SemHash](https://github.com/MinishLab/semhash) for HuggingFace datasets, using a [Model2Vec](https://github.com/MinishLab/model2vec) encoder.
    It can be used to identify duplicate texts within a **single dataset** or across **two datasets**.
    You can adjust the similarity threshold to control the strictness of the deduplication.

    **NOTE**: This demo runs on a free CPU backend, so it may be slow for large datasets.
    For faster results, please run the code locally.
    """)

    deduplication_type = gr.Radio(
        choices=["Cross-dataset", "Single dataset"],
        label="Deduplication Type",
        value="Cross-dataset",  # default
    )

    with gr.Row():
        dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
        dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
        dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")

    dataset2_inputs = gr.Column(visible=True)
    with dataset2_inputs:
        with gr.Row():
            dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
            dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
            dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")

    threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")

    with gr.Row():
        compute_button = gr.Button("Deduplicate")

    status_output = gr.Markdown(elem_id="status_output")
    result_output = gr.Markdown()

    def update_visibility(choice: str):
        return gr.update(visible=(choice == "Cross-dataset"))

    deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)

    compute_button.click(
        fn=perform_deduplication,
        inputs=[
            deduplication_type,
            dataset1_name,
            dataset1_split,
            dataset1_text_column,
            dataset2_name,
            dataset2_split,
            dataset2_text_column,
            threshold,
        ],
        outputs=[status_output, result_output],
    )

demo.launch()