File size: 980 Bytes
9ee3bcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import tqdm
import numpy as np
import nltk

from utils import DEVICE, FeatureExtractor, HWT, MGT
from roberta_model_loader import roberta_model
from meta_train import net
from data_loader import load_HC3, filter_data


feature_extractor = FeatureExtractor(roberta_model, net)

target = HWT

# load target data
data_o = load_HC3()
data = filter_data(data_o)
data = data[target]
# print(data[:3])

# split with nltk
nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
paragraphs = [nltk.sent_tokenize(paragraph)[1:-1] for paragraph in data]
data = [sent for paragraph in paragraphs for sent in paragraph if 5 < len(sent.split())]
# print(data[:3])

# extract features
feature_ref = []
for i in tqdm.tqdm(range(2000), desc=f"Generating feature ref for {target}"):
    feature_ref.append(
        feature_extractor.process(data[i], False).detach()
    )  # detach to save memory
torch.save(torch.cat(feature_ref, dim=0), f"feature_ref_{target}.pt")