Spaces:
Running
Running
Create data_loader.py
Browse files- data_loader.py +125 -0
data_loader.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import tqdm
|
3 |
+
import datasets
|
4 |
+
import re
|
5 |
+
import transformers
|
6 |
+
import numpy as np
|
7 |
+
from utils import MGT, HWT
|
8 |
+
|
9 |
+
preproc_tokenizer = transformers.AutoTokenizer.from_pretrained(
|
10 |
+
"google-t5/t5-small", model_max_length=512
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def process_spaces(text):
|
15 |
+
text = (
|
16 |
+
text.replace(" ,", ",")
|
17 |
+
.replace(" .", ".")
|
18 |
+
.replace(" ?", "?")
|
19 |
+
.replace(" !", "!")
|
20 |
+
.replace(" ;", ";")
|
21 |
+
.replace(" '", "'")
|
22 |
+
.replace(" ’ ", "'")
|
23 |
+
.replace(" :", ":")
|
24 |
+
.replace("<newline>", "\n")
|
25 |
+
.replace("`` ", '"')
|
26 |
+
.replace(" ''", '"')
|
27 |
+
.replace("''", '"')
|
28 |
+
.replace(".. ", "... ")
|
29 |
+
.replace(" )", ")")
|
30 |
+
.replace("( ", "(")
|
31 |
+
.replace(" n't", "n't")
|
32 |
+
.replace(" i ", " I ")
|
33 |
+
.replace(" i'", " I'")
|
34 |
+
.replace("\\'", "'")
|
35 |
+
.replace("\n ", "\n")
|
36 |
+
.strip()
|
37 |
+
)
|
38 |
+
text = text.replace("\r\n", "\n").replace("\\n", "").replace("!\n", "")
|
39 |
+
return re.sub("\n+", "\n", text)
|
40 |
+
|
41 |
+
|
42 |
+
def trim_to_shorter_length(texta, textb):
|
43 |
+
# truncate to shorter of o and s
|
44 |
+
shorter_length = min(len(texta.split(" ")), len(textb.split(" ")))
|
45 |
+
texta = " ".join(texta.split(" ")[:shorter_length])
|
46 |
+
textb = " ".join(textb.split(" ")[:shorter_length])
|
47 |
+
return texta, textb
|
48 |
+
|
49 |
+
|
50 |
+
def load_HC3():
|
51 |
+
|
52 |
+
ds = datasets.load_dataset("Hello-SimpleAI/HC3", name="all")
|
53 |
+
ds = ds["train"] # DatasetDict -> Dataset
|
54 |
+
filtered_ds = [
|
55 |
+
item
|
56 |
+
for item in ds
|
57 |
+
if (
|
58 |
+
len(item["human_answers"]) > 0
|
59 |
+
and len(item["chatgpt_answers"]) > 0
|
60 |
+
and len(item["human_answers"][0].split()) > 5
|
61 |
+
and len(item["chatgpt_answers"][0].split()) > 5
|
62 |
+
)
|
63 |
+
]
|
64 |
+
# print("DEBUG: filtered_ds[0]:", filtered_ds[0])
|
65 |
+
|
66 |
+
data_new = {"text": [], "label": []}
|
67 |
+
|
68 |
+
for i in tqdm.tqdm(range(len(filtered_ds)), desc="Parsing data"):
|
69 |
+
data_new["text"].append(process_spaces(filtered_ds[i]["human_answers"][0]))
|
70 |
+
data_new["label"].append(HWT)
|
71 |
+
data_new["text"].append(process_spaces(filtered_ds[i]["chatgpt_answers"][0]))
|
72 |
+
data_new["label"].append(MGT)
|
73 |
+
return data_new
|
74 |
+
|
75 |
+
|
76 |
+
def filter_data(data_o, long_train_threshold_low=150, long_train_threshold_high=512):
|
77 |
+
data_HWT = [
|
78 |
+
text for text, label in zip(data_o["text"], data_o["label"]) if label == HWT
|
79 |
+
]
|
80 |
+
data_MGT = [
|
81 |
+
text for text, label in zip(data_o["text"], data_o["label"]) if label == MGT
|
82 |
+
]
|
83 |
+
|
84 |
+
# keep only examples with <= 512 tokens according to mask_tokenizer
|
85 |
+
# this step has the extra effect of removing examples with low-quality/garbage content
|
86 |
+
tokenized_data = preproc_tokenizer(data_HWT)
|
87 |
+
long_HWT = [
|
88 |
+
x
|
89 |
+
for x, y in zip(data_HWT, tokenized_data["input_ids"])
|
90 |
+
if long_train_threshold_low <= len(y) <= long_train_threshold_high
|
91 |
+
]
|
92 |
+
tokenized_data = preproc_tokenizer(data_MGT)
|
93 |
+
long_MGT = [
|
94 |
+
x
|
95 |
+
for x, y in zip(data_MGT, tokenized_data["input_ids"])
|
96 |
+
if long_train_threshold_low <= len(y) <= long_train_threshold_high
|
97 |
+
]
|
98 |
+
|
99 |
+
# print stats about remainining data
|
100 |
+
print(f"Total number of samples: {len(long_HWT)}")
|
101 |
+
print(f"Average number of words: {np.mean([len(x.split()) for x in long_HWT])}")
|
102 |
+
|
103 |
+
data = {
|
104 |
+
HWT: [],
|
105 |
+
MGT: [],
|
106 |
+
}
|
107 |
+
|
108 |
+
print(len(long_HWT), len(long_MGT))
|
109 |
+
for o, s in zip(long_HWT, long_MGT):
|
110 |
+
o, s = trim_to_shorter_length(o, s)
|
111 |
+
|
112 |
+
# add to the data
|
113 |
+
data[HWT].append(o)
|
114 |
+
data[MGT].append(s)
|
115 |
+
|
116 |
+
return data
|
117 |
+
|
118 |
+
|
119 |
+
# Test code
|
120 |
+
# data_o = load_HC3()
|
121 |
+
# data = filter_data(data_o)
|
122 |
+
# real = data[HWT] # [:args.train_real_num] len== n_samples, many sentences of words
|
123 |
+
# generated = data[MGT]
|
124 |
+
# print(real[:5])
|
125 |
+
# print(generated[:5])
|