Spaces:
Runtime error
Runtime error
Create data_loader.py
Browse files- data_loader.py +125 -0
data_loader.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import tqdm
|
| 3 |
+
import datasets
|
| 4 |
+
import re
|
| 5 |
+
import transformers
|
| 6 |
+
import numpy as np
|
| 7 |
+
from utils import MGT, HWT
|
| 8 |
+
|
| 9 |
+
preproc_tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 10 |
+
"google-t5/t5-small", model_max_length=512
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def process_spaces(text):
|
| 15 |
+
text = (
|
| 16 |
+
text.replace(" ,", ",")
|
| 17 |
+
.replace(" .", ".")
|
| 18 |
+
.replace(" ?", "?")
|
| 19 |
+
.replace(" !", "!")
|
| 20 |
+
.replace(" ;", ";")
|
| 21 |
+
.replace(" '", "'")
|
| 22 |
+
.replace(" ’ ", "'")
|
| 23 |
+
.replace(" :", ":")
|
| 24 |
+
.replace("<newline>", "\n")
|
| 25 |
+
.replace("`` ", '"')
|
| 26 |
+
.replace(" ''", '"')
|
| 27 |
+
.replace("''", '"')
|
| 28 |
+
.replace(".. ", "... ")
|
| 29 |
+
.replace(" )", ")")
|
| 30 |
+
.replace("( ", "(")
|
| 31 |
+
.replace(" n't", "n't")
|
| 32 |
+
.replace(" i ", " I ")
|
| 33 |
+
.replace(" i'", " I'")
|
| 34 |
+
.replace("\\'", "'")
|
| 35 |
+
.replace("\n ", "\n")
|
| 36 |
+
.strip()
|
| 37 |
+
)
|
| 38 |
+
text = text.replace("\r\n", "\n").replace("\\n", "").replace("!\n", "")
|
| 39 |
+
return re.sub("\n+", "\n", text)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def trim_to_shorter_length(texta, textb):
|
| 43 |
+
# truncate to shorter of o and s
|
| 44 |
+
shorter_length = min(len(texta.split(" ")), len(textb.split(" ")))
|
| 45 |
+
texta = " ".join(texta.split(" ")[:shorter_length])
|
| 46 |
+
textb = " ".join(textb.split(" ")[:shorter_length])
|
| 47 |
+
return texta, textb
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_HC3():
|
| 51 |
+
|
| 52 |
+
ds = datasets.load_dataset("Hello-SimpleAI/HC3", name="all")
|
| 53 |
+
ds = ds["train"] # DatasetDict -> Dataset
|
| 54 |
+
filtered_ds = [
|
| 55 |
+
item
|
| 56 |
+
for item in ds
|
| 57 |
+
if (
|
| 58 |
+
len(item["human_answers"]) > 0
|
| 59 |
+
and len(item["chatgpt_answers"]) > 0
|
| 60 |
+
and len(item["human_answers"][0].split()) > 5
|
| 61 |
+
and len(item["chatgpt_answers"][0].split()) > 5
|
| 62 |
+
)
|
| 63 |
+
]
|
| 64 |
+
# print("DEBUG: filtered_ds[0]:", filtered_ds[0])
|
| 65 |
+
|
| 66 |
+
data_new = {"text": [], "label": []}
|
| 67 |
+
|
| 68 |
+
for i in tqdm.tqdm(range(len(filtered_ds)), desc="Parsing data"):
|
| 69 |
+
data_new["text"].append(process_spaces(filtered_ds[i]["human_answers"][0]))
|
| 70 |
+
data_new["label"].append(HWT)
|
| 71 |
+
data_new["text"].append(process_spaces(filtered_ds[i]["chatgpt_answers"][0]))
|
| 72 |
+
data_new["label"].append(MGT)
|
| 73 |
+
return data_new
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def filter_data(data_o, long_train_threshold_low=150, long_train_threshold_high=512):
|
| 77 |
+
data_HWT = [
|
| 78 |
+
text for text, label in zip(data_o["text"], data_o["label"]) if label == HWT
|
| 79 |
+
]
|
| 80 |
+
data_MGT = [
|
| 81 |
+
text for text, label in zip(data_o["text"], data_o["label"]) if label == MGT
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
# keep only examples with <= 512 tokens according to mask_tokenizer
|
| 85 |
+
# this step has the extra effect of removing examples with low-quality/garbage content
|
| 86 |
+
tokenized_data = preproc_tokenizer(data_HWT)
|
| 87 |
+
long_HWT = [
|
| 88 |
+
x
|
| 89 |
+
for x, y in zip(data_HWT, tokenized_data["input_ids"])
|
| 90 |
+
if long_train_threshold_low <= len(y) <= long_train_threshold_high
|
| 91 |
+
]
|
| 92 |
+
tokenized_data = preproc_tokenizer(data_MGT)
|
| 93 |
+
long_MGT = [
|
| 94 |
+
x
|
| 95 |
+
for x, y in zip(data_MGT, tokenized_data["input_ids"])
|
| 96 |
+
if long_train_threshold_low <= len(y) <= long_train_threshold_high
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
# print stats about remainining data
|
| 100 |
+
print(f"Total number of samples: {len(long_HWT)}")
|
| 101 |
+
print(f"Average number of words: {np.mean([len(x.split()) for x in long_HWT])}")
|
| 102 |
+
|
| 103 |
+
data = {
|
| 104 |
+
HWT: [],
|
| 105 |
+
MGT: [],
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
print(len(long_HWT), len(long_MGT))
|
| 109 |
+
for o, s in zip(long_HWT, long_MGT):
|
| 110 |
+
o, s = trim_to_shorter_length(o, s)
|
| 111 |
+
|
| 112 |
+
# add to the data
|
| 113 |
+
data[HWT].append(o)
|
| 114 |
+
data[MGT].append(s)
|
| 115 |
+
|
| 116 |
+
return data
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# Test code
|
| 120 |
+
# data_o = load_HC3()
|
| 121 |
+
# data = filter_data(data_o)
|
| 122 |
+
# real = data[HWT] # [:args.train_real_num] len== n_samples, many sentences of words
|
| 123 |
+
# generated = data[MGT]
|
| 124 |
+
# print(real[:5])
|
| 125 |
+
# print(generated[:5])
|