|
from everything import * |
|
from bert import BertModel |
|
from optimizer import AdamW |
|
from tokenizer import BertTokenizer |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
|
|
class SentimentDataset(Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, data): |
|
sents = [x[0] for x in data] |
|
labels = [x[1] for x in data] |
|
sent_ids = [x[2] for x in data] |
|
|
|
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) |
|
token_ids = torch.LongTensor(encoding['input_ids']) |
|
attention_mask = torch.LongTensor(encoding['attention_mask']) |
|
labels = torch.LongTensor(labels) |
|
|
|
return token_ids, attention_mask, labels, sents, sent_ids |
|
|
|
def collate_fn(self, all_data): |
|
token_ids, attention_mask, labels, sents, sent_ids = self.pad_data(all_data) |
|
|
|
batched_data = { |
|
'token_ids': token_ids, |
|
'attention_mask': attention_mask, |
|
'labels': labels, |
|
'sents': sents, |
|
'sent_ids': sent_ids |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
class SentimentTestDataset(Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, data): |
|
sents = [x[0] for x in data] |
|
sent_ids = [x[1] for x in data] |
|
|
|
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) |
|
token_ids = torch.LongTensor(encoding['input_ids']) |
|
attention_mask = torch.LongTensor(encoding['attention_mask']) |
|
|
|
return token_ids, attention_mask, sents, sent_ids |
|
|
|
def collate_fn(self, all_data): |
|
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data) |
|
|
|
batched_data = { |
|
'token_ids': token_ids, |
|
'attention_mask': attention_mask, |
|
'sents': sents, |
|
'sent_ids': sent_ids |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
class AmazonDataset(Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, data): |
|
sents = [x[0] for x in data] |
|
sent_ids = [x[1] for x in data] |
|
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) |
|
token_ids = torch.LongTensor(encoding['input_ids']) |
|
attension_mask = torch.LongTensor(encoding['attention_mask']) |
|
|
|
return token_ids, attension_mask, sent_ids |
|
|
|
def collate_fn(self, data): |
|
token_ids, attention_mask, sent_ids = self.pad_data(data) |
|
|
|
batched_data = { |
|
'token_ids': token_ids, |
|
'attention_mask': attention_mask, |
|
'sent_ids': sent_ids |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
class SemanticDataset(Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, data): |
|
sents1 = [x[0] for x in data] |
|
sents2 = [x[1] for x in data] |
|
score = [x[2] for x in data] |
|
sent_ids = [x[3] for x in data] |
|
encoding = tokenizer(sents1 + sents2, return_tensors='pt', padding=True, truncation=True) |
|
token_ids = torch.LongTensor(encoding['input_ids']) |
|
attension_mask = torch.LongTensor(encoding['attention_mask']) |
|
|
|
return token_ids, attension_mask, score, sent_ids |
|
|
|
def collate_fn(self, data): |
|
token_ids, attention_mask, score, sent_ids = self.pad_data(data) |
|
n = len(sent_ids) |
|
|
|
batched_data = { |
|
'token_ids_1': token_ids[:n], |
|
'token_ids_2': token_ids[n:], |
|
'attention_mask_1': attention_mask[:n], |
|
'attention_mask_2': attention_mask[n:], |
|
'score': score, |
|
'sent_ids': sent_ids |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
class InferenceDataset(Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, data): |
|
anchor = [x[0] for x in data] |
|
positive = [x[1] for x in data] |
|
negative = [x[2] for x in data] |
|
sent_ids = [x[3] for x in data] |
|
encoding = tokenizer(anchor + positive + negative, return_tensors='pt', padding=True, truncation=True) |
|
token_ids = torch.LongTensor(encoding['input_ids']) |
|
attension_mask = torch.LongTensor(encoding['attention_mask']) |
|
|
|
return token_ids, attension_mask, sent_ids |
|
|
|
def collate_fn(self, data): |
|
token_ids, attention_mask, sent_ids = self.pad_data(data) |
|
n = len(sent_ids) |
|
|
|
batched_data = { |
|
'anchor_ids': token_ids[:n], |
|
'positive_ids': token_ids[n:2*n], |
|
'negative_ids': token_ids[2*n:], |
|
'anchor_masks': attention_mask[:n], |
|
'positive_masks': attention_mask[n:2*n], |
|
'negative_masks': attention_mask[2*n:], |
|
'sent_ids': sent_ids |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
def load_data(filename, flag='train'): |
|
''' |
|
- for amazon dataset: list of (sent, id) |
|
- for nli dataset: list of (anchor, positive, negative, id) |
|
- for stsb dataset: list of (sentence1, sentence2, score, id) |
|
|
|
- for test dataset: list of (sent, id) |
|
- for train dataset: list of (sent, label, id) |
|
''' |
|
|
|
if flag == 'amazon': |
|
df = pd.read_parquet(filename) |
|
data = list(zip(df['content'], df.index)) |
|
elif flag == 'nli': |
|
df = pd.read_parquet(filename) |
|
data = list(zip(df['anchor'], df['positive'], df['negative'], df.index)) |
|
elif flag == 'stsb': |
|
df = pd.read_parquet(filename) |
|
data = list(zip(df['sentence1'], df['sentence2'], df['score'], df.index)) |
|
else: |
|
data, num_labels = [], set() |
|
|
|
with open(filename, 'r') as fp: |
|
if flag == 'test': |
|
for record in csv.DictReader(fp, delimiter = '\t'): |
|
sent = record['sentence'].lower().strip() |
|
sent_id = record['id'].lower().strip() |
|
data.append((sent,sent_id)) |
|
else: |
|
for record in csv.DictReader(fp, delimiter = '\t'): |
|
sent = record['sentence'].lower().strip() |
|
sent_id = record['id'].lower().strip() |
|
label = int(record['sentiment'].strip()) |
|
num_labels.add(label) |
|
data.append((sent, label, sent_id)) |
|
|
|
print(f"load {len(data)} data from {filename}") |
|
if flag == "train": |
|
return data, len(num_labels) |
|
else: |
|
return data |
|
|
|
|
|
def seed_everything(seed=11711): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = True |