Yacine Jernite
commited on
Commit
·
7bffaaf
1
Parent(s):
89e8e87
initial commit
Browse files- .gitattributes +2 -0
- README.md +13 -7
- app.py +75 -0
- data_measurements_clusters/__init__.py +1 -0
- data_measurements_clusters/clustering.py +691 -0
- data_measurements_clusters/dataset_utils.py +292 -0
- posts/conclusion.py +58 -0
- posts/context.py +104 -0
- posts/dataset_exploration.py +143 -0
- posts/model_exploration.py +340 -0
- posts/welcome.py +74 -0
.gitattributes
CHANGED
|
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,13 +1,19 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: streamlit
|
| 7 |
-
sdk_version: 1.10.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Task Exploration - Automatic Content Moderation
|
| 3 |
+
emoji: 🤗
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: streamlit
|
|
|
|
| 7 |
app_file: app.py
|
| 8 |
pinned: false
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# Task Exploration
|
| 12 |
+
|
| 13 |
+
[](https://huggingface.co/spaces/aymm/Task-Exploration-Hate-Speech)
|
| 14 |
+
|
| 15 |
+
The context and definition of hate speech detection as a modeling task.
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
Autogenerated using [this template](https://github.com/nateraw/spaces-template)
|
app.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import re
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
import yaml
|
| 7 |
+
|
| 8 |
+
REGEX_YAML_BLOCK = re.compile(r"---[\n\r]+([\S\s]*?)[\n\r]+---[\n\r](.*)", re.DOTALL)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def render_preview(image, title, description):
|
| 12 |
+
with st.container():
|
| 13 |
+
image_col, text_col = st.columns((1, 4))
|
| 14 |
+
with image_col:
|
| 15 |
+
st.image(image)
|
| 16 |
+
|
| 17 |
+
with text_col:
|
| 18 |
+
st.subheader(title)
|
| 19 |
+
st.write(description)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def render_page(post_path: Path):
|
| 23 |
+
mod = importlib.import_module(str(post_path))
|
| 24 |
+
mod.run_article()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_page_data(post_path: Path):
|
| 28 |
+
mod = importlib.import_module(str(post_path))
|
| 29 |
+
return {
|
| 30 |
+
"title": mod.title,
|
| 31 |
+
"description": mod.description,
|
| 32 |
+
"date": mod.date,
|
| 33 |
+
"thumbnail": mod.thumbnail,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main():
|
| 38 |
+
st.set_page_config(layout="wide")
|
| 39 |
+
posts = {
|
| 40 |
+
"posts.welcome": "Welcome",
|
| 41 |
+
"posts.context": "Hate Speech in ACM",
|
| 42 |
+
"posts.dataset_exploration": "ACM Datasets",
|
| 43 |
+
"posts.model_exploration": "ACM Models",
|
| 44 |
+
"posts.conclusion": "Key Takeaways",
|
| 45 |
+
}
|
| 46 |
+
page_to_show = list(posts.keys())[0]
|
| 47 |
+
with st.sidebar:
|
| 48 |
+
|
| 49 |
+
st.markdown(
|
| 50 |
+
"""
|
| 51 |
+
<div align="center">
|
| 52 |
+
<h1>Task Exploration: Hate Speech Detection</h1>
|
| 53 |
+
</div>
|
| 54 |
+
""",
|
| 55 |
+
unsafe_allow_html=True,
|
| 56 |
+
)
|
| 57 |
+
st.markdown("---")
|
| 58 |
+
|
| 59 |
+
page_to_show = st.selectbox(
|
| 60 |
+
"Navigation menu:",
|
| 61 |
+
posts,
|
| 62 |
+
format_func=lambda x:posts[x],
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
for post in posts:
|
| 66 |
+
data = get_page_data(Path(post))
|
| 67 |
+
clicked = render_preview(
|
| 68 |
+
data.get("thumbnail"), data.get("title"), data.get("description")
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if page_to_show:
|
| 72 |
+
render_page(Path(page_to_show))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
main()
|
data_measurements_clusters/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .clustering import Clustering
|
data_measurements_clusters/clustering.py
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import gzip
|
| 16 |
+
import json
|
| 17 |
+
import math
|
| 18 |
+
import os
|
| 19 |
+
from os.path import exists
|
| 20 |
+
from os.path import join as pjoin
|
| 21 |
+
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import plotly.express as px
|
| 24 |
+
import plotly.graph_objects as go
|
| 25 |
+
import torch
|
| 26 |
+
import transformers
|
| 27 |
+
from datasets import load_dataset
|
| 28 |
+
from huggingface_hub import HfApi
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
|
| 31 |
+
# from .dataset_utils import prepare_clustering_dataset
|
| 32 |
+
|
| 33 |
+
pd.options.display.max_colwidth = 256
|
| 34 |
+
|
| 35 |
+
_CACHE_DIR = "cache_dir"
|
| 36 |
+
|
| 37 |
+
_DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
| 38 |
+
|
| 39 |
+
_MAX_MERGE = 20000000 # to run on 64GB RAM laptop
|
| 40 |
+
|
| 41 |
+
def sentence_mean_pooling(model_output, attention_mask):
|
| 42 |
+
token_embeddings = model_output[
|
| 43 |
+
0
|
| 44 |
+
] # First element of model_output contains all token embeddings
|
| 45 |
+
input_mask_expanded = (
|
| 46 |
+
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 47 |
+
)
|
| 48 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
| 49 |
+
input_mask_expanded.sum(1), min=1e-9
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# get nearest neighbors of a centroid by dot product
|
| 54 |
+
def get_examplars(example_ids, centroid, embeddings, dset, n_examplars):
|
| 55 |
+
example_embeds = embeddings[example_ids]
|
| 56 |
+
example_scores = torch.mv(example_embeds, centroid)
|
| 57 |
+
s_scores, s_ids = example_scores.sort(dim=-1, descending=True)
|
| 58 |
+
examplars = [
|
| 59 |
+
(example_ids[i.item()], s.item())
|
| 60 |
+
for i, s in zip(s_ids[:n_examplars], s_scores[:n_examplars])
|
| 61 |
+
]
|
| 62 |
+
res = []
|
| 63 |
+
for eid, score in examplars:
|
| 64 |
+
dct = dict(dset[eid])
|
| 65 |
+
dct["score"] = score
|
| 66 |
+
res += [dct]
|
| 67 |
+
return res
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# order node children so that the large ones are in the middle
|
| 71 |
+
# makes visualization more balanced
|
| 72 |
+
def pretty_order(nodes, node_ids):
|
| 73 |
+
sorted_ids = sorted(node_ids, key=lambda nid: nodes[nid]["weight"])
|
| 74 |
+
sorted_a = [nid for i, nid in enumerate(sorted_ids) if i % 2 == 0]
|
| 75 |
+
sorted_b = [nid for i, nid in enumerate(sorted_ids) if i % 2 == 1]
|
| 76 |
+
sorted_b.reverse()
|
| 77 |
+
return sorted_a + sorted_b
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def make_tree_plot(node_list, root_id, max_depth=-1):
|
| 81 |
+
# make plot nodes
|
| 82 |
+
plot_nodes = [{} for _ in node_list]
|
| 83 |
+
|
| 84 |
+
root = {
|
| 85 |
+
"parent_id": -1,
|
| 86 |
+
"node_id": root_id,
|
| 87 |
+
"label": node_list[root_id]["hover_text"],
|
| 88 |
+
"weight": node_list[root_id]["weight"],
|
| 89 |
+
"num_leaves": 0,
|
| 90 |
+
"children_ids": node_list[root_id]["children_ids"],
|
| 91 |
+
"Xmin": 0,
|
| 92 |
+
"Y": 0,
|
| 93 |
+
}
|
| 94 |
+
plot_nodes[root_id] = root
|
| 95 |
+
|
| 96 |
+
root_depth = node_list[root_id]["depth"]
|
| 97 |
+
|
| 98 |
+
def rec_make_coordinates(node):
|
| 99 |
+
total_weight = 0
|
| 100 |
+
recurse = (max_depth == -1) or (
|
| 101 |
+
node_list[node["node_id"]]["depth"] - root_depth < max_depth - 1
|
| 102 |
+
)
|
| 103 |
+
for cid in node["children_ids"]:
|
| 104 |
+
plot_nodes[cid] = {
|
| 105 |
+
"parent_id": node["node_id"],
|
| 106 |
+
"node_id": cid,
|
| 107 |
+
"label": node_list[cid]["hover_text"],
|
| 108 |
+
"weight": node_list[cid]["weight"],
|
| 109 |
+
"children_ids": node_list[cid]["children_ids"] if recurse else [],
|
| 110 |
+
"Xmin": node["Xmin"] + total_weight,
|
| 111 |
+
"Y": node["Y"] - 1,
|
| 112 |
+
}
|
| 113 |
+
plot_nodes[cid]["num_leaves"] = 1 if len(plot_nodes[cid]["children_ids"]) == 0 else 0
|
| 114 |
+
rec_make_coordinates(plot_nodes[cid])
|
| 115 |
+
total_weight += plot_nodes[cid]["num_leaves"]
|
| 116 |
+
node["num_leaves"] += plot_nodes[cid]["num_leaves"]
|
| 117 |
+
node["Xmax"] = node["Xmin"] + node["num_leaves"]
|
| 118 |
+
node["X"] = node["Xmin"] + (node["num_leaves"] / 2)
|
| 119 |
+
|
| 120 |
+
rec_make_coordinates(root)
|
| 121 |
+
|
| 122 |
+
subtree_nodes = [node for node in plot_nodes if len(node) > 0]
|
| 123 |
+
nid_map = dict([(node["node_id"], nid) for nid, node in enumerate(subtree_nodes)])
|
| 124 |
+
labels = [node["label"] for node in subtree_nodes]
|
| 125 |
+
|
| 126 |
+
E = [] # list of edges
|
| 127 |
+
Xn = []
|
| 128 |
+
Yn = []
|
| 129 |
+
Xe = []
|
| 130 |
+
Ye = []
|
| 131 |
+
for nid, node in enumerate(subtree_nodes):
|
| 132 |
+
Xn += [node["X"]]
|
| 133 |
+
Yn += [node["Y"]]
|
| 134 |
+
for cid in node["children_ids"]:
|
| 135 |
+
child = plot_nodes[cid]
|
| 136 |
+
E += [(nid, nid_map[child["node_id"]])]
|
| 137 |
+
Xe += [node["X"], child["X"], None]
|
| 138 |
+
Ye += [node["Y"], child["Y"], None]
|
| 139 |
+
|
| 140 |
+
# make figure
|
| 141 |
+
fig = go.Figure()
|
| 142 |
+
fig.add_trace(
|
| 143 |
+
go.Scatter(
|
| 144 |
+
x=Xe,
|
| 145 |
+
y=Ye,
|
| 146 |
+
mode="lines",
|
| 147 |
+
name="",
|
| 148 |
+
line=dict(color="rgb(210,210,210)", width=1),
|
| 149 |
+
hoverinfo="none",
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
fig.add_trace(
|
| 153 |
+
go.Scatter(
|
| 154 |
+
x=Xn,
|
| 155 |
+
y=Yn,
|
| 156 |
+
mode="markers",
|
| 157 |
+
name="nodes",
|
| 158 |
+
marker=dict(
|
| 159 |
+
symbol="circle-dot",
|
| 160 |
+
size=18,
|
| 161 |
+
color="#6175c1",
|
| 162 |
+
line=dict(color="rgb(50,50,50)", width=1)
|
| 163 |
+
# '#DB4551',
|
| 164 |
+
),
|
| 165 |
+
text=labels,
|
| 166 |
+
hoverinfo="text",
|
| 167 |
+
opacity=0.8,
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
fig.layout.showlegend = False
|
| 171 |
+
return fig
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class ClusteringBuilder:
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
dataset_name,
|
| 178 |
+
config_name,
|
| 179 |
+
split_name,
|
| 180 |
+
input_field_path,
|
| 181 |
+
label_name,
|
| 182 |
+
num_rows,
|
| 183 |
+
model_name=_DEFAULT_MODEL,
|
| 184 |
+
):
|
| 185 |
+
"""Item embeddings and clustering"""
|
| 186 |
+
self.dataset_name = dataset_name
|
| 187 |
+
self.config_name = config_name
|
| 188 |
+
self.split_name = split_name
|
| 189 |
+
self.input_field_path = input_field_path
|
| 190 |
+
self.label_name = label_name
|
| 191 |
+
self.num_rows = num_rows
|
| 192 |
+
self.cache_path_list = [
|
| 193 |
+
_CACHE_DIR,
|
| 194 |
+
dataset_name.replace("/", "---"),
|
| 195 |
+
f"{'default' if config_name is None else config_name}",
|
| 196 |
+
f"{'train' if split_name is None else split_name}",
|
| 197 |
+
f"field-{'->'.join(input_field_path)}-label-{label_name}",
|
| 198 |
+
f"{num_rows}_rows",
|
| 199 |
+
model_name.replace("/", "---"),
|
| 200 |
+
]
|
| 201 |
+
self.cache_path = pjoin(*self.cache_path_list)
|
| 202 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 203 |
+
self.model_name = model_name
|
| 204 |
+
|
| 205 |
+
# prepare embeddings for the dataset
|
| 206 |
+
def set_model(self):
|
| 207 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
| 208 |
+
self.model = transformers.AutoModel.from_pretrained(self.model_name).to(
|
| 209 |
+
self.device
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
def set_features_dataset(self, use_streaming, use_auth_token, use_dataset):
|
| 213 |
+
dset, dset_path = prepare_clustering_dataset(
|
| 214 |
+
dataset_name=self.dataset_name,
|
| 215 |
+
input_field_path=self.input_field_path,
|
| 216 |
+
label_name=self.label_name,
|
| 217 |
+
config_name=self.config_name,
|
| 218 |
+
split_name=self.split_name,
|
| 219 |
+
num_rows=self.num_rows,
|
| 220 |
+
use_streaming=use_streaming,
|
| 221 |
+
use_auth_token=use_auth_token,
|
| 222 |
+
use_dataset=use_dataset,
|
| 223 |
+
)
|
| 224 |
+
self.features_dset = dset
|
| 225 |
+
|
| 226 |
+
def compute_feature_embeddings(self, sentences):
|
| 227 |
+
batch = self.tokenizer(
|
| 228 |
+
sentences, padding=True, truncation=True, return_tensors="pt"
|
| 229 |
+
)
|
| 230 |
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
| 231 |
+
with torch.no_grad():
|
| 232 |
+
model_output = self.model(**batch)
|
| 233 |
+
sentence_embeds = sentence_mean_pooling(
|
| 234 |
+
model_output, batch["attention_mask"]
|
| 235 |
+
)
|
| 236 |
+
sentence_embeds /= sentence_embeds.norm(dim=-1, keepdim=True)
|
| 237 |
+
return sentence_embeds
|
| 238 |
+
|
| 239 |
+
def set_embeddings_dataset(self):
|
| 240 |
+
def batch_embed(examples):
|
| 241 |
+
return {
|
| 242 |
+
"embedding": [
|
| 243 |
+
embed.tolist()
|
| 244 |
+
for embed in self.compute_feature_embeddings(examples["field"])
|
| 245 |
+
]
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
if not exists(self.cache_path):
|
| 249 |
+
os.mkdir(self.cache_path)
|
| 250 |
+
|
| 251 |
+
self.embeddings_dset = self.features_dset.map(
|
| 252 |
+
batch_embed,
|
| 253 |
+
batched=True,
|
| 254 |
+
batch_size=32,
|
| 255 |
+
cache_file_name=pjoin(self.cache_path, "embeddings_dset"),
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
def prepare_embeddings(
|
| 259 |
+
self,
|
| 260 |
+
use_streaming=True,
|
| 261 |
+
use_auth_token=None,
|
| 262 |
+
use_dataset=None,
|
| 263 |
+
):
|
| 264 |
+
self.set_model()
|
| 265 |
+
self.set_features_dataset(use_streaming, use_auth_token, use_dataset)
|
| 266 |
+
self.set_embeddings_dataset()
|
| 267 |
+
|
| 268 |
+
# make cluster tree
|
| 269 |
+
def prepare_merges(self, batch_size, low_thres):
|
| 270 |
+
self.embeddings = torch.Tensor(self.embeddings_dset["embedding"])
|
| 271 |
+
all_indices = torch.LongTensor(torch.Size([0, 2]))
|
| 272 |
+
all_scores = torch.Tensor(torch.Size([0]))
|
| 273 |
+
n_batches = math.ceil(self.embeddings_dset.num_rows / batch_size)
|
| 274 |
+
for a in range(n_batches):
|
| 275 |
+
for b in tqdm(range(a, n_batches)):
|
| 276 |
+
cos_scores = torch.mm(
|
| 277 |
+
self.embeddings[a * batch_size : (a + 1) * batch_size],
|
| 278 |
+
self.embeddings[b * batch_size : (b + 1) * batch_size].t(),
|
| 279 |
+
)
|
| 280 |
+
if a == b:
|
| 281 |
+
cos_scores = cos_scores.triu(diagonal=1)
|
| 282 |
+
merge_indices = torch.nonzero(cos_scores > low_thres)
|
| 283 |
+
merge_indices[:, 0] += a * batch_size
|
| 284 |
+
merge_indices[:, 1] += b * batch_size
|
| 285 |
+
merge_scores = cos_scores[cos_scores > low_thres]
|
| 286 |
+
all_indices = torch.cat([all_indices, merge_indices], dim=0)
|
| 287 |
+
all_scores = torch.cat([all_scores, merge_scores], dim=0)
|
| 288 |
+
self.sorted_scores, sorted_score_ids = all_scores.sort(dim=0, descending=True)
|
| 289 |
+
self.sorted_scores = self.sorted_scores[:_MAX_MERGE]
|
| 290 |
+
sorted_score_ids = sorted_score_ids[:_MAX_MERGE]
|
| 291 |
+
self.sorted_indices = all_indices[sorted_score_ids]
|
| 292 |
+
|
| 293 |
+
def make_starting_nodes(self, identical_threshold):
|
| 294 |
+
identical_indices = self.sorted_indices[
|
| 295 |
+
self.sorted_scores >= identical_threshold
|
| 296 |
+
]
|
| 297 |
+
identical_inter = identical_indices[
|
| 298 |
+
identical_indices[:, 1].sort(stable=True).indices
|
| 299 |
+
]
|
| 300 |
+
identical_sorted = identical_inter[
|
| 301 |
+
identical_inter[:, 0].sort(stable=True).indices
|
| 302 |
+
]
|
| 303 |
+
self.parents = {}
|
| 304 |
+
for a_pre, b_pre in identical_sorted:
|
| 305 |
+
a = a_pre.item()
|
| 306 |
+
b = b_pre.item()
|
| 307 |
+
while self.parents.get(a, -1) != -1:
|
| 308 |
+
a = self.parents[a]
|
| 309 |
+
self.parents[b] = a
|
| 310 |
+
self.duplicates = {}
|
| 311 |
+
for a, b in self.parents.items():
|
| 312 |
+
self.duplicates[b] = self.duplicates.get(b, []) + [a]
|
| 313 |
+
self.nodes = {}
|
| 314 |
+
for node_id in range(self.features_dset.num_rows):
|
| 315 |
+
if node_id in self.parents:
|
| 316 |
+
continue
|
| 317 |
+
else:
|
| 318 |
+
self.nodes[node_id] = {
|
| 319 |
+
"node_id": node_id,
|
| 320 |
+
"parent_id": -1,
|
| 321 |
+
"children": [],
|
| 322 |
+
"children_ids": [],
|
| 323 |
+
"example_ids": [node_id],
|
| 324 |
+
"weight": 1,
|
| 325 |
+
"merge_threshold": 0.98,
|
| 326 |
+
"depth": 0,
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
def make_merge_nodes(self, identical_threshold, thres_step):
|
| 330 |
+
new_node_id = self.features_dset.num_rows
|
| 331 |
+
current_thres = identical_threshold
|
| 332 |
+
depth = 1
|
| 333 |
+
merge_ids = self.sorted_indices[self.sorted_scores < identical_threshold]
|
| 334 |
+
merge_scores = self.sorted_scores[self.sorted_scores < identical_threshold]
|
| 335 |
+
for (node_id_a, node_id_b), merge_score in tqdm(
|
| 336 |
+
zip(merge_ids, merge_scores), total=len(merge_ids)
|
| 337 |
+
):
|
| 338 |
+
if merge_score.item() < current_thres:
|
| 339 |
+
current_thres -= thres_step
|
| 340 |
+
merge_a = node_id_a.item()
|
| 341 |
+
while self.parents.get(merge_a, -1) != -1:
|
| 342 |
+
merge_a = self.parents[merge_a]
|
| 343 |
+
self.parents[node_id_a] = merge_a
|
| 344 |
+
merge_b = node_id_b.item()
|
| 345 |
+
while self.parents.get(merge_b, -1) != -1:
|
| 346 |
+
merge_b = self.parents[merge_b]
|
| 347 |
+
self.parents[node_id_b] = merge_b
|
| 348 |
+
if merge_a == merge_b:
|
| 349 |
+
continue
|
| 350 |
+
else:
|
| 351 |
+
merge_b, merge_a = sorted([merge_a, merge_b])
|
| 352 |
+
node_a = self.nodes[merge_a]
|
| 353 |
+
node_b = self.nodes[merge_b]
|
| 354 |
+
if (node_a["depth"]) > 0 and min(
|
| 355 |
+
node_a["merge_threshold"], node_b["merge_threshold"]
|
| 356 |
+
) == current_thres:
|
| 357 |
+
node_a["depth"] = max(node_a["depth"], node_b["depth"])
|
| 358 |
+
node_a["weight"] += node_b["weight"]
|
| 359 |
+
node_a["children_ids"] += (
|
| 360 |
+
node_b["children_ids"]
|
| 361 |
+
if node_b["depth"] > 0
|
| 362 |
+
else [node_b["node_id"]]
|
| 363 |
+
)
|
| 364 |
+
for cid in node_b["children_ids"]:
|
| 365 |
+
self.nodes[cid]["parent_id"] = node_a["node_id"]
|
| 366 |
+
self.parents[cid] = node_a["node_id"]
|
| 367 |
+
node_b["parent_id"] = node_a["node_id"]
|
| 368 |
+
self.parents[node_b["node_id"]] = node_a["node_id"]
|
| 369 |
+
else:
|
| 370 |
+
new_nid = new_node_id
|
| 371 |
+
new_node_id += 1
|
| 372 |
+
new_node = {
|
| 373 |
+
"node_id": new_nid,
|
| 374 |
+
"parent_id": -1,
|
| 375 |
+
"children_ids": [node_a["node_id"], node_b["node_id"]],
|
| 376 |
+
"example_ids": [],
|
| 377 |
+
"weight": node_a["weight"] + node_b["weight"],
|
| 378 |
+
"merge_threshold": current_thres,
|
| 379 |
+
"depth": max(node_a["depth"], node_b["depth"]) + 1,
|
| 380 |
+
}
|
| 381 |
+
depth = max(depth, new_node["depth"])
|
| 382 |
+
node_a["parent_id"] = new_nid
|
| 383 |
+
node_b["parent_id"] = new_nid
|
| 384 |
+
self.parents[node_a["node_id"]] = new_nid
|
| 385 |
+
self.parents[node_b["node_id"]] = new_nid
|
| 386 |
+
self.parents[node_id_a] = new_nid
|
| 387 |
+
self.parents[node_id_b] = new_nid
|
| 388 |
+
self.nodes[new_nid] = new_node
|
| 389 |
+
return new_node_id
|
| 390 |
+
|
| 391 |
+
def collapse_nodes(self, node, min_weight):
|
| 392 |
+
children = [
|
| 393 |
+
self.collapse_nodes(self.nodes[cid], min_weight)
|
| 394 |
+
for cid in node["children_ids"]
|
| 395 |
+
if self.nodes[cid]["weight"] >= min_weight
|
| 396 |
+
]
|
| 397 |
+
extras = [
|
| 398 |
+
lid
|
| 399 |
+
for cid in node["children_ids"]
|
| 400 |
+
if self.nodes[cid]["weight"] < min_weight
|
| 401 |
+
for lid in self.collapse_nodes(self.nodes[cid], min_weight)["example_ids"]
|
| 402 |
+
] + node["example_ids"]
|
| 403 |
+
extras_embed = (
|
| 404 |
+
torch.cat(
|
| 405 |
+
[self.embeddings[eid][None, :] for eid in extras],
|
| 406 |
+
dim=0,
|
| 407 |
+
).sum(dim=0)
|
| 408 |
+
if len(extras) > 0
|
| 409 |
+
else torch.zeros(self.embeddings.shape[-1])
|
| 410 |
+
)
|
| 411 |
+
if len(children) == 0:
|
| 412 |
+
node["extras"] = extras
|
| 413 |
+
node["children_ids"] = []
|
| 414 |
+
node["example_ids"] = extras
|
| 415 |
+
node["embedding_sum"] = extras_embed
|
| 416 |
+
elif len(children) == 1:
|
| 417 |
+
node["extras"] = extras + children[0]["extras"]
|
| 418 |
+
node["children_ids"] = children[0]["children_ids"]
|
| 419 |
+
node["example_ids"] = extras + children[0]["example_ids"]
|
| 420 |
+
node["embedding_sum"] = extras_embed + children[0]["embedding_sum"]
|
| 421 |
+
else:
|
| 422 |
+
node["extras"] = extras
|
| 423 |
+
node["children_ids"] = [child["node_id"] for child in children]
|
| 424 |
+
node["example_ids"] = extras + [
|
| 425 |
+
eid for child in children for eid in child["example_ids"]
|
| 426 |
+
]
|
| 427 |
+
node["embedding_sum"] = (
|
| 428 |
+
extras_embed
|
| 429 |
+
+ torch.cat(
|
| 430 |
+
[child["embedding_sum"][None, :] for child in children],
|
| 431 |
+
dim=0,
|
| 432 |
+
).sum(dim=0)
|
| 433 |
+
)
|
| 434 |
+
assert (
|
| 435 |
+
len(node["example_ids"]) == node["weight"]
|
| 436 |
+
), f"stuck at {node['node_id']} - {len(node['example_ids'])} - {node['weight']}"
|
| 437 |
+
return node
|
| 438 |
+
|
| 439 |
+
def finalize_node(self, node, parent_id, n_examplars, with_labels):
|
| 440 |
+
new_node_id = len(self.tree_node_list)
|
| 441 |
+
new_node = {
|
| 442 |
+
"node_id": new_node_id,
|
| 443 |
+
"parent_id": parent_id,
|
| 444 |
+
"depth": 0
|
| 445 |
+
if parent_id == -1
|
| 446 |
+
else self.tree_node_list[parent_id]["depth"] + 1,
|
| 447 |
+
"merged_at": node["merge_threshold"],
|
| 448 |
+
"weight": node["weight"],
|
| 449 |
+
"is_extra": False,
|
| 450 |
+
}
|
| 451 |
+
self.tree_node_list += [new_node]
|
| 452 |
+
centroid = node["embedding_sum"] / node["embedding_sum"].norm()
|
| 453 |
+
new_node["centroid"] = centroid.tolist()
|
| 454 |
+
new_node["examplars"] = get_examplars(
|
| 455 |
+
node["example_ids"],
|
| 456 |
+
centroid,
|
| 457 |
+
self.embeddings,
|
| 458 |
+
self.features_dset,
|
| 459 |
+
n_examplars,
|
| 460 |
+
)
|
| 461 |
+
label_counts = {}
|
| 462 |
+
if with_labels:
|
| 463 |
+
for eid in node["example_ids"]:
|
| 464 |
+
label = self.features_dset[eid]["label"]
|
| 465 |
+
label_counts[label] = label_counts.get(label, 0) + 1
|
| 466 |
+
new_node["label_counts"] = sorted(
|
| 467 |
+
label_counts.items(), key=lambda x: x[1], reverse=True
|
| 468 |
+
)
|
| 469 |
+
if len(node["children_ids"]) == 0:
|
| 470 |
+
new_node["children_ids"] = []
|
| 471 |
+
else:
|
| 472 |
+
children = [
|
| 473 |
+
self.nodes[cid]
|
| 474 |
+
for cid in pretty_order(self.nodes, node["children_ids"])
|
| 475 |
+
]
|
| 476 |
+
children_ids = [
|
| 477 |
+
self.finalize_node(child, new_node_id, n_examplars, with_labels)
|
| 478 |
+
for child in children
|
| 479 |
+
]
|
| 480 |
+
new_node["children_ids"] = children_ids
|
| 481 |
+
if len(node["extras"]) > 0:
|
| 482 |
+
extra_node = {
|
| 483 |
+
"node_id": len(self.tree_node_list),
|
| 484 |
+
"parent_id": new_node_id,
|
| 485 |
+
"depth": new_node["depth"] + 1,
|
| 486 |
+
"merged_at": node["merge_threshold"],
|
| 487 |
+
"weight": len(node["extras"]),
|
| 488 |
+
"is_extra": True,
|
| 489 |
+
"centroid": new_node["centroid"],
|
| 490 |
+
"examplars": get_examplars(
|
| 491 |
+
node["extras"],
|
| 492 |
+
centroid,
|
| 493 |
+
self.embeddings,
|
| 494 |
+
self.features_dset,
|
| 495 |
+
n_examplars,
|
| 496 |
+
),
|
| 497 |
+
}
|
| 498 |
+
self.tree_node_list += [extra_node]
|
| 499 |
+
label_counts = {}
|
| 500 |
+
if with_labels:
|
| 501 |
+
for eid in node["extras"]:
|
| 502 |
+
label = self.features_dset[eid]["label"]
|
| 503 |
+
label_counts[label] = label_counts.get(label, 0) + 1
|
| 504 |
+
extra_node["label_counts"] = sorted(
|
| 505 |
+
label_counts.items(), key=lambda x: x[1], reverse=True
|
| 506 |
+
)
|
| 507 |
+
extra_node["children_ids"] = []
|
| 508 |
+
new_node["children_ids"] += [extra_node["node_id"]]
|
| 509 |
+
return new_node_id
|
| 510 |
+
|
| 511 |
+
def make_hover_text(self, num_examples=5, text_width=64, with_labels=False):
|
| 512 |
+
for nid, node in enumerate(self.tree_node_list):
|
| 513 |
+
line_list = [
|
| 514 |
+
f"Node {nid:3d} - {node['weight']:6d} items - Linking threshold: {node['merged_at']:.2f}"
|
| 515 |
+
]
|
| 516 |
+
for examplar in node["examplars"][:num_examples]:
|
| 517 |
+
line_list += [
|
| 518 |
+
f"{examplar['ids']:6d}:{examplar['score']:.2f} - {examplar['field'][:text_width]}"
|
| 519 |
+
+ (f" - {examplar['label']}" if with_labels else "")
|
| 520 |
+
]
|
| 521 |
+
if with_labels:
|
| 522 |
+
line_list += ["Label distribution"]
|
| 523 |
+
for label, count in node["label_counts"]:
|
| 524 |
+
line_list += [f" - label: {label} - {count} items"]
|
| 525 |
+
node["hover_text"] = "<br>".join(line_list)
|
| 526 |
+
|
| 527 |
+
def build_tree(
|
| 528 |
+
self,
|
| 529 |
+
batch_size=10000,
|
| 530 |
+
low_thres=0.5,
|
| 531 |
+
identical_threshold=0.95,
|
| 532 |
+
thres_step=0.05,
|
| 533 |
+
min_weight=10,
|
| 534 |
+
n_examplars=25,
|
| 535 |
+
hover_examples=5,
|
| 536 |
+
hover_text_width=64,
|
| 537 |
+
):
|
| 538 |
+
self.prepare_merges(batch_size, low_thres)
|
| 539 |
+
self.make_starting_nodes(identical_threshold)
|
| 540 |
+
# make a root to join all trees
|
| 541 |
+
root_node_id = self.make_merge_nodes(identical_threshold, thres_step)
|
| 542 |
+
top_nodes = [node for node in self.nodes.values() if node["parent_id"] == -1]
|
| 543 |
+
root_node = {
|
| 544 |
+
"node_id": root_node_id,
|
| 545 |
+
"parent_id": -1,
|
| 546 |
+
"children_ids": [node["node_id"] for node in top_nodes],
|
| 547 |
+
"example_ids": [],
|
| 548 |
+
"weight": sum([node["weight"] for node in top_nodes]),
|
| 549 |
+
"merge_threshold": -1.0,
|
| 550 |
+
"depth": 1 + max([node["depth"] for node in top_nodes]),
|
| 551 |
+
}
|
| 552 |
+
for node in top_nodes:
|
| 553 |
+
node["parent_id"] = root_node_id
|
| 554 |
+
self.nodes[root_node_id] = root_node
|
| 555 |
+
_ = self.collapse_nodes(root_node, min_weight)
|
| 556 |
+
self.tree_node_list = []
|
| 557 |
+
self.finalize_node(
|
| 558 |
+
root_node,
|
| 559 |
+
-1,
|
| 560 |
+
n_examplars,
|
| 561 |
+
with_labels=(self.label_name is not None),
|
| 562 |
+
)
|
| 563 |
+
self.make_hover_text(
|
| 564 |
+
num_examples=hover_examples,
|
| 565 |
+
text_width=hover_text_width,
|
| 566 |
+
with_labels=(self.label_name is not None),
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
def push_to_hub(self, use_auth_token=None, file_name=None):
|
| 570 |
+
path_list = self.cache_path_list
|
| 571 |
+
name = "tree" if file_name is None else file_name
|
| 572 |
+
tree_file = pjoin(pjoin(*path_list), f"{name}.jsonl.gz")
|
| 573 |
+
fout = gzip.open(tree_file, "w")
|
| 574 |
+
for node in tqdm(self.tree_node_list):
|
| 575 |
+
_ = fout.write((json.dumps(node) + "\n").encode("utf-8"))
|
| 576 |
+
fout.close()
|
| 577 |
+
api = HfApi()
|
| 578 |
+
file_loc = api.upload_file(
|
| 579 |
+
path_or_fileobj=tree_file,
|
| 580 |
+
path_in_repo=pjoin(pjoin(*path_list[1:]), f"{name}.jsonl.gz"),
|
| 581 |
+
repo_id="yjernite/datasets_clusters",
|
| 582 |
+
token=use_auth_token,
|
| 583 |
+
repo_type="dataset",
|
| 584 |
+
)
|
| 585 |
+
return file_loc
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
class Clustering:
|
| 589 |
+
def __init__(
|
| 590 |
+
self,
|
| 591 |
+
dataset_name,
|
| 592 |
+
config_name,
|
| 593 |
+
split_name,
|
| 594 |
+
input_field_path,
|
| 595 |
+
label_name,
|
| 596 |
+
num_rows,
|
| 597 |
+
n_examplars=10,
|
| 598 |
+
model_name=_DEFAULT_MODEL,
|
| 599 |
+
file_name=None,
|
| 600 |
+
max_depth_subtree=3,
|
| 601 |
+
):
|
| 602 |
+
self.dataset_name = dataset_name
|
| 603 |
+
self.config_name = config_name
|
| 604 |
+
self.split_name = split_name
|
| 605 |
+
self.input_field_path = input_field_path
|
| 606 |
+
self.label_name = label_name
|
| 607 |
+
self.num_rows = num_rows
|
| 608 |
+
self.model_name = model_name
|
| 609 |
+
self.n_examplars = n_examplars
|
| 610 |
+
self.file_name = "tree" if file_name is None else file_name
|
| 611 |
+
self.repo_path_list = [
|
| 612 |
+
dataset_name.replace("/", "---"),
|
| 613 |
+
f"{'default' if config_name is None else config_name}",
|
| 614 |
+
f"{'train' if split_name is None else split_name}",
|
| 615 |
+
f"field-{'->'.join(input_field_path)}-label-{label_name}",
|
| 616 |
+
f"{num_rows}_rows",
|
| 617 |
+
model_name.replace("/", "---"),
|
| 618 |
+
f"{self.file_name}.jsonl.gz",
|
| 619 |
+
]
|
| 620 |
+
self.repo_path = pjoin(*self.repo_path_list)
|
| 621 |
+
self.node_list = load_dataset(
|
| 622 |
+
"yjernite/datasets_clusters", data_files=[self.repo_path]
|
| 623 |
+
)["train"]
|
| 624 |
+
self.node_reps = [{} for node in self.node_list]
|
| 625 |
+
self.max_depth_subtree = max_depth_subtree
|
| 626 |
+
|
| 627 |
+
def set_full_tree(self):
|
| 628 |
+
self.node_reps[0]["tree"] = self.node_reps[0].get(
|
| 629 |
+
"tree",
|
| 630 |
+
make_tree_plot(
|
| 631 |
+
self.node_list,
|
| 632 |
+
0,
|
| 633 |
+
),
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
def get_full_tree(self):
|
| 637 |
+
self.set_full_tree()
|
| 638 |
+
return self.node_reps[0]["tree"]
|
| 639 |
+
|
| 640 |
+
def set_node_subtree(self, node_id):
|
| 641 |
+
self.node_reps[node_id]["subtree"] = self.node_reps[node_id].get(
|
| 642 |
+
"subtree",
|
| 643 |
+
make_tree_plot(
|
| 644 |
+
self.node_list,
|
| 645 |
+
node_id,
|
| 646 |
+
self.max_depth_subtree,
|
| 647 |
+
),
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
def get_node_subtree(self, node_id):
|
| 651 |
+
self.set_node_subtree(node_id)
|
| 652 |
+
return self.node_reps[node_id]["subtree"]
|
| 653 |
+
|
| 654 |
+
def set_node_examplars(self, node_id):
|
| 655 |
+
self.node_reps[node_id]["examplars"] = self.node_reps[node_id].get(
|
| 656 |
+
"examplars",
|
| 657 |
+
pd.DataFrame(
|
| 658 |
+
[
|
| 659 |
+
{
|
| 660 |
+
"id": exple["ids"],
|
| 661 |
+
"score": exple["score"],
|
| 662 |
+
"field": exple["field"],
|
| 663 |
+
"label": exple.get("label", "N/A"),
|
| 664 |
+
}
|
| 665 |
+
for exple in self.node_list[node_id]["examplars"]
|
| 666 |
+
][: self.n_examplars]
|
| 667 |
+
),
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
def get_node_examplars(self, node_id):
|
| 671 |
+
self.set_node_examplars(node_id)
|
| 672 |
+
return self.node_reps[node_id]["examplars"]
|
| 673 |
+
|
| 674 |
+
def set_node_label_chart(self, node_id):
|
| 675 |
+
self.node_reps[node_id]["label_chart"] = self.node_reps[node_id].get(
|
| 676 |
+
"label_chart",
|
| 677 |
+
px.pie(
|
| 678 |
+
values=[ct for lab, ct in self.node_list[node_id]["label_counts"]],
|
| 679 |
+
names=[
|
| 680 |
+
f"Label {lab}"
|
| 681 |
+
for lab, ct in self.node_list[node_id]["label_counts"]
|
| 682 |
+
],
|
| 683 |
+
color_discrete_sequence=px.colors.sequential.Rainbow,
|
| 684 |
+
width=400,
|
| 685 |
+
height=400,
|
| 686 |
+
),
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
def get_node_label_chart(self, node_id):
|
| 690 |
+
self.set_node_label_chart(node_id)
|
| 691 |
+
return self.node_reps[node_id]["label_chart"]
|
data_measurements_clusters/dataset_utils.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
from os.path import exists
|
| 18 |
+
from os.path import join as pjoin
|
| 19 |
+
|
| 20 |
+
from datasets import Dataset, load_dataset, load_from_disk
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
_CACHE_DIR = "cache_dir"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# grab first N rows of a dataset from the hub
|
| 27 |
+
def load_truncated_dataset(
|
| 28 |
+
dataset_name,
|
| 29 |
+
config_name=None,
|
| 30 |
+
split_name=None,
|
| 31 |
+
num_rows=0,
|
| 32 |
+
use_streaming=True,
|
| 33 |
+
use_auth_token=None,
|
| 34 |
+
use_dataset=None,
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
This function loads the first `num_rows` items of a dataset for a
|
| 38 |
+
given `config_name` and `split_name`.
|
| 39 |
+
When the dataset is streamable, we iterate through the first
|
| 40 |
+
`num_rows` examples in streaming mode, write them to a jsonl file,
|
| 41 |
+
then create a new dataset from the json.
|
| 42 |
+
This is the most direct way to make a Dataset from an IterableDataset
|
| 43 |
+
as of datasets version 1.6.1.
|
| 44 |
+
Otherwise, we download the full dataset and select the first
|
| 45 |
+
`num_rows` items
|
| 46 |
+
Args:
|
| 47 |
+
dataset_name (string):
|
| 48 |
+
dataset id in the dataset library
|
| 49 |
+
config_name (string):
|
| 50 |
+
dataset configuration
|
| 51 |
+
split_name (string):
|
| 52 |
+
optional split name, defaults to `train`
|
| 53 |
+
num_rows (int):
|
| 54 |
+
number of rows to truncate the dataset to, <= 0 means no truncation
|
| 55 |
+
use_streaming (bool):
|
| 56 |
+
whether to use streaming when the dataset supports it
|
| 57 |
+
use_auth_token (string):
|
| 58 |
+
HF authentication token to access private datasets
|
| 59 |
+
use_dataset (Dataset):
|
| 60 |
+
use existing dataset instead of getting one from the hub
|
| 61 |
+
Returns:
|
| 62 |
+
Dataset:
|
| 63 |
+
the truncated dataset as a Dataset object
|
| 64 |
+
"""
|
| 65 |
+
split_name = "train" if split_name is None else split_name
|
| 66 |
+
cache_name = f"{dataset_name.replace('/', '---')}_{'default' if config_name is None else config_name}_{split_name}_{num_rows}"
|
| 67 |
+
if use_streaming:
|
| 68 |
+
if not exists(pjoin(_CACHE_DIR, "tmp", f"{cache_name}.jsonl")):
|
| 69 |
+
iterable_dataset = (
|
| 70 |
+
load_dataset(
|
| 71 |
+
dataset_name,
|
| 72 |
+
name=config_name,
|
| 73 |
+
split=split_name,
|
| 74 |
+
cache_dir=pjoin(_CACHE_DIR, "tmp", cache_name + "_temp"),
|
| 75 |
+
streaming=True,
|
| 76 |
+
use_auth_token=use_auth_token,
|
| 77 |
+
)
|
| 78 |
+
if use_dataset is None
|
| 79 |
+
else use_dataset
|
| 80 |
+
)
|
| 81 |
+
if num_rows > 0:
|
| 82 |
+
iterable_dataset = iterable_dataset.take(num_rows)
|
| 83 |
+
f = open(
|
| 84 |
+
pjoin(_CACHE_DIR, "tmp", f"{cache_name}.jsonl"), "w", encoding="utf-8"
|
| 85 |
+
)
|
| 86 |
+
for row in tqdm(iterable_dataset):
|
| 87 |
+
_ = f.write(json.dumps(row) + "\n")
|
| 88 |
+
f.close()
|
| 89 |
+
dataset = Dataset.from_json(
|
| 90 |
+
pjoin(_CACHE_DIR, "tmp", f"{cache_name}.jsonl"),
|
| 91 |
+
cache_dir=pjoin(_CACHE_DIR, "tmp", cache_name + "_jsonl"),
|
| 92 |
+
)
|
| 93 |
+
else:
|
| 94 |
+
full_dataset = (
|
| 95 |
+
load_dataset(
|
| 96 |
+
dataset_name,
|
| 97 |
+
name=config_name,
|
| 98 |
+
split=split_name,
|
| 99 |
+
use_auth_token=use_auth_token,
|
| 100 |
+
cache_dir=pjoin(_CACHE_DIR, "tmp", cache_name + "_temp"),
|
| 101 |
+
)
|
| 102 |
+
if use_dataset is None
|
| 103 |
+
else use_dataset
|
| 104 |
+
)
|
| 105 |
+
if num_rows > 0:
|
| 106 |
+
dataset = full_dataset.select(range(num_rows))
|
| 107 |
+
else:
|
| 108 |
+
dataset = full_dataset
|
| 109 |
+
return dataset
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# get all instances of a specific field in a dataset with indices and labels
|
| 113 |
+
def extract_features(examples, indices, input_field_path, label_name=None):
|
| 114 |
+
"""
|
| 115 |
+
This function prepares examples for further processing by:
|
| 116 |
+
- returning an "unrolled" list of all the fields denoted by input_field_path
|
| 117 |
+
- with the indices corresponding to the example the field item came from
|
| 118 |
+
- optionally, the corresponding label is also returned with each field item
|
| 119 |
+
Args:
|
| 120 |
+
examples (dict):
|
| 121 |
+
a dictionary of lists, provided dataset.map with batched=True
|
| 122 |
+
indices (list):
|
| 123 |
+
a list of indices, provided dataset.map with with_indices=True
|
| 124 |
+
input_field_path (tuple):
|
| 125 |
+
a tuple indicating the field we want to extract. Can be a singleton
|
| 126 |
+
for top-level features (e.g. `("text",)`) or a full path for nested
|
| 127 |
+
features (e.g. `("answers", "text")`) to get all answer strings in
|
| 128 |
+
SQuAD
|
| 129 |
+
label_name (string):
|
| 130 |
+
optionally used to align the field items with labels. Currently,
|
| 131 |
+
returns the top-most field that has this name, which may fail in some
|
| 132 |
+
edge cases
|
| 133 |
+
TODO: make it so the label is specified through a full path
|
| 134 |
+
Returns:
|
| 135 |
+
Dict:
|
| 136 |
+
a dictionary of lists, used by dataset.map with batched=True.
|
| 137 |
+
labels are all None if label_name!=None but label_name is not found
|
| 138 |
+
TODO: raised an error if label_name is specified but not found
|
| 139 |
+
"""
|
| 140 |
+
top_name = input_field_path[0]
|
| 141 |
+
if label_name is not None and label_name in examples:
|
| 142 |
+
item_list = [
|
| 143 |
+
{"index": i, "label": label, "items": items}
|
| 144 |
+
for i, items, label in zip(
|
| 145 |
+
indices, examples[top_name], examples[label_name]
|
| 146 |
+
)
|
| 147 |
+
]
|
| 148 |
+
else:
|
| 149 |
+
item_list = [
|
| 150 |
+
{"index": i, "label": None, "items": items}
|
| 151 |
+
for i, items in zip(indices, examples[top_name])
|
| 152 |
+
]
|
| 153 |
+
for field_name in input_field_path[1:]:
|
| 154 |
+
new_item_list = []
|
| 155 |
+
for dct in item_list:
|
| 156 |
+
if label_name is not None and label_name in dct["items"]:
|
| 157 |
+
if isinstance(dct["items"][field_name], list):
|
| 158 |
+
new_item_list += [
|
| 159 |
+
{"index": dct["index"], "label": label, "items": next_item}
|
| 160 |
+
for next_item, label in zip(
|
| 161 |
+
dct["items"][field_name], dct["items"][label_name]
|
| 162 |
+
)
|
| 163 |
+
]
|
| 164 |
+
else:
|
| 165 |
+
new_item_list += [
|
| 166 |
+
{
|
| 167 |
+
"index": dct["index"],
|
| 168 |
+
"label": dct["items"][label_name],
|
| 169 |
+
"items": dct["items"][field_name],
|
| 170 |
+
}
|
| 171 |
+
]
|
| 172 |
+
else:
|
| 173 |
+
if isinstance(dct["items"][field_name], list):
|
| 174 |
+
new_item_list += [
|
| 175 |
+
{
|
| 176 |
+
"index": dct["index"],
|
| 177 |
+
"label": dct["label"],
|
| 178 |
+
"items": next_item,
|
| 179 |
+
}
|
| 180 |
+
for next_item in dct["items"][field_name]
|
| 181 |
+
]
|
| 182 |
+
else:
|
| 183 |
+
new_item_list += [
|
| 184 |
+
{
|
| 185 |
+
"index": dct["index"],
|
| 186 |
+
"label": dct["label"],
|
| 187 |
+
"items": dct["items"][field_name],
|
| 188 |
+
}
|
| 189 |
+
]
|
| 190 |
+
item_list = new_item_list
|
| 191 |
+
res = (
|
| 192 |
+
{
|
| 193 |
+
"ids": [dct["index"] for dct in item_list],
|
| 194 |
+
"field": [dct["items"] for dct in item_list],
|
| 195 |
+
}
|
| 196 |
+
if label_name is None
|
| 197 |
+
else {
|
| 198 |
+
"ids": [dct["index"] for dct in item_list],
|
| 199 |
+
"field": [dct["items"] for dct in item_list],
|
| 200 |
+
"label": [dct["label"] for dct in item_list],
|
| 201 |
+
}
|
| 202 |
+
)
|
| 203 |
+
return res
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# grab some examples and extract interesting fields
|
| 207 |
+
def prepare_clustering_dataset(
|
| 208 |
+
dataset_name,
|
| 209 |
+
input_field_path,
|
| 210 |
+
label_name=None,
|
| 211 |
+
config_name=None,
|
| 212 |
+
split_name=None,
|
| 213 |
+
num_rows=0,
|
| 214 |
+
use_streaming=True,
|
| 215 |
+
use_auth_token=None,
|
| 216 |
+
cache_dir=_CACHE_DIR,
|
| 217 |
+
use_dataset=None,
|
| 218 |
+
):
|
| 219 |
+
"""
|
| 220 |
+
This function loads the first `num_rows` items of a dataset for a
|
| 221 |
+
given `config_name` and `split_name`, and extracts all instances of a field
|
| 222 |
+
of interest denoted by `input_field_path` along with the indices of the
|
| 223 |
+
examples the instances came from and optionall their labels (`label_name`)
|
| 224 |
+
in the original dataset
|
| 225 |
+
Args:
|
| 226 |
+
dataset_name (string):
|
| 227 |
+
dataset id in the dataset library
|
| 228 |
+
input_field_path (tuple):
|
| 229 |
+
a tuple indicating the field we want to extract. Can be a singleton
|
| 230 |
+
for top-level features (e.g. `("text",)`) or a full path for nested
|
| 231 |
+
features (e.g. `("answers", "text")`) to get all answer strings in
|
| 232 |
+
SQuAD
|
| 233 |
+
label_name (string):
|
| 234 |
+
optionally used to align the field items with labels. Currently,
|
| 235 |
+
returns the top-most field that has this name, which fails in edge cases
|
| 236 |
+
config_name (string):
|
| 237 |
+
dataset configuration
|
| 238 |
+
split_name (string):
|
| 239 |
+
optional split name, defaults to `train`
|
| 240 |
+
num_rows (int):
|
| 241 |
+
number of rows to truncate the dataset to, <= 0 means no truncation
|
| 242 |
+
use_streaming (bool):
|
| 243 |
+
whether to use streaming when the dataset supports it
|
| 244 |
+
use_auth_token (string):
|
| 245 |
+
HF authentication token to access private datasets
|
| 246 |
+
use_dataset (Dataset):
|
| 247 |
+
use existing dataset instead of getting one from the hub
|
| 248 |
+
Returns:
|
| 249 |
+
Dataset:
|
| 250 |
+
the extracted dataset as a Dataset object. Note that if there is more
|
| 251 |
+
than one instance of the field per example in the original dataset
|
| 252 |
+
(e.g. multiple answers per QA example), the returned dataset will
|
| 253 |
+
have more than `num_rows` rows
|
| 254 |
+
string:
|
| 255 |
+
the path to the newsly created dataset directory
|
| 256 |
+
"""
|
| 257 |
+
cache_path = [
|
| 258 |
+
cache_dir,
|
| 259 |
+
dataset_name.replace("/", "---"),
|
| 260 |
+
f"{'default' if config_name is None else config_name}",
|
| 261 |
+
f"{'train' if split_name is None else split_name}",
|
| 262 |
+
f"field-{'->'.join(input_field_path)}-label-{label_name}",
|
| 263 |
+
f"{num_rows}_rows",
|
| 264 |
+
"features_dset",
|
| 265 |
+
]
|
| 266 |
+
if exists(pjoin(*cache_path)):
|
| 267 |
+
pre_clustering_dset = load_from_disk(pjoin(*cache_path))
|
| 268 |
+
else:
|
| 269 |
+
truncated_dset = load_truncated_dataset(
|
| 270 |
+
dataset_name,
|
| 271 |
+
config_name,
|
| 272 |
+
split_name,
|
| 273 |
+
num_rows,
|
| 274 |
+
use_streaming,
|
| 275 |
+
use_auth_token,
|
| 276 |
+
use_dataset,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
def batch_func(examples, indices):
|
| 280 |
+
return extract_features(examples, indices, input_field_path, label_name)
|
| 281 |
+
|
| 282 |
+
pre_clustering_dset = truncated_dset.map(
|
| 283 |
+
batch_func,
|
| 284 |
+
remove_columns=truncated_dset.features,
|
| 285 |
+
batched=True,
|
| 286 |
+
with_indices=True,
|
| 287 |
+
)
|
| 288 |
+
for i in range(1, len(cache_path) - 1):
|
| 289 |
+
if not exists(pjoin(*cache_path[:i])):
|
| 290 |
+
os.mkdir(pjoin(*cache_path[:i]))
|
| 291 |
+
pre_clustering_dset.save_to_disk(pjoin(*cache_path))
|
| 292 |
+
return pre_clustering_dset, pjoin(*cache_path)
|
posts/conclusion.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
title = "Key Takeaways"
|
| 4 |
+
description = "Review of the information from previous pages."
|
| 5 |
+
date = "2022-01-26"
|
| 6 |
+
thumbnail = "images/raised_hand.png"
|
| 7 |
+
|
| 8 |
+
__KEY_TAKEAWAYS = """
|
| 9 |
+
# Key Takeaways and Review
|
| 10 |
+
|
| 11 |
+
Here are some of the main ideas we have conveyed in this exploration:
|
| 12 |
+
- Defining hate speech is hard and changes depending on your context and goals.
|
| 13 |
+
- Capturing a snapshot of what you've defined to be hate speech in a dataset is hard.
|
| 14 |
+
- Models learn lots of different things based on the data it sees, and that can include things you didn't intend for them to learn.
|
| 15 |
+
|
| 16 |
+
Next, please answer the following questions about the information presented in this demo:
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def run_article():
|
| 21 |
+
st.markdown(__KEY_TAKEAWAYS)
|
| 22 |
+
st.text_area(
|
| 23 |
+
"Did you click on any of the links provided in the **Hate Speech in ACM** page? If so, which one did you find most surprising?"
|
| 24 |
+
)
|
| 25 |
+
st.text_area(
|
| 26 |
+
"Of the datasets presented in the **Dataset Exploration** page, which one did you think best represented content that should be moderated? Which worst?"
|
| 27 |
+
)
|
| 28 |
+
st.text_area(
|
| 29 |
+
"Of the models presented in the **Model Exploration** page, which one did you think performed best? Which worst?"
|
| 30 |
+
)
|
| 31 |
+
st.text_area(
|
| 32 |
+
"Any additional comments about the materials?"
|
| 33 |
+
)
|
| 34 |
+
# from paper
|
| 35 |
+
st.text_area(
|
| 36 |
+
"How would you describe your role? E.g. model developer, dataset developer, domain expert, policy maker, platform manager, community advocate, platform user, student"
|
| 37 |
+
)
|
| 38 |
+
st.text_area(
|
| 39 |
+
"Why are you interested in content moderation?"
|
| 40 |
+
)
|
| 41 |
+
st.text_area(
|
| 42 |
+
"Which modules did you use the most?"
|
| 43 |
+
)
|
| 44 |
+
st.text_area(
|
| 45 |
+
"Which module did you find the most informative?"
|
| 46 |
+
)
|
| 47 |
+
st.text_area(
|
| 48 |
+
"Which application were you most interested in learning more about?"
|
| 49 |
+
)
|
| 50 |
+
st.text_area(
|
| 51 |
+
"What surprised you most about the datasets?"
|
| 52 |
+
)
|
| 53 |
+
st.text_area(
|
| 54 |
+
"Which models are you most concerned about as a user?"
|
| 55 |
+
)
|
| 56 |
+
st.text_area(
|
| 57 |
+
"Do you have any comments or suggestions?"
|
| 58 |
+
)
|
posts/context.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
title = "Hate Speech in ACM"
|
| 4 |
+
description = "The history and development of hate speech detection as a modeling task"
|
| 5 |
+
date = "2022-01-26"
|
| 6 |
+
thumbnail = "images/prohibited.png"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__ACM_SECTION = """
|
| 10 |
+
Content moderation is a collection of interventions used by online platforms to partially obscure
|
| 11 |
+
or remove entirely from user-facing view content that is objectionable based on the company's values
|
| 12 |
+
or community guidelines, which vary from platform to platform.
|
| 13 |
+
[Sarah T. Roberts (2014)](https://yalebooks.yale.edu/book/9780300261479/behind-the-screen/) describes
|
| 14 |
+
content moderation as "the organized practice of screening user-generated content (UGC)
|
| 15 |
+
posted to Internet sites, social media, and other online outlets" (p. 12).
|
| 16 |
+
[Tarleton Gillespie (2021)](https://yalebooks.yale.edu/book/9780300261431/custodians-internet/) writes
|
| 17 |
+
that platforms moderate content "both to protect one user from another,
|
| 18 |
+
or one group from its antagonists, and to remove the offensive, vile, or illegal.''
|
| 19 |
+
While there are a variety of approaches to this problem, in this tool, we focus on automated content moderation,
|
| 20 |
+
which is the application of algorithms to the classification of problematic content.
|
| 21 |
+
|
| 22 |
+
Content that is subject to moderation can be user-directed (e.g. targeted harassment of a particular user
|
| 23 |
+
in comments or direct messages) or posted to a personal account (e.g. user-created posts that contain hateful
|
| 24 |
+
remarks against a particular social group).
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
__CURRENT_APPROACHES = """
|
| 28 |
+
Automated content moderation has relied both on analysis of the media itself (e.g. using methods from natural
|
| 29 |
+
language processing and computer vision) as well as user dynamics (e.g. whether the user sending the content
|
| 30 |
+
to another user shares followers with the recipient, or whether the user posting the content is a relatively new account).
|
| 31 |
+
Often, the ACM pipeline is fed by user-reported content. Within the realm of text-based ACM, approaches vary
|
| 32 |
+
from wordlist-based approaches to data-driven, machine learning models. Common datasets used for training and
|
| 33 |
+
evaluating hate speech detectors can be found at [https://hatespeechdata.com/](https://hatespeechdata.com/).
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
__CURRENT_CHALLENGES = """
|
| 37 |
+
Combating hateful content on the Internet continues to be a challenge. A 2021 survey of respondents
|
| 38 |
+
in the United States, conducted by Anti-Defamation League, found an increase in online hate & harassment
|
| 39 |
+
directed at LGBTQ+, Asian American, Jewish, and African American individuals.
|
| 40 |
+
|
| 41 |
+
### Technical challenges for data-driven systems
|
| 42 |
+
|
| 43 |
+
With respect to models that are based on training data, datasets encode worldviews, and so a common challenge
|
| 44 |
+
lies in having insufficient data or data that only reflects a limited worldview. For example, a recent
|
| 45 |
+
study found that Tweets posted by drag queens were more often rated by an automated system as toxic than
|
| 46 |
+
Tweets posted by white supremacists.
|
| 47 |
+
This may be due, in part, to the labeling schemes and choices made for the data used in training the model,
|
| 48 |
+
as well as particular company policies that are invoked when making these labeling choices.
|
| 49 |
+
(This all needs to be spelled out better!)
|
| 50 |
+
|
| 51 |
+
### Context matters for content moderation.
|
| 52 |
+
|
| 53 |
+
*Counterspeech* is "any direct response to hateful or harmful speech which seeks to undermine it"
|
| 54 |
+
(from [Dangerous Speech Project](https://dangerousspeech.org/counterspeech/)). Counterspeech has been shown
|
| 55 |
+
to be an important community self-moderation tool for reducing instances of hate speech (see
|
| 56 |
+
[Hangartner et al. 2021](https://www.pnas.org/doi/10.1073/pnas.2116310118)), but counterspeech is often
|
| 57 |
+
incorrectly categorized as hate speech by automatic systems due to the counterspeech making direct reference
|
| 58 |
+
to or quoting the original hate speech. Such system behavior silences those who are trying to push back against
|
| 59 |
+
hateful and toxis speech, and, if the flagged content is hidden automatically, prevents others from seeing the
|
| 60 |
+
counterspeech.
|
| 61 |
+
|
| 62 |
+
See [van Aken et al. 2018](https://aclanthology.org/W18-5105.pdf) for a detailed list of examples that
|
| 63 |
+
automatic systems frequently misclassify.
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
__SELF_EXAMPLES = """
|
| 68 |
+
- [**(FB)(TOU)** - *Facebook Community Standards*](https://transparency.fb.com/policies/community-standards/)
|
| 69 |
+
- [**(FB)(Blog)** - *What is Hate Speech? (2017)*](https://about.fb.com/news/2017/06/hard-questions-hate-speech/)
|
| 70 |
+
- [**(NYT)(Blog)** - * New York Times on their partnership with JigSaw*](https://open.nytimes.com/to-apply-machine-learning-responsibly-we-use-it-in-moderation-d001f49e0644)
|
| 71 |
+
- [**(NYT)(FAQ)** - *New York Times on their moderation policy*](https://help.nytimes.com/hc/en-us/articles/115014792387-Comments)
|
| 72 |
+
- [**(Reddit)(TOU)** - *Reddit General Content Policies*](https://www.redditinc.com/policies/content-policy)
|
| 73 |
+
- [**(Reddit)(Blog)** - *AutoMod - help scale moderation without ML*](https://mods.reddithelp.com/hc/en-us/articles/360008425592-Moderation-Tools-overview)
|
| 74 |
+
- [**(Google)(Blog)** - *Google Search Results Moderation*](https://blog.google/products/search/when-and-why-we-remove-content-google-search-results/)
|
| 75 |
+
- [**(Google)(Blog)** - *JigSaw Case Studies*](https://www.perspectiveapi.com/case-studies/)
|
| 76 |
+
- [**(YouTube)(TOU)** - *YouTube Community Guidelines*](https://www.youtube.com/howyoutubeworks/policies/community-guidelines/)
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
__CRITIC_EXAMPLES = """
|
| 80 |
+
- [Social Media and Extremism - Questions about January 6th 2021](https://thehill.com/policy/technology/589651-jan-6-panel-subpoenas-facebook-twitter-reddit-and-alphabet/)
|
| 81 |
+
- [Over-Moderation of LGBTQ content on YouTube](https://www.gaystarnews.com/article/youtube-lgbti-content/)
|
| 82 |
+
- [Disparate Impacts of Moderation](https://www.aclu.org/news/free-speech/time-and-again-social-media-giants-get-content-moderation-wrong-silencing-speech-about-al-aqsa-mosque-is-just-the-latest-example/)
|
| 83 |
+
- [Calls for Transparency](https://santaclaraprinciples.org/)
|
| 84 |
+
- [Income Loss from Failures of Moderation](https://foundation.mozilla.org/de/blog/facebook-delivers-a-serious-blow-to-tunisias-music-scene/)
|
| 85 |
+
- [Fighting Hate Speech, Silencing Drag Queens?](https://link.springer.com/article/10.1007/s12119-020-09790-w)
|
| 86 |
+
- [Reddit Self Reflection on Lack of Content Policy](https://www.reddit.com/r/announcements/comments/gxas21/upcoming_changes_to_our_content_policy_our_board/)
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def run_article():
|
| 90 |
+
st.markdown("## Automatic Content Moderation (ACM)")
|
| 91 |
+
with st.expander("ACM definition", expanded=False):
|
| 92 |
+
st.markdown(__ACM_SECTION, unsafe_allow_html=True)
|
| 93 |
+
st.markdown("## Current approaches to ACM")
|
| 94 |
+
with st.expander("Current Approaches"):
|
| 95 |
+
st.markdown(__CURRENT_APPROACHES, unsafe_allow_html=True)
|
| 96 |
+
st.markdown("## Current challenges in ACM")
|
| 97 |
+
with st.expander("Current Challenges"):
|
| 98 |
+
st.markdown(__CURRENT_CHALLENGES, unsafe_allow_html=True)
|
| 99 |
+
st.markdown("## Examples of ACM in Use: in the Press and in their own Words")
|
| 100 |
+
col1, col2 = st.columns([4, 5])
|
| 101 |
+
with col1.expander("In their own Words"):
|
| 102 |
+
st.markdown(__SELF_EXAMPLES, unsafe_allow_html=True)
|
| 103 |
+
with col2.expander("Critical Writings"):
|
| 104 |
+
st.markdown(__CRITIC_EXAMPLES, unsafe_allow_html=True)
|
posts/dataset_exploration.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from os import mkdir
|
| 3 |
+
from os.path import isdir
|
| 4 |
+
from os.path import join as pjoin
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import streamlit as st
|
| 8 |
+
|
| 9 |
+
from data_measurements_clusters import Clustering
|
| 10 |
+
|
| 11 |
+
title = "Dataset Exploration"
|
| 12 |
+
description = "Comparison of hate speech detection datasets"
|
| 13 |
+
date = "2022-01-26"
|
| 14 |
+
thumbnail = "images/books.png"
|
| 15 |
+
|
| 16 |
+
__COLLECT = """
|
| 17 |
+
In order to turn observations of the world into data, choices must be made
|
| 18 |
+
about what counts as data, where to collect data, and how to collect data.
|
| 19 |
+
When collecting language data, this often means selecting websites that allow
|
| 20 |
+
for easily collecting samples of text, and hate speech data is frequently
|
| 21 |
+
collected from social media platforms like Twitter or forums like Wikipedia.
|
| 22 |
+
Each of these decisions results in a specific sample of all the possible
|
| 23 |
+
observations.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
__ANNOTATE = """
|
| 27 |
+
Once the data is collected, further decisions must be made about how to
|
| 28 |
+
label the data if the data is being used to train a classification system,
|
| 29 |
+
as is common in hate speech detection. These labels must be defined in order
|
| 30 |
+
for the dataset to be consistently labeled, which helps the classification
|
| 31 |
+
model produce more consistent output. This labeling process, called
|
| 32 |
+
*annotation*, can be done by the data collectors, by a set of trained
|
| 33 |
+
annotators with relevant expert knowledge, or by online crowdworkers. Who
|
| 34 |
+
is doing the annotating has a significant effect on the resulting set of
|
| 35 |
+
labels ([Sap et al., 2019](https://aclanthology.org/P19-1163.pdf)).
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
__STANDARDIZE = """
|
| 39 |
+
As a relatively new task in NLP, the definitions that are used across
|
| 40 |
+
different projects vary. Some projects target just hate speech, but others
|
| 41 |
+
may label their data for ‘toxic’, ‘offensive’, or ‘abusive’ language. Still
|
| 42 |
+
others may address related problems such as bullying and harassment.
|
| 43 |
+
This variation makes it difficult to compare across datasets and their
|
| 44 |
+
respective models. As these modeling paradigms become more established,
|
| 45 |
+
definitions grounded in relevant sociological research will need to be
|
| 46 |
+
agreed upon in order for datasets and models in ACM to appropriately
|
| 47 |
+
capture the problems in the world that they set out to address. For more
|
| 48 |
+
on this discussion, see
|
| 49 |
+
[Madukwe et al 2020](https://aclanthology.org/2020.alw-1.18.pdf) and
|
| 50 |
+
[Fortuna et al 2020](https://aclanthology.org/2020.lrec-1.838.pdf).
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
__HOW_TO = """
|
| 54 |
+
To use the tool, select a dataset. The tool will then show clusters of
|
| 55 |
+
examples in the dataset that have been automatically determined to be similar
|
| 56 |
+
to one another. Below that, you can see specific examples within the cluster,
|
| 57 |
+
the labels for those examples, and the distribution of labels within the
|
| 58 |
+
cluster. Note that cluster 0 will always be the full dataset.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
DSET_OPTIONS = {'classla/FRENK-hate-en': {'binary': {'train': {('text',): {'label': {100000: {
|
| 62 |
+
'sentence-transformers/all-mpnet-base-v2': {'tree': {'dataset_name': 'classla/FRENK-hate-en',
|
| 63 |
+
'config_name': 'binary',
|
| 64 |
+
'split_name': 'train',
|
| 65 |
+
'input_field_path': ('text',),
|
| 66 |
+
'label_name': 'label',
|
| 67 |
+
'num_rows': 100000,
|
| 68 |
+
'model_name': 'sentence-transformers/all-mpnet-base-v2',
|
| 69 |
+
'file_name': 'tree'}}}}}}}},
|
| 70 |
+
'tweets_hate_speech_detection': {'default': {'train': {('tweet',): {'label': {100000: {
|
| 71 |
+
'sentence-transformers/all-mpnet-base-v2': {'tree': {'dataset_name': 'tweets_hate_speech_detection',
|
| 72 |
+
'config_name': 'default',
|
| 73 |
+
'split_name': 'train',
|
| 74 |
+
'input_field_path': ('tweet',),
|
| 75 |
+
'label_name': 'label',
|
| 76 |
+
'num_rows': 100000,
|
| 77 |
+
'model_name': 'sentence-transformers/all-mpnet-base-v2',
|
| 78 |
+
'file_name': 'tree'}}}}}}}},
|
| 79 |
+
'ucberkeley-dlab/measuring-hate-speech': {'default': {'train': {('text',): {'hatespeech': {100000: {
|
| 80 |
+
'sentence-transformers/all-mpnet-base-v2': {'tree': {'dataset_name': 'ucberkeley-dlab/measuring-hate-speech',
|
| 81 |
+
'config_name': 'default',
|
| 82 |
+
'split_name': 'train',
|
| 83 |
+
'input_field_path': ('text',),
|
| 84 |
+
'label_name': 'hatespeech',
|
| 85 |
+
'num_rows': 100000,
|
| 86 |
+
'model_name': 'sentence-transformers/all-mpnet-base-v2',
|
| 87 |
+
'file_name': 'tree'}}}}}}}},
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
@st.cache(allow_output_mutation=True)
|
| 91 |
+
def download_tree(args):
|
| 92 |
+
clusters = Clustering(**args)
|
| 93 |
+
return clusters
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def run_article():
|
| 97 |
+
st.markdown("# Making a Hate Speech Dataset")
|
| 98 |
+
st.markdown("## Collecting observations of the world")
|
| 99 |
+
with st.expander("Collection"):
|
| 100 |
+
st.markdown(__COLLECT, unsafe_allow_html=True)
|
| 101 |
+
st.markdown("## Annotating observations with task labels")
|
| 102 |
+
with st.expander("Annotation"):
|
| 103 |
+
st.markdown(__ANNOTATE, unsafe_allow_html=True)
|
| 104 |
+
st.markdown("## Standardizing the task")
|
| 105 |
+
with st.expander("Standardization"):
|
| 106 |
+
st.markdown(__STANDARDIZE, unsafe_allow_html=True)
|
| 107 |
+
st.markdown("# Exploring datasets")
|
| 108 |
+
with st.expander("How to use the tool"):
|
| 109 |
+
st.markdown(__HOW_TO, unsafe_allow_html=True)
|
| 110 |
+
|
| 111 |
+
choose_dset = st.selectbox(
|
| 112 |
+
"Select dataset to visualize",
|
| 113 |
+
DSET_OPTIONS,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
pre_args = DSET_OPTIONS[choose_dset]
|
| 117 |
+
args = pre_args
|
| 118 |
+
while not 'dataset_name' in args:
|
| 119 |
+
args = list(args.values())[0]
|
| 120 |
+
|
| 121 |
+
clustering = download_tree(args)
|
| 122 |
+
|
| 123 |
+
st.markdown("---\n")
|
| 124 |
+
|
| 125 |
+
full_tree_fig = clustering.get_full_tree()
|
| 126 |
+
st.plotly_chart(full_tree_fig, use_container_width=True)
|
| 127 |
+
|
| 128 |
+
st.markdown("---\n")
|
| 129 |
+
show_node = st.selectbox(
|
| 130 |
+
"Visualize cluster node:",
|
| 131 |
+
range(len(clustering.node_list)),
|
| 132 |
+
)
|
| 133 |
+
st.markdown(f"Node {show_node} has {clustering.node_list[show_node]['weight']} examples.")
|
| 134 |
+
st.markdown(f"Node {show_node} was merged at {clustering.node_list[show_node]['merged_at']:.2f}.")
|
| 135 |
+
examplars = clustering.get_node_examplars(show_node)
|
| 136 |
+
st.markdown("---\n")
|
| 137 |
+
|
| 138 |
+
label_fig = clustering.get_node_label_chart(show_node)
|
| 139 |
+
examplars_col, labels_col = st.columns([2, 1])
|
| 140 |
+
examplars_col.markdown("#### Node cluster examplars")
|
| 141 |
+
examplars_col.table(examplars)
|
| 142 |
+
labels_col.markdown("#### Node cluster labels")
|
| 143 |
+
labels_col.plotly_chart(label_fig, use_container_width=True)
|
posts/model_exploration.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
# from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 9 |
+
from transformers import pipeline
|
| 10 |
+
|
| 11 |
+
title = "Model Exploration"
|
| 12 |
+
description = "Comparison of hate speech detection models"
|
| 13 |
+
date = "2022-01-26"
|
| 14 |
+
thumbnail = "images/robot.png"
|
| 15 |
+
|
| 16 |
+
__HATE_DETECTION = """
|
| 17 |
+
Once the data has been collected using the definitions identified for the
|
| 18 |
+
task, you can start training your model. At training, the model takes in
|
| 19 |
+
the data with labels and learns the associated context in the input data
|
| 20 |
+
for each label. Depending on the task design, the labels may be binary like
|
| 21 |
+
'hateful' and 'non-hateful' or multiclass like 'neutral', 'offensive', and
|
| 22 |
+
'attack'.
|
| 23 |
+
|
| 24 |
+
When presented with a new input string, the model then predicts the
|
| 25 |
+
likelihood that the input is classified as each of the available labels and
|
| 26 |
+
returns the label with the highest likelihood as well as how confident the
|
| 27 |
+
model is in its selection using a score from 0 to 1.
|
| 28 |
+
|
| 29 |
+
Neural models such as transformers are frequently trained as general
|
| 30 |
+
language models and then fine-tuned on specific classification tasks.
|
| 31 |
+
These models can vary in their architecture and the optimization
|
| 32 |
+
algorithms, sometimes resulting in very different output for the same
|
| 33 |
+
input text.
|
| 34 |
+
|
| 35 |
+
The models used below include:
|
| 36 |
+
- [RoBERTa trained on FRENK dataset](https://huggingface.co/classla/roberta-base-frenk-hate)
|
| 37 |
+
- [RoBERTa trained on Twitter Hate Speech](https://huggingface.co/cardiffnlp/twitter-roberta-base-hate)
|
| 38 |
+
- [DeHateBERT model (trained on Twitter and StormFront)](https://huggingface.co/Hate-speech-CNERG/dehatebert-mono-english)
|
| 39 |
+
- [RoBERTa trained on 11 English hate speech datasets](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r1-target)
|
| 40 |
+
- [RoBERTa trained on 11 English hate speech datasets and Round 1 of the Dynamically Generated Hate Speech Dataset](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r2-target)
|
| 41 |
+
- [RoBERTa trained on 11 English hate speech datasets and Rounds 1 and 2 of the Dynamically Generated Hate Speech Dataset](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r3-target)
|
| 42 |
+
- [RoBERTa trained on 11 English hate speech datasets and Rounds 1, 2, and 3 of the Dynamically Generated Hate Speech Dataset](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target)
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
__HATECHECK = """
|
| 46 |
+
[Röttinger et al. (2021)](https://aclanthology.org/2021.acl-long.4.pdf)
|
| 47 |
+
developed a list of 3,901 test cases for hate speech detection models called
|
| 48 |
+
HateCheck. HateCheck provides a number of templates long with placeholders for
|
| 49 |
+
identity categories and hateful terms along with labels indicating whether a
|
| 50 |
+
model should or should not categorize the instance as hate speech. For each
|
| 51 |
+
case, they created several examples with different
|
| 52 |
+
identity attributes to test models' abilities to detect hate speech towards
|
| 53 |
+
a range of groups of people. Additionally, they used more difficult
|
| 54 |
+
linguistic contexts such as adding negation or more nuanced words to try to fool the
|
| 55 |
+
model. See some of there examples using the button or try to make
|
| 56 |
+
your own examples to test the models in the tools below.
|
| 57 |
+
|
| 58 |
+
*** Warning: these examples may include hateful and violent content as
|
| 59 |
+
well as slurs and other offensive languages ***
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
__RANKING = """
|
| 63 |
+
When models process a given input, they calculate the probability of
|
| 64 |
+
that input being labeled with each of the possible labels (in binary
|
| 65 |
+
cases for example, either 'hateful' or 'not hateful'). The label with
|
| 66 |
+
the highest probably is returned. If we test multiple input sentences
|
| 67 |
+
for a given model, we can see which input sentences have the
|
| 68 |
+
highest probabilities, indicating which examples the model is most
|
| 69 |
+
confident in classifying.
|
| 70 |
+
|
| 71 |
+
Try comparing different input sentences for a given model
|
| 72 |
+
using the tool below.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
__COMPARISON = """
|
| 76 |
+
Depending on their training data and parameters, models can return very
|
| 77 |
+
different outputs for the same input. Knowing how models differ in
|
| 78 |
+
their behavior can help with choosing an appropriate model for your
|
| 79 |
+
given use case.
|
| 80 |
+
|
| 81 |
+
Additionally, models trained on one kind of data can perform very
|
| 82 |
+
differently when tested on novel data. To show the models' performance
|
| 83 |
+
in a variety of settings, we also show the results of each model on
|
| 84 |
+
four datasets: the entirety of the [HateCheck Dataset](https://github.com/paul-rottger/hatecheck-data),
|
| 85 |
+
the [Dynamically Generated Hate Speech Dataset](https://github.com/bvidgen/Dynamically-Generated-Hate-Speech-Dataset),
|
| 86 |
+
the [Hateful Memes Dataset](https://huggingface.co/datasets/limjiayi/hateful_memes_expanded)
|
| 87 |
+
(without the images), and a 2000-instance subsection of the
|
| 88 |
+
[Open Subtitles](https://opus.nlpl.eu/OpenSubtitles-v2018.php)
|
| 89 |
+
(English) Dataset. While the first three are hate speech datasets and have
|
| 90 |
+
labels for whether or not each instance is hateful, the Open Subtitles dataset
|
| 91 |
+
contains movie subtitles and is not labeled for instances of hate.
|
| 92 |
+
|
| 93 |
+
Try comparing two different models' output for the same comparison
|
| 94 |
+
sentence using the tool below.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
__MODELS = [
|
| 98 |
+
"classla/roberta-base-frenk-hate",
|
| 99 |
+
"cardiffnlp/twitter-roberta-base-hate",
|
| 100 |
+
"Hate-speech-CNERG/dehatebert-mono-english",
|
| 101 |
+
"facebook/roberta-hate-speech-dynabench-r1-target",
|
| 102 |
+
"facebook/roberta-hate-speech-dynabench-r2-target",
|
| 103 |
+
"facebook/roberta-hate-speech-dynabench-r3-target",
|
| 104 |
+
"facebook/roberta-hate-speech-dynabench-r4-target"
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
autopop_ranking = [
|
| 108 |
+
[
|
| 109 |
+
"I don't hate black people.",
|
| 110 |
+
"Black people should die.",
|
| 111 |
+
"My friend likes pancakes."
|
| 112 |
+
],
|
| 113 |
+
[
|
| 114 |
+
"Those dirty redskins are all drunks.",
|
| 115 |
+
"I can't stand Chinese food. It's too spicy.",
|
| 116 |
+
"I lived near many Chinese people and I hated every single one of them."
|
| 117 |
+
],
|
| 118 |
+
[
|
| 119 |
+
"Stop saying that Mexicans don't belong in our country!",
|
| 120 |
+
"We walked along the dyke down the road.",
|
| 121 |
+
"Ah shit, I fucked up."
|
| 122 |
+
]
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
# Creates the forms for receiving multiple inputs to compare for a single
|
| 126 |
+
# model or one input to compare for two models
|
| 127 |
+
def run_article():
|
| 128 |
+
st.markdown("# Making a Hate Speech Detection Model")
|
| 129 |
+
with st.expander("Hate Speech Detection Models", expanded=False):
|
| 130 |
+
st.markdown(__HATE_DETECTION)
|
| 131 |
+
hc_path = "posts/resources/"
|
| 132 |
+
hc_pholders = json.load(
|
| 133 |
+
open(
|
| 134 |
+
hc_path + "template_placeholders.json",
|
| 135 |
+
encoding="utf-8"
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
hc_templates = json.load(
|
| 139 |
+
open(
|
| 140 |
+
hc_path + "hatecheck_category_templates.json",
|
| 141 |
+
encoding="utf-8"
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
hc_info = json.load(
|
| 145 |
+
open(
|
| 146 |
+
hc_path + "hatecheck_category_info.json",
|
| 147 |
+
encoding="utf-8"
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
hc_cats = [""] + list(hc_info.keys())
|
| 151 |
+
|
| 152 |
+
st.markdown("## Testing Models' Behavior")
|
| 153 |
+
with st.expander("HateCheck Examples", expanded=False):
|
| 154 |
+
st.markdown(__HATECHECK)
|
| 155 |
+
category = st.selectbox(
|
| 156 |
+
"Select a category of examples from HateCheck",
|
| 157 |
+
hc_cats,
|
| 158 |
+
key="hc_cat_select"
|
| 159 |
+
)
|
| 160 |
+
if category:
|
| 161 |
+
with st.form(key="hate_check"):
|
| 162 |
+
hc_cat = hc_info[category]
|
| 163 |
+
templates = []
|
| 164 |
+
names = []
|
| 165 |
+
for hc_temp in hc_cat:
|
| 166 |
+
templates.append(hc_temp)
|
| 167 |
+
names.append(hc_cat[hc_temp]["name"])
|
| 168 |
+
selected_names = st.multiselect(
|
| 169 |
+
"Select one or more HateCheck templates to generate examples for",
|
| 170 |
+
names,
|
| 171 |
+
key="hc_temp_multiselect"
|
| 172 |
+
)
|
| 173 |
+
num_exs = st.number_input(
|
| 174 |
+
"Select a number of examples to generate for each selected template",
|
| 175 |
+
min_value = 1,
|
| 176 |
+
max_value = 5,
|
| 177 |
+
value = 3
|
| 178 |
+
)
|
| 179 |
+
if st.form_submit_button(label="Generate Examples"):
|
| 180 |
+
for name in selected_names:
|
| 181 |
+
index = names.index(name)
|
| 182 |
+
template = templates[index]
|
| 183 |
+
examples = generate_hc_ex(
|
| 184 |
+
hc_templates[template],
|
| 185 |
+
hc_pholders,
|
| 186 |
+
num_exs
|
| 187 |
+
)
|
| 188 |
+
st.header(name)
|
| 189 |
+
st.subheader("Label: " + hc_cat[template]["value"])
|
| 190 |
+
st.caption(hc_cat[template]["desc"])
|
| 191 |
+
for num in range(num_exs):
|
| 192 |
+
ex = examples[num]
|
| 193 |
+
st.write("Example #" + str(num + 1) + ": " + ex)
|
| 194 |
+
|
| 195 |
+
st.markdown("## Model Output Ranking")
|
| 196 |
+
with st.expander("Model Output Ranking Tool", expanded=False):
|
| 197 |
+
st.markdown(__RANKING)
|
| 198 |
+
with st.form(key="ranking"):
|
| 199 |
+
model_name = st.selectbox(
|
| 200 |
+
"Select a model to test",
|
| 201 |
+
__MODELS,
|
| 202 |
+
)
|
| 203 |
+
# the placeholder key functionality was added in v1.2 of streamlit
|
| 204 |
+
# and versions on Spaces currently goes up to v1.0
|
| 205 |
+
input_1 = st.text_input(
|
| 206 |
+
"Input 1",
|
| 207 |
+
help="Try a phrase like 'We shouldn't let [IDENTITY] suffer.'",
|
| 208 |
+
# placeholder="We shouldn't let [IDENTITY] suffer."
|
| 209 |
+
)
|
| 210 |
+
input_2 = st.text_input(
|
| 211 |
+
"Input 2",
|
| 212 |
+
help="Try a phrase like 'I'd rather die than date [IDENTITY].'",
|
| 213 |
+
# placeholder="I'd rather die than date [IDENTITY]."
|
| 214 |
+
)
|
| 215 |
+
input_3 = st.text_input(
|
| 216 |
+
"Input 3",
|
| 217 |
+
help="Try a phrase like 'Good morning'",
|
| 218 |
+
#placeholder="Good morning."
|
| 219 |
+
)
|
| 220 |
+
autopop = st.checkbox(
|
| 221 |
+
'Choose examples for me',
|
| 222 |
+
key="rank_autopop_ckbx",
|
| 223 |
+
help="Check this box to run the model with 3 preselected sentences."
|
| 224 |
+
)
|
| 225 |
+
if st.form_submit_button(label="Rank inputs"):
|
| 226 |
+
if autopop:
|
| 227 |
+
rank_inputs = random.choice(autopop_ranking)
|
| 228 |
+
else:
|
| 229 |
+
rank_inputs = [input_1, input_2, input_3]
|
| 230 |
+
sys.stderr.write("\n" + str(rank_inputs) + "\n")
|
| 231 |
+
results = run_ranked(model_name, rank_inputs)
|
| 232 |
+
st.dataframe(results)
|
| 233 |
+
|
| 234 |
+
st.markdown("## Model Comparison")
|
| 235 |
+
with st.expander("Model Comparison Tool", expanded=False):
|
| 236 |
+
st.markdown(__COMPARISON)
|
| 237 |
+
with st.form(key="comparison"):
|
| 238 |
+
model_name_1 = st.selectbox(
|
| 239 |
+
"Select a model to compare",
|
| 240 |
+
__MODELS,
|
| 241 |
+
key="compare_model_1",
|
| 242 |
+
)
|
| 243 |
+
model_name_2 = st.selectbox(
|
| 244 |
+
"Select another model to compare",
|
| 245 |
+
__MODELS,
|
| 246 |
+
key="compare_model_2",
|
| 247 |
+
)
|
| 248 |
+
autopop = st.checkbox(
|
| 249 |
+
'Choose an example for me',
|
| 250 |
+
key="comp_autopop_ckbx",
|
| 251 |
+
help="Check this box to compare the models with a preselected sentence."
|
| 252 |
+
)
|
| 253 |
+
input_text = st.text_input("Comparison input")
|
| 254 |
+
if st.form_submit_button(label="Compare models"):
|
| 255 |
+
if autopop:
|
| 256 |
+
input_text = random.choice(random.choice(autopop_ranking))
|
| 257 |
+
results = run_compare(model_name_1, model_name_2, input_text)
|
| 258 |
+
st.write("### Showing results for: " + input_text)
|
| 259 |
+
st.dataframe(results)
|
| 260 |
+
outside_ds = [
|
| 261 |
+
"hatecheck",
|
| 262 |
+
"dynabench",
|
| 263 |
+
"hatefulmemes",
|
| 264 |
+
"opensubtitles"
|
| 265 |
+
]
|
| 266 |
+
name_1_short = model_name_1.split("/")[1]
|
| 267 |
+
name_2_short = model_name_2.split("/")[1]
|
| 268 |
+
for calib_ds in outside_ds:
|
| 269 |
+
ds_loc = "posts/resources/charts/" + calib_ds + "/"
|
| 270 |
+
images, captions = [], []
|
| 271 |
+
for model in [name_1_short, name_2_short]:
|
| 272 |
+
images.append(ds_loc + model + "_" + calib_ds + ".png")
|
| 273 |
+
captions.append("Counts of dataset instances by hate score.")
|
| 274 |
+
st.write("#### Model performance comparison on " + calib_ds)
|
| 275 |
+
st.image(images, captions)
|
| 276 |
+
|
| 277 |
+
# if model_name_1 == "Hate-speech-CNERG/dehatebert-mono-english":
|
| 278 |
+
# st.image("posts/resources/dehatebert-mono-english_calibration.png")
|
| 279 |
+
# elif model_name_1 == "cardiffnlp/twitter-roberta-base-hate":
|
| 280 |
+
# st.image("posts/resources/twitter-roberta-base-hate_calibration.png")
|
| 281 |
+
# st.write("Calibration of Model 2")
|
| 282 |
+
# if model_name_2 == "Hate-speech-CNERG/dehatebert-mono-english":
|
| 283 |
+
# st.image("posts/resources/dehatebert-mono-english_calibration.png")
|
| 284 |
+
# elif model_name_2 == "cardiffnlp/twitter-roberta-base-hate":
|
| 285 |
+
# st.image("posts/resources/twitter-roberta-base-hate_calibration.png")
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# Takes in a Hate Check template and placeholders and generates the given
|
| 289 |
+
# number of random examples from the template, inserting a random instance of
|
| 290 |
+
# an identity category if there is a placeholder in the template
|
| 291 |
+
def generate_hc_ex(template, placeholders, gen_num):
|
| 292 |
+
sampled = random.sample(template, gen_num)
|
| 293 |
+
ph_cats = list(placeholders.keys())
|
| 294 |
+
for index in range(len(sampled)):
|
| 295 |
+
sample = sampled[index]
|
| 296 |
+
for ph_cat in ph_cats:
|
| 297 |
+
if ph_cat in sample:
|
| 298 |
+
insert = random.choice(placeholders[ph_cat])
|
| 299 |
+
sampled[index] = sample.replace(ph_cat, insert).capitalize()
|
| 300 |
+
return sampled
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# Runs the received input strings through the given model and returns the
|
| 304 |
+
# all scores for all possible labels as a DataFrame
|
| 305 |
+
def run_ranked(model, input_list):
|
| 306 |
+
classifier = pipeline(
|
| 307 |
+
"text-classification",
|
| 308 |
+
model=model,
|
| 309 |
+
return_all_scores=True
|
| 310 |
+
)
|
| 311 |
+
output = {}
|
| 312 |
+
results = classifier(input_list)
|
| 313 |
+
for result in results:
|
| 314 |
+
for index in range(len(result)):
|
| 315 |
+
label = result[index]["label"]
|
| 316 |
+
score = result[index]["score"]
|
| 317 |
+
if label in output:
|
| 318 |
+
output[label].append(score)
|
| 319 |
+
else:
|
| 320 |
+
new_out = [score]
|
| 321 |
+
output[label] = new_out
|
| 322 |
+
return pd.DataFrame(output, index=input_list)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# Takes in two model names and returns the output of both models for that
|
| 326 |
+
# given input string
|
| 327 |
+
def run_compare(name_1, name_2, text):
|
| 328 |
+
classifier_1 = pipeline("text-classification", model=name_1)
|
| 329 |
+
result_1 = classifier_1(text)
|
| 330 |
+
out_1 = {}
|
| 331 |
+
out_1["Model"] = name_1
|
| 332 |
+
out_1["Label"] = result_1[0]["label"]
|
| 333 |
+
out_1["Score"] = result_1[0]["score"]
|
| 334 |
+
classifier_2 = pipeline("text-classification", model=name_2)
|
| 335 |
+
result_2 = classifier_2(text)
|
| 336 |
+
out_2 = {}
|
| 337 |
+
out_2["Model"] = name_2
|
| 338 |
+
out_2["Label"] = result_2[0]["label"]
|
| 339 |
+
out_2["Score"] = result_2[0]["score"]
|
| 340 |
+
return [out_1, out_2]
|
posts/welcome.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
title = "Welcome Page"
|
| 4 |
+
description = "Introduction"
|
| 5 |
+
date = "2022-01-26"
|
| 6 |
+
thumbnail = "images/waving_hand.png"
|
| 7 |
+
|
| 8 |
+
__INTRO_TEXT = """
|
| 9 |
+
Welcome to the Task Exploration Activity for hate speech detection!
|
| 10 |
+
In this series of modules, you'll learn about the history of hate speech detection as a task in
|
| 11 |
+
the larger pipeline of automatic content moderation (ACM).
|
| 12 |
+
You'll also be able to interact with and compare datasets and models built for this task.
|
| 13 |
+
|
| 14 |
+
The goal of this exploration is to share the design considerations and challenges faced when using algorithms to detect hate speech.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
__DEF_HATE_SPEECH = """
|
| 18 |
+
Hate speech is hard to define, with definitions shifting across time and location.
|
| 19 |
+
In 2019, the United Nations defined hate speech as "any kind of communication in speech,
|
| 20 |
+
writing or behaviour, that attacks or uses pejorative or discriminatory language with
|
| 21 |
+
reference to a person or a group on the basis of who they are, in other words, based on their religion,
|
| 22 |
+
ethnicity, nationality, race, colour, descent, gender or other identity factor."
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
__DEF_CONTENT = """
|
| 26 |
+
Different platforms have different guidelines about what
|
| 27 |
+
content is sanctioned on the platform. For example, many US-based platforms prohibit posting threats of violence,
|
| 28 |
+
nudity, and hate speech. We discuss hate speech below.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
__CONTENT_WARNING = """
|
| 32 |
+
These modules contain examples of hateful, abusive, and offensive language that have be collected in datasets and
|
| 33 |
+
reproduced by models. These examples are meant to illustrate the variety of content that may be subject to
|
| 34 |
+
moderation.
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
__DATASET_LIST = """
|
| 39 |
+
- [FRENK hate speech dataset](https://huggingface.co/datasets/classla/FRENK-hate-en)
|
| 40 |
+
- [Twitter Hate Speech dataset](https://huggingface.co/datasets/tweets_hate_speech_detection)
|
| 41 |
+
- [UC Berkley Measuring Hate Speech](https://huggingface.co/datasets/ucberkeley-dlab/measuring-hate-speech)
|
| 42 |
+
- [Dynamically Generated Hate Speech Dataset](https://github.com/bvidgen/Dynamically-Generated-Hate-Speech-Dataset)
|
| 43 |
+
- [HateCheck](https://github.com/paul-rottger/hatecheck-data)
|
| 44 |
+
- [Hateful Memes Dataset](https://huggingface.co/datasets/limjiayi/hateful_memes_expanded)
|
| 45 |
+
- [Open Subtitles English Dataset](https://opus.nlpl.eu/OpenSubtitles-v2018.php)
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
__MODEL_LIST = """
|
| 49 |
+
- [RoBERTa trained on FRENK dataset](https://huggingface.co/classla/roberta-base-frenk-hate)
|
| 50 |
+
- [RoBERTa trained on Twitter Hate Speech](https://huggingface.co/cardiffnlp/twitter-roberta-base-hate)
|
| 51 |
+
- [DeHateBERT model (trained on Twitter and StormFront)](https://huggingface.co/Hate-speech-CNERG/dehatebert-mono-english)
|
| 52 |
+
- [RoBERTa trained on 11 English hate speech datasets](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r1-target)
|
| 53 |
+
- [RoBERTa trained on 11 English hate speech datasets and Round 1 of the Dynamically Generated Hate Speech Dataset](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r2-target)
|
| 54 |
+
- [RoBERTa trained on 11 English hate speech datasets and Rounds 1 and 2 of the Dynamically Generated Hate Speech Dataset](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r3-target)
|
| 55 |
+
- [RoBERTa trained on 11 English hate speech datasets and Rounds 1, 2, and 3 of the Dynamically Generated Hate Speech Dataset](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def run_article():
|
| 59 |
+
st.markdown("# Welcome!")
|
| 60 |
+
st.markdown(__INTRO_TEXT)
|
| 61 |
+
st.markdown("### What is hate speech?")
|
| 62 |
+
st.markdown(__DEF_HATE_SPEECH)
|
| 63 |
+
st.markdown("### What kind of content is subject to moderation?")
|
| 64 |
+
st.markdown(__DEF_CONTENT)
|
| 65 |
+
st.markdown("### Content Warning")
|
| 66 |
+
st.markdown(__CONTENT_WARNING)
|
| 67 |
+
st.markdown("---\n\n## Featured datasets and models")
|
| 68 |
+
col_1, col_2, _ = st.columns(3)
|
| 69 |
+
with col_1:
|
| 70 |
+
st.markdown("### Datasets")
|
| 71 |
+
st.markdown(__DATASET_LIST, unsafe_allow_html=True)
|
| 72 |
+
with col_2:
|
| 73 |
+
st.markdown("### Models")
|
| 74 |
+
st.markdown(__MODEL_LIST, unsafe_allow_html=True)
|