|
from . import InputExample |
|
import csv |
|
import gzip |
|
import os |
|
|
|
class TripletReader(object): |
|
""" |
|
Reads in the a Triplet Dataset: Each line contains (at least) 3 columns, one anchor column (s1), |
|
one positive example (s2) and one negative example (s3) |
|
""" |
|
def __init__(self, dataset_folder, s1_col_idx=0, s2_col_idx=1, s3_col_idx=2, has_header=False, delimiter="\t", |
|
quoting=csv.QUOTE_NONE): |
|
self.dataset_folder = dataset_folder |
|
self.s1_col_idx = s1_col_idx |
|
self.s2_col_idx = s2_col_idx |
|
self.s3_col_idx = s3_col_idx |
|
self.has_header = has_header |
|
self.delimiter = delimiter |
|
self.quoting = quoting |
|
|
|
def get_examples(self, filename, max_examples=0): |
|
""" |
|
|
|
""" |
|
data = csv.reader(open(os.path.join(self.dataset_folder, filename), encoding="utf-8"), delimiter=self.delimiter, |
|
quoting=self.quoting) |
|
examples = [] |
|
if self.has_header: |
|
next(data) |
|
|
|
for id, row in enumerate(data): |
|
s1 = row[self.s1_col_idx] |
|
s2 = row[self.s2_col_idx] |
|
s3 = row[self.s3_col_idx] |
|
|
|
examples.append(InputExample(texts=[s1, s2, s3])) |
|
if max_examples > 0 and len(examples) >= max_examples: |
|
break |
|
|
|
return examples |