Spaces:
Runtime error
Runtime error
from torch.utils.data import Dataset | |
import csv | |
import os | |
import numpy as np | |
class HighlightDataset(Dataset): | |
"""Face Landmarks dataset.""" | |
def __init__(self, root_dir, transform=None): | |
""" | |
Arguments: | |
csv_file (string): Path to the csv file with annotations. | |
root_dir (string): Directory with all datas including videos and annotations. | |
""" | |
self.root_dir = root_dir | |
self.video_dir = os.path.join(root_dir, "videos") | |
self.anno_path = os.path.join(root_dir, "ydata-tvsum50-anno.tsv") | |
#read annotations | |
with open(self.anno_path, newline='') as f: | |
reader = csv.reader(f, delimiter='\t') | |
raw_annotations = list(reader) | |
self.num_annotator = 20 | |
self.annotations = self.parse_annotations(raw_annotations) # {video_id: [importance scores]} | |
#get list of videos | |
self.video_list = os.listdir(self.video_dir) | |
def parse_annotations(self, annotations): | |
''' | |
format of annotations: | |
[[video_id, video_category, importance score], ...] | |
''' | |
#separate annotations into chunks of length 20 | |
parsed_annotations = {} | |
annotations_per_video = [annotations[i:i + self.num_annotator] for i in range(0, len(annotations), self.num_annotator)] | |
for anno_video in annotations_per_video: | |
video_id = anno_video[0][0] | |
video_category = anno_video[0][1] | |
#get importance score | |
#anno[2] is a string of scores separated by commas | |
importance_score = [] | |
for anno in anno_video: | |
anno[2] = anno[2].split(',') | |
anno[2] = [float(score) for score in anno[2]] | |
importance_score.append(anno[2]) | |
importance_score = np.array(importance_score) | |
#get average importance score | |
parsed_annotations[video_id] = np.mean(importance_score, axis=0) | |
return parsed_annotations | |
def __len__(self): | |
return len(self.video_list) | |
def __getitem__(self, idx): | |
if torch.is_tensor(idx): | |
idx = idx.tolist() | |
#should return frames and scores | |
video_name = self.video_list[idx] | |
video_id = video_name.split('.')[0] | |
video_path = os.path.join(self.video_dir, video_name) | |
#get annotations | |
annotations = self.annotations[video_id] | |
return video_path, annotations |