jer233 commited on
Commit
bff0409
·
verified ·
1 Parent(s): c319e25

Create data_loader.py

Browse files
Files changed (1) hide show
  1. data_loader.py +125 -0
data_loader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import tqdm
3
+ import datasets
4
+ import re
5
+ import transformers
6
+ import numpy as np
7
+ from utils import MGT, HWT
8
+
9
+ preproc_tokenizer = transformers.AutoTokenizer.from_pretrained(
10
+ "google-t5/t5-small", model_max_length=512
11
+ )
12
+
13
+
14
+ def process_spaces(text):
15
+ text = (
16
+ text.replace(" ,", ",")
17
+ .replace(" .", ".")
18
+ .replace(" ?", "?")
19
+ .replace(" !", "!")
20
+ .replace(" ;", ";")
21
+ .replace(" '", "'")
22
+ .replace(" ’ ", "'")
23
+ .replace(" :", ":")
24
+ .replace("<newline>", "\n")
25
+ .replace("`` ", '"')
26
+ .replace(" ''", '"')
27
+ .replace("''", '"')
28
+ .replace(".. ", "... ")
29
+ .replace(" )", ")")
30
+ .replace("( ", "(")
31
+ .replace(" n't", "n't")
32
+ .replace(" i ", " I ")
33
+ .replace(" i'", " I'")
34
+ .replace("\\'", "'")
35
+ .replace("\n ", "\n")
36
+ .strip()
37
+ )
38
+ text = text.replace("\r\n", "\n").replace("\\n", "").replace("!\n", "")
39
+ return re.sub("\n+", "\n", text)
40
+
41
+
42
+ def trim_to_shorter_length(texta, textb):
43
+ # truncate to shorter of o and s
44
+ shorter_length = min(len(texta.split(" ")), len(textb.split(" ")))
45
+ texta = " ".join(texta.split(" ")[:shorter_length])
46
+ textb = " ".join(textb.split(" ")[:shorter_length])
47
+ return texta, textb
48
+
49
+
50
+ def load_HC3():
51
+
52
+ ds = datasets.load_dataset("Hello-SimpleAI/HC3", name="all")
53
+ ds = ds["train"] # DatasetDict -> Dataset
54
+ filtered_ds = [
55
+ item
56
+ for item in ds
57
+ if (
58
+ len(item["human_answers"]) > 0
59
+ and len(item["chatgpt_answers"]) > 0
60
+ and len(item["human_answers"][0].split()) > 5
61
+ and len(item["chatgpt_answers"][0].split()) > 5
62
+ )
63
+ ]
64
+ # print("DEBUG: filtered_ds[0]:", filtered_ds[0])
65
+
66
+ data_new = {"text": [], "label": []}
67
+
68
+ for i in tqdm.tqdm(range(len(filtered_ds)), desc="Parsing data"):
69
+ data_new["text"].append(process_spaces(filtered_ds[i]["human_answers"][0]))
70
+ data_new["label"].append(HWT)
71
+ data_new["text"].append(process_spaces(filtered_ds[i]["chatgpt_answers"][0]))
72
+ data_new["label"].append(MGT)
73
+ return data_new
74
+
75
+
76
+ def filter_data(data_o, long_train_threshold_low=150, long_train_threshold_high=512):
77
+ data_HWT = [
78
+ text for text, label in zip(data_o["text"], data_o["label"]) if label == HWT
79
+ ]
80
+ data_MGT = [
81
+ text for text, label in zip(data_o["text"], data_o["label"]) if label == MGT
82
+ ]
83
+
84
+ # keep only examples with <= 512 tokens according to mask_tokenizer
85
+ # this step has the extra effect of removing examples with low-quality/garbage content
86
+ tokenized_data = preproc_tokenizer(data_HWT)
87
+ long_HWT = [
88
+ x
89
+ for x, y in zip(data_HWT, tokenized_data["input_ids"])
90
+ if long_train_threshold_low <= len(y) <= long_train_threshold_high
91
+ ]
92
+ tokenized_data = preproc_tokenizer(data_MGT)
93
+ long_MGT = [
94
+ x
95
+ for x, y in zip(data_MGT, tokenized_data["input_ids"])
96
+ if long_train_threshold_low <= len(y) <= long_train_threshold_high
97
+ ]
98
+
99
+ # print stats about remainining data
100
+ print(f"Total number of samples: {len(long_HWT)}")
101
+ print(f"Average number of words: {np.mean([len(x.split()) for x in long_HWT])}")
102
+
103
+ data = {
104
+ HWT: [],
105
+ MGT: [],
106
+ }
107
+
108
+ print(len(long_HWT), len(long_MGT))
109
+ for o, s in zip(long_HWT, long_MGT):
110
+ o, s = trim_to_shorter_length(o, s)
111
+
112
+ # add to the data
113
+ data[HWT].append(o)
114
+ data[MGT].append(s)
115
+
116
+ return data
117
+
118
+
119
+ # Test code
120
+ # data_o = load_HC3()
121
+ # data = filter_data(data_o)
122
+ # real = data[HWT] # [:args.train_real_num] len== n_samples, many sentences of words
123
+ # generated = data[MGT]
124
+ # print(real[:5])
125
+ # print(generated[:5])