Spaces:
Runtime error
Runtime error
Douwe Kiela
commited on
Commit
·
d5b2eed
1
Parent(s):
d23bce8
Initial import
Browse files- .gitignore +1 -0
- README.md +1 -1
- app.py +82 -0
- collect.py +39 -0
- config.py.example +6 -0
- requirements.txt +2 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
config.py
|
README.md
CHANGED
|
@@ -10,4 +10,4 @@ pinned: false
|
|
| 10 |
license: bigscience-bloom-rail-1.0
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
| 10 |
license: bigscience-bloom-rail-1.0
|
| 11 |
---
|
| 12 |
|
| 13 |
+
A basic example of dynamic adversarial data collection with a Gradio app.
|
app.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Basic example for doing model-in-the-loop dynamic adversarial data collection
|
| 2 |
+
# using Gradio Blocks.
|
| 3 |
+
|
| 4 |
+
import random
|
| 5 |
+
from urllib.parse import parse_qs
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import requests
|
| 9 |
+
from transformers import pipeline
|
| 10 |
+
|
| 11 |
+
demo = gr.Blocks()
|
| 12 |
+
|
| 13 |
+
with demo:
|
| 14 |
+
total_cnt = 2 # How many examples per HIT
|
| 15 |
+
dummy = gr.Textbox(visible=False) # dummy for passing assignmentId
|
| 16 |
+
|
| 17 |
+
# We keep track of state as a Variable
|
| 18 |
+
state_dict = {"assignmentId": "", "cnt": 0, "fooled": 0, "data": [], "metadata": {}}
|
| 19 |
+
state = gr.Variable(state_dict)
|
| 20 |
+
|
| 21 |
+
gr.Markdown("# DADC in Gradio example")
|
| 22 |
+
gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!")
|
| 23 |
+
|
| 24 |
+
state_display = gr.Markdown(f"State: 0/{total_cnt} (0 fooled)")
|
| 25 |
+
|
| 26 |
+
# Generate model prediction
|
| 27 |
+
# Default model: distilbert-base-uncased-finetuned-sst-2-english
|
| 28 |
+
def _predict(txt, tgt, state):
|
| 29 |
+
pipe = pipeline("sentiment-analysis")
|
| 30 |
+
pred = pipe(txt)[0]
|
| 31 |
+
|
| 32 |
+
pred["label"] = pred["label"].title()
|
| 33 |
+
ret = f"Target: {tgt}. Model prediction: {pred['label']} ({pred['score']} confidence). {pred['label'] != tgt}\n\n"
|
| 34 |
+
if pred["label"] != tgt:
|
| 35 |
+
state["fooled"] += 1
|
| 36 |
+
ret += " You fooled the model! Well done!"
|
| 37 |
+
else:
|
| 38 |
+
ret += " You did not fool the model! Too bad, try again!"
|
| 39 |
+
state["data"].append(ret)
|
| 40 |
+
state["cnt"] += 1
|
| 41 |
+
|
| 42 |
+
done = state["cnt"] == total_cnt
|
| 43 |
+
toggle_final_submit = gr.update(visible=done)
|
| 44 |
+
toggle_example_submit = gr.update(visible=not done)
|
| 45 |
+
new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['fooled']} fooled)"
|
| 46 |
+
return ret, state, toggle_example_submit, toggle_final_submit, new_state_md
|
| 47 |
+
|
| 48 |
+
# Input fields
|
| 49 |
+
text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False)
|
| 50 |
+
labels = ["Positive", "Negative"]
|
| 51 |
+
random.shuffle(labels)
|
| 52 |
+
label_input = gr.Radio(choices=labels, label="Target (correct) label")
|
| 53 |
+
text_output = gr.Markdown()
|
| 54 |
+
with gr.Column() as example_submit:
|
| 55 |
+
submit_ex_button = gr.Button("Submit")
|
| 56 |
+
with gr.Column(visible=False) as final_submit:
|
| 57 |
+
submit_hit_button = gr.Button("Submit HIT")
|
| 58 |
+
|
| 59 |
+
# Submit state to MTurk backend for ExternalQuestion
|
| 60 |
+
# Update the URL below to switch from Sandbox to real data collection
|
| 61 |
+
def _submit(state, dummy):
|
| 62 |
+
query = parse_qs(dummy[1:])
|
| 63 |
+
assert "assignmentId" in query, "No assignment ID provided, unable to submit"
|
| 64 |
+
state["assignmentId"] = query["assignmentId"]
|
| 65 |
+
url = "https://workersandbox.mturk.com/mturk/externalSubmit"
|
| 66 |
+
return requests.post(url, data=state)
|
| 67 |
+
|
| 68 |
+
# Button event handlers
|
| 69 |
+
submit_ex_button.click(
|
| 70 |
+
_predict,
|
| 71 |
+
inputs=[text_input, label_input, state],
|
| 72 |
+
outputs=[text_output, state, example_submit, final_submit, state_display],
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
submit_hit_button.click(
|
| 76 |
+
_submit,
|
| 77 |
+
inputs=[state, dummy],
|
| 78 |
+
outputs=None,
|
| 79 |
+
_js="function(state, dummy) { return [state, window.location.search]; }",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
demo.launch()
|
collect.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Basic example for running MTurk data collection against a Space
|
| 2 |
+
# For more information see https://docs.aws.amazon.com/mturk/index.html
|
| 3 |
+
|
| 4 |
+
import boto3
|
| 5 |
+
from boto.mturk.question import ExternalQuestion
|
| 6 |
+
|
| 7 |
+
from config import MTURK_KEY, MTURK_SECRET
|
| 8 |
+
|
| 9 |
+
MTURK_REGION = "us-east-1"
|
| 10 |
+
MTURK_SANDBOX = "https://mturk-requester-sandbox.us-east-1.amazonaws.com"
|
| 11 |
+
|
| 12 |
+
mturk = boto3.client(
|
| 13 |
+
"mturk",
|
| 14 |
+
aws_access_key_id=MTURK_KEY,
|
| 15 |
+
aws_secret_access_key=MTURK_SECRET,
|
| 16 |
+
region_name=MTURK_REGION,
|
| 17 |
+
endpoint_url=MTURK_SANDBOX,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
question = ExternalQuestion(
|
| 21 |
+
"https://huggingface.co/spaces/douwekiela/dadc", frame_height=600
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
new_hit = mturk.create_hit(
|
| 25 |
+
Title="DADC with Gradio",
|
| 26 |
+
Description="Hello world",
|
| 27 |
+
Keywords="fool the model",
|
| 28 |
+
Reward="0.15",
|
| 29 |
+
MaxAssignments=1,
|
| 30 |
+
LifetimeInSeconds=172800,
|
| 31 |
+
AssignmentDurationInSeconds=600,
|
| 32 |
+
AutoApprovalDelayInSeconds=14400,
|
| 33 |
+
Question=question.get_as_xml(),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
print(
|
| 37 |
+
"Sandbox link: https://workersandbox.mturk.com/mturk/preview?groupId="
|
| 38 |
+
+ new_hit["HIT"]["HITGroupId"]
|
| 39 |
+
)
|
config.py.example
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fill in the information and rename this file config.py
|
| 2 |
+
# You can obtain the key and secret in the AWS Identity
|
| 3 |
+
# and Access Management (IAM) panel.
|
| 4 |
+
|
| 5 |
+
MTURK_KEY = ''
|
| 6 |
+
MTURK_SECRET = ''
|
requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
requests
|
| 2 |
+
transformers
|