File size: 3,737 Bytes
bff0409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import random
import tqdm
import datasets
import re
import transformers
import numpy as np
from utils import MGT, HWT

preproc_tokenizer = transformers.AutoTokenizer.from_pretrained(
    "google-t5/t5-small", model_max_length=512
)


def process_spaces(text):
    text = (
        text.replace(" ,", ",")
        .replace(" .", ".")
        .replace(" ?", "?")
        .replace(" !", "!")
        .replace(" ;", ";")
        .replace(" '", "'")
        .replace(" ’ ", "'")
        .replace(" :", ":")
        .replace("<newline>", "\n")
        .replace("`` ", '"')
        .replace(" ''", '"')
        .replace("''", '"')
        .replace(".. ", "... ")
        .replace(" )", ")")
        .replace("( ", "(")
        .replace(" n't", "n't")
        .replace(" i ", " I ")
        .replace(" i'", " I'")
        .replace("\\'", "'")
        .replace("\n ", "\n")
        .strip()
    )
    text = text.replace("\r\n", "\n").replace("\\n", "").replace("!\n", "")
    return re.sub("\n+", "\n", text)


def trim_to_shorter_length(texta, textb):
    # truncate to shorter of o and s
    shorter_length = min(len(texta.split(" ")), len(textb.split(" ")))
    texta = " ".join(texta.split(" ")[:shorter_length])
    textb = " ".join(textb.split(" ")[:shorter_length])
    return texta, textb


def load_HC3():

    ds = datasets.load_dataset("Hello-SimpleAI/HC3", name="all")
    ds = ds["train"]  # DatasetDict -> Dataset
    filtered_ds = [
        item
        for item in ds
        if (
            len(item["human_answers"]) > 0
            and len(item["chatgpt_answers"]) > 0
            and len(item["human_answers"][0].split()) > 5
            and len(item["chatgpt_answers"][0].split()) > 5
        )
    ]
    # print("DEBUG: filtered_ds[0]:", filtered_ds[0])

    data_new = {"text": [], "label": []}

    for i in tqdm.tqdm(range(len(filtered_ds)), desc="Parsing data"):
        data_new["text"].append(process_spaces(filtered_ds[i]["human_answers"][0]))
        data_new["label"].append(HWT)
        data_new["text"].append(process_spaces(filtered_ds[i]["chatgpt_answers"][0]))
        data_new["label"].append(MGT)
    return data_new


def filter_data(data_o, long_train_threshold_low=150, long_train_threshold_high=512):
    data_HWT = [
        text for text, label in zip(data_o["text"], data_o["label"]) if label == HWT
    ]
    data_MGT = [
        text for text, label in zip(data_o["text"], data_o["label"]) if label == MGT
    ]

    # keep only examples with <= 512 tokens according to mask_tokenizer
    # this step has the extra effect of removing examples with low-quality/garbage content
    tokenized_data = preproc_tokenizer(data_HWT)
    long_HWT = [
        x
        for x, y in zip(data_HWT, tokenized_data["input_ids"])
        if long_train_threshold_low <= len(y) <= long_train_threshold_high
    ]
    tokenized_data = preproc_tokenizer(data_MGT)
    long_MGT = [
        x
        for x, y in zip(data_MGT, tokenized_data["input_ids"])
        if long_train_threshold_low <= len(y) <= long_train_threshold_high
    ]

    # print stats about remainining data
    print(f"Total number of samples: {len(long_HWT)}")
    print(f"Average number of words: {np.mean([len(x.split()) for x in long_HWT])}")

    data = {
        HWT: [],
        MGT: [],
    }

    print(len(long_HWT), len(long_MGT))
    for o, s in zip(long_HWT, long_MGT):
        o, s = trim_to_shorter_length(o, s)

        # add to the data
        data[HWT].append(o)
        data[MGT].append(s)

    return data


# Test code
# data_o = load_HC3()
# data = filter_data(data_o)
# real = data[HWT]  # [:args.train_real_num]  len== n_samples, many sentences of words
# generated = data[MGT]
# print(real[:5])
# print(generated[:5])