txt2audio's picture
update
fa25a07
import collections
import csv
import logging
import os
import random
from glob import glob
from pathlib import Path
import numpy as np
import torch
import torchvision
logger = logging.getLogger(f'main.{__name__}')
class VGGSound(torch.utils.data.Dataset):
def __init__(self, split, specs_dir, transforms=None, splits_path='./data', meta_path='./data/vggsound.csv'):
super().__init__()
self.split = split
self.specs_dir = specs_dir
self.transforms = transforms
self.splits_path = splits_path
self.meta_path = meta_path
vggsound_meta = list(csv.reader(open(meta_path), quotechar='"'))
unique_classes = sorted(list(set(row[2] for row in vggsound_meta)))
self.label2target = {label: target for target, label in enumerate(unique_classes)}
self.target2label = {target: label for label, target in self.label2target.items()}
self.video2target = {row[0]: self.label2target[row[2]] for row in vggsound_meta}
split_clip_ids_path = os.path.join(splits_path, f'vggsound_{split}.txt')
if not os.path.exists(split_clip_ids_path):
self.make_split_files()
clip_ids_with_timestamp = open(split_clip_ids_path).read().splitlines()
clip_paths = [os.path.join(specs_dir, v + '_mel.npy') for v in clip_ids_with_timestamp]
self.dataset = clip_paths
# self.dataset = clip_paths[:10000] # overfit one batch
# 'zyTX_1BXKDE_16000_26000'[:11] -> 'zyTX_1BXKDE'
vid_classes = [self.video2target[Path(path).stem[:11]] for path in self.dataset]
class2count = collections.Counter(vid_classes)
self.class_counts = torch.tensor([class2count[cls] for cls in range(len(class2count))])
# self.sample_weights = [len(self.dataset) / class2count[self.video2target[Path(path).stem[:11]]] for path in self.dataset]
def __getitem__(self, idx):
item = {}
spec_path = self.dataset[idx]
# 'zyTX_1BXKDE_16000_26000' -> 'zyTX_1BXKDE'
video_name = Path(spec_path).stem[:11]
item['input'] = np.load(spec_path)
item['input_path'] = spec_path
# if self.split in ['train', 'valid']:
item['target'] = self.video2target[video_name]
item['label'] = self.target2label[item['target']]
if self.transforms is not None:
item = self.transforms(item)
return item
def __len__(self):
return len(self.dataset)
def make_split_files(self):
random.seed(1337)
logger.info(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
# The downloaded videos (some went missing on YouTube and no longer available)
available_vid_paths = sorted(glob(os.path.join(self.specs_dir, '*_mel.npy')))
logger.info(f'The number of clips available after download: {len(available_vid_paths)}')
# original (full) train and test sets
vggsound_meta = list(csv.reader(open(self.meta_path), quotechar='"'))
train_vids = {row[0] for row in vggsound_meta if row[3] == 'train'}
test_vids = {row[0] for row in vggsound_meta if row[3] == 'test'}
logger.info(f'The number of videos in vggsound train set: {len(train_vids)}')
logger.info(f'The number of videos in vggsound test set: {len(test_vids)}')
# class counts in test set. We would like to have the same distribution in valid
unique_classes = sorted(list(set(row[2] for row in vggsound_meta)))
label2target = {label: target for target, label in enumerate(unique_classes)}
video2target = {row[0]: label2target[row[2]] for row in vggsound_meta}
test_vid_classes = [video2target[vid] for vid in test_vids]
test_target2count = collections.Counter(test_vid_classes)
# now given the counts from test set, sample the same count for validation and the rest leave in train
train_vids_wo_valid, valid_vids = set(), set()
for target, label in enumerate(label2target.keys()):
class_train_vids = [vid for vid in train_vids if video2target[vid] == target]
random.shuffle(class_train_vids)
count = test_target2count[target]
valid_vids.update(class_train_vids[:count])
train_vids_wo_valid.update(class_train_vids[count:])
# make file with a list of available test videos (each video should contain timestamps as well)
train_i = valid_i = test_i = 0
with open(os.path.join(self.splits_path, 'vggsound_train.txt'), 'w') as train_file, \
open(os.path.join(self.splits_path, 'vggsound_valid.txt'), 'w') as valid_file, \
open(os.path.join(self.splits_path, 'vggsound_test.txt'), 'w') as test_file:
for path in available_vid_paths:
path = path.replace('_mel.npy', '')
vid_name = Path(path).name
# 'zyTX_1BXKDE_16000_26000'[:11] -> 'zyTX_1BXKDE'
if vid_name[:11] in train_vids_wo_valid:
train_file.write(vid_name + '\n')
train_i += 1
elif vid_name[:11] in valid_vids:
valid_file.write(vid_name + '\n')
valid_i += 1
elif vid_name[:11] in test_vids:
test_file.write(vid_name + '\n')
test_i += 1
else:
raise Exception(f'Clip {vid_name} is neither in train, valid nor test. Strange.')
logger.info(f'Put {train_i} clips to the train set and saved it to ./data/vggsound_train.txt')
logger.info(f'Put {valid_i} clips to the valid set and saved it to ./data/vggsound_valid.txt')
logger.info(f'Put {test_i} clips to the test set and saved it to ./data/vggsound_test.txt')
if __name__ == '__main__':
from transforms import Crop, StandardNormalizeAudio, ToTensor
specs_path = '/home/nvme/data/vggsound/features/melspec_10s_22050hz/'
transforms = torchvision.transforms.transforms.Compose([
StandardNormalizeAudio(specs_path),
ToTensor(),
Crop([80, 848]),
])
datasets = {
'train': VGGSound('train', specs_path, transforms),
'valid': VGGSound('valid', specs_path, transforms),
'test': VGGSound('test', specs_path, transforms),
}
print(datasets['train'][0])
print(datasets['valid'][0])
print(datasets['test'][0])
print(datasets['train'].class_counts)
print(datasets['valid'].class_counts)
print(datasets['test'].class_counts)