Spaces:
Runtime error
Runtime error
''' | |
Script for downloading all GLUE data. | |
Note: for legal reasons, we are unable to host MRPC. | |
You can either use the version hosted by the SentEval team, which is already tokenized, | |
or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. | |
For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). | |
You should then rename and place specific files in a folder (see below for an example). | |
mkdir MRPC | |
cabextract MSRParaphraseCorpus.msi -d MRPC | |
cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt | |
cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt | |
rm MRPC/_* | |
rm MSRParaphraseCorpus.msi | |
1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now. | |
2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray! | |
''' | |
import os | |
import sys | |
import shutil | |
import argparse | |
import tempfile | |
import urllib.request | |
import zipfile | |
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI" ] | |
MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' | |
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' | |
def download_and_extract(task, data_dir): | |
print("Downloading and extracting %s..." % task) | |
data_file = "%s.zip" % task | |
urllib.request.urlretrieve(TASK2PATH[task], data_file) | |
with zipfile.ZipFile(data_file) as zip_ref: | |
zip_ref.extractall(data_dir) | |
os.remove(data_file) | |
print("\tCompleted!") | |
def format_mrpc(data_dir, path_to_data): | |
print("Processing MRPC...") | |
mrpc_dir = os.path.join(data_dir, "MRPC") | |
if not os.path.isdir(mrpc_dir): | |
os.mkdir(mrpc_dir) | |
if path_to_data: | |
mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt") | |
mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt") | |
else: | |
print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN) | |
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") | |
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") | |
urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) | |
urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) | |
assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file | |
assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file | |
urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv")) | |
dev_ids = [] | |
with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: | |
for row in ids_fh: | |
dev_ids.append(row.strip().split('\t')) | |
with open(mrpc_train_file, encoding="utf8") as data_fh, \ | |
open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ | |
open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: | |
header = data_fh.readline() | |
train_fh.write(header) | |
dev_fh.write(header) | |
for row in data_fh: | |
label, id1, id2, s1, s2 = row.strip().split('\t') | |
if [id1, id2] in dev_ids: | |
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) | |
else: | |
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) | |
with open(mrpc_test_file, encoding="utf8") as data_fh, \ | |
open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: | |
header = data_fh.readline() | |
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") | |
for idx, row in enumerate(data_fh): | |
label, id1, id2, s1, s2 = row.strip().split('\t') | |
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) | |
print("\tCompleted!") | |
def download_diagnostic(data_dir): | |
print("Downloading and extracting diagnostic...") | |
if not os.path.isdir(os.path.join(data_dir, "diagnostic")): | |
os.mkdir(os.path.join(data_dir, "diagnostic")) | |
data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") | |
urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) | |
print("\tCompleted!") | |
return | |
def get_tasks(task_names): | |
task_names = task_names.split(',') | |
if "all" in task_names: | |
tasks = TASKS | |
else: | |
tasks = [] | |
for task_name in task_names: | |
assert task_name in TASKS, "Task %s not found!" % task_name | |
tasks.append(task_name) | |
return tasks | |
def main(arguments): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--data_dir', help='directory to save data to', type=str, default='./') | |
parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', | |
type=str, default='all') | |
parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', | |
type=str, default='') | |
args = parser.parse_args(arguments) | |
if not os.path.isdir(args.data_dir): | |
os.mkdir(args.data_dir) | |
tasks = get_tasks(args.tasks) | |
for task in tasks: | |
extract_and integrate(task, args.data_dir) | |
if __name__ == '__main__': | |
sys.exit(main(sys.argv[1:])) | |