File size: 9,578 Bytes
25d2eb7
2827b8a
7a1cd7a
a81fb12
b54da62
 
 
 
f5eb405
95530b9
c8fad0f
b54da62
 
393e68a
3b4c438
f5eb405
b54da62
 
 
58d8f1a
7a1cd7a
ed5b7bd
 
 
 
7a1cd7a
73a84b9
 
 
ed5b7bd
b54da62
 
95530b9
 
7a1cd7a
b54da62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0286f
24f7d5b
 
 
 
 
 
 
 
 
2827b8a
ed5b7bd
b54da62
 
ed5b7bd
f5eb405
 
3bd0812
b54da62
95530b9
 
24f7d5b
f5eb405
b54da62
 
 
 
 
 
 
 
 
 
 
 
 
95530b9
b54da62
2a0be82
b54da62
2258895
c58907b
3bd0812
b54da62
5422464
5d96b3d
c8fad0f
 
 
 
 
 
b54da62
 
 
 
 
 
 
 
 
c8fad0f
 
5422464
 
3bd0812
95530b9
c58907b
95530b9
b54da62
95530b9
 
24f7d5b
b54da62
 
 
 
 
 
 
 
 
 
c58907b
95530b9
b54da62
2a0be82
b54da62
2258895
c58907b
39a5b1c
 
5d96b3d
c8fad0f
 
 
 
 
 
b54da62
 
 
 
 
 
 
 
 
c8fad0f
 
39a5b1c
 
 
95530b9
f5eb405
6b0e834
39a5b1c
c58907b
 
b54da62
 
365d622
b54da62
24f7d5b
b54da62
 
 
 
 
 
24f7d5b
4f0286f
 
b54da62
4f0286f
b54da62
4f0286f
 
 
95530b9
b54da62
4f0286f
 
b54da62
4f0286f
 
95530b9
b54da62
4f0286f
 
95530b9
2f9e086
b54da62
2f9e086
 
1a5f99b
c58907b
4f0286f
24f7d5b
b54da62
4f0286f
95530b9
4f0286f
 
 
 
 
 
 
 
 
 
 
c58907b
4f0286f
c58907b
4f0286f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import gradio as gr
from datasets import load_dataset
from difflib import ndiff

from semhash import SemHash
from semhash.datamodels import DeduplicationResult

from model2vec import StaticModel

# Default parameters
default_dataset_name = "SetFit/amazon_massive_scenario_en-US"
default_dataset1_split = "train"
default_dataset2_split = "test"
default_text_column = "text"
default_threshold = 0.9

# Load the model to use
model = StaticModel.from_pretrained("minishlab/potion-base-8M")


def display_word_differences(x: str, y: str) -> str:
    """
    Display the word-level differences between two texts, formatted to avoid
    misinterpretation of Markdown syntax.
    """
    diff = ndiff(x.split(), y.split())
    formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
    return f"```\n{formatted_diff}\n```"


def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str) -> list[str]:
    """Load texts from a specified dataset split."""
    ds = load_dataset(dataset_name, split=dataset_split)
    return [example[text_column] for example in ds]


def deduplicate_single_dataset(texts: list[str], threshold: float) -> DeduplicationResult:
    """Deduplicate within a single dataset using SemHash, treating each text as a raw string record."""
    # Build a SemHash index from the raw texts
    semhash = SemHash.from_records(records=texts, model=model)
    # Deduplicate the entire dataset
    return semhash.self_deduplicate(threshold=threshold)


def deduplicate_two_datasets(texts1: list[str], texts2: list[str], threshold: float) -> DeduplicationResult:
    """Deduplicate dataset2 against dataset1, both as raw strings, using SemHash."""
    # Build SemHash index on dataset1
    semhash = SemHash.from_records(records=texts1, model=model)
    # Deduplicate texts2 against dataset1
    return semhash.deduplicate(records=texts2, threshold=threshold)


def perform_deduplication(
    deduplication_type: str,
    dataset1_name: str,
    dataset1_split: str,
    dataset1_text_column: str,
    dataset2_name: str = "",
    dataset2_split: str = "",
    dataset2_text_column: str = "",
    threshold: float = default_threshold,
    progress: gr.Progress = gr.Progress(track_tqdm=True)
):
    """
    Perform deduplication on one or two datasets using SemHash. This function
    streams status updates to Gradio for user feedback.
    """
    try:
        threshold = float(threshold)

        # Load Dataset 1
        yield "Loading Dataset 1...", ""
        texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)

        if deduplication_type == "Single dataset":
            # Single-dataset deduplication
            yield "Deduplicating within Dataset 1 (SemHash)...", ""
            result = deduplicate_single_dataset(texts1, threshold=threshold)

            # Sort all duplicates in descending order of their highest score
            for duprec in result.duplicates:
                duprec.duplicates.sort(key=lambda x: x[1], reverse=True)

            # Summarize results
            num_duplicates = len(result.duplicates)
            deduplicated_count = len(result.deduplicated)
            total_docs = len(texts1)
            
            result_text = (
                f"**Total documents (Dataset 1):** {total_docs}\n\n"
                f"**Duplicates found:** {num_duplicates}\n\n"
                f"**Unique documents after deduplication:** {deduplicated_count}\n\n"
                + "-" * 50 + "\n\n"
            )

            # Show example duplicates
            if num_duplicates > 0:
                result_text += "**Example duplicates:**\n\n"

                # Only show duplicates that actually have near-duplicate records
                duplicates_with_data = [duprec for duprec in result.duplicates if duprec.duplicates]
                if duplicates_with_data:
                    for duprec in duplicates_with_data[:5]:
                        dup_text = duprec.record
                        orig_text, score = duprec.duplicates[0]
                        differences = display_word_differences(orig_text, dup_text)
                        result_text += (
                            f"**Original:**\n{orig_text}\n\n"
                            f"**Duplicate:**\n{dup_text}\n\n"
                            f"**Similarity Score:** {score:.4f}\n"
                            f"**Differences:**\n{differences}\n"
                            + "-" * 50 + "\n\n"
                        )
                else:
                    result_text += "No near-duplicate details available.\n\n"
            else:
                result_text += "No duplicates found."

            yield "Deduplication completed.", result_text

        else:
            # Cross-dataset deduplication
            yield "Loading Dataset 2...", ""
            texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)

            yield "Deduplicating Dataset 2 against Dataset 1 (SemHash)...", ""
            result = deduplicate_two_datasets(texts1, texts2, threshold=threshold)

            # Sort duplicates in descending order of their highest score
            for duprec in result.duplicates:
                duprec.duplicates.sort(key=lambda x: x[1], reverse=True)

            num_duplicates = len(result.duplicates)
            total_docs2 = len(texts2)
            deduplicated_count = len(result.deduplicated)

            result_text = (
                f"**Total documents in {dataset2_name}/{dataset2_split}:** {total_docs2}\n\n"
                f"**Duplicates found in Dataset 2:** {num_duplicates}\n\n"
                f"**Unique documents after deduplication:** {deduplicated_count}\n\n"
                + "-" * 50 + "\n\n"
            )

            if num_duplicates > 0:
                result_text += "**Example duplicates from Dataset 2:**\n\n"

                # Again, only show duplicates that actually have near-duplicate records
                duplicates_with_data = [duprec for duprec in result.duplicates if duprec.duplicates]
                if duplicates_with_data:
                    for duprec in duplicates_with_data[:5]:
                        dup_text = duprec.record  # The "duplicate" text from dataset2
                        orig_text, score = duprec.duplicates[0]
                        differences = display_word_differences(orig_text, dup_text)
                        result_text += (
                            f"**Original (Dataset 1):**\n{orig_text}\n\n"
                            f"**Duplicate (Dataset 2):**\n{dup_text}\n\n"
                            f"**Similarity Score:** {score:.4f}\n"
                            f"**Differences:**\n{differences}\n"
                            + "-" * 50 + "\n\n"
                        )
                else:
                    result_text += "No near-duplicate details available.\n\n"
            else:
                result_text += "No duplicates found."

            yield "Deduplication completed.", result_text

    except Exception as e:
        yield f"An error occurred: {e}", ""
        raise e


# --- Gradio App ---
with gr.Blocks(theme=gr.themes.Ocean(), css="#status_output { height: 50px; overflow: auto; }") as demo:
    gr.Markdown("# Semantic Text Deduplication Using SemHash")
    gr.Markdown("""
    This demo showcases **semantic deduplication** using [SemHash](https://github.com/MinishLab/semhash) for HuggingFace datasets, using a [Model2Vec](https://github.com/MinishLab/model2vec) encoder.
    It can be used to identify duplicate texts within a **single dataset** or across **two datasets**.
    You can adjust the similarity threshold to control the strictness of the deduplication.

    **NOTE**: This demo runs on a free CPU backend, so it may be slow for large datasets.
    For faster results, please run the code locally.
    """)

    deduplication_type = gr.Radio(
        choices=["Cross-dataset", "Single dataset"],
        label="Deduplication Type",
        value="Cross-dataset",  # default
    )

    with gr.Row():
        dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
        dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
        dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")

    dataset2_inputs = gr.Column(visible=True)
    with dataset2_inputs:
        with gr.Row():
            dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
            dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
            dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")

    threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")

    with gr.Row():
        compute_button = gr.Button("Deduplicate")

    status_output = gr.Markdown(elem_id="status_output")
    result_output = gr.Markdown()

    def update_visibility(choice: str):
        return gr.update(visible=(choice == "Cross-dataset"))

    deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)

    compute_button.click(
        fn=perform_deduplication,
        inputs=[
            deduplication_type,
            dataset1_name,
            dataset1_split,
            dataset1_text_column,
            dataset2_name,
            dataset2_split,
            dataset2_text_column,
            threshold,
        ],
        outputs=[status_output, result_output],
    )

demo.launch()