diff --git "a/data/prepare_each_dataset.py" "b/data/prepare_each_dataset.py" new file mode 100644--- /dev/null +++ "b/data/prepare_each_dataset.py" @@ -0,0 +1,3247 @@ +import os +import json +import csv +import yaml +from collections import defaultdict +import pickle +import glob +import math +from functools import partial +import sys +import io +import warnings +import random + +import numpy as np +import torch +import laion_clap + +import librosa +from pydub import AudioSegment +import soundfile as sf + +import faiss + +import multiprocessing +multiprocessing.set_start_method('spawn', force=True) + +try: + from tqdm import tqdm +except: + tqdm = lambda x: x + + +def suppress_all_output(func): + def wrapper(*args, **kwargs): + old_stdout = sys.stdout + old_stderr = sys.stderr + + sys.stdout = io.StringIO() + sys.stderr = io.StringIO() + + old_fd_out = os.dup(1) + old_fd_err = os.dup(2) + null_fd = os.open(os.devnull, os.O_RDWR) + + os.dup2(null_fd, 1) + os.dup2(null_fd, 2) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + try: + result = func(*args, **kwargs) + finally: + os.dup2(old_fd_out, 1) + os.dup2(old_fd_err, 2) + os.close(null_fd) + os.close(old_fd_out) + os.close(old_fd_err) + + sys.stdout = old_stdout + sys.stderr = old_stderr + + return result + return wrapper + + +def filter_file(file_path, file_list, filename): + if file_list is not None: + if filename not in file_list: + print(filename, 'not exist') + return True + else: + if not os.path.exists(os.path.join(file_path, filename)): + print(filename, 'not exist') + return True + + if os.path.getsize(os.path.join(file_path, filename)) < 16000: + print(filename, 'less than 0.5 to 1 second') + return True + + return False + + +# ==================== Prepare dataset files from each data folder ==================== + +EMOTION_MAP_DICT = { + 'amused': 'amused' , + 'anger': 'angry' , 'angry': 'angry' , + 'anxious': 'anxious' , + 'apologetic': 'apologetic' , + 'assertive': 'assertive' , + 'calm': 'calm' , + 'concerned': 'concerned' , + 'contempt': 'contempt' , + 'disgust': 'disgusted' , 'disgusted': 'disgusted' , + 'encouraging': 'encouraging' , + 'excited': 'excited' , + 'fear': 'fearful' , 'fearful': 'fearful' , + 'frustated': 'frustated' , + 'happy': 'happy' , 'joy': 'happy' , + 'neutral': 'neutral' , + 'sad': 'sad' , 'sadness': 'sad' , + 'sleepy': 'sleepy' , + 'surprise': 'surprised' , 'surprised': 'surprised' , + 'pleasantly surprised': 'pleasantly surprised' , +} + +def load_dataset_file(dataset_file): + with open(dataset_file) as f: + contents = f.read() + contents = json.loads(contents) + + audio_files = [ + os.path.join( + contents["dataset_path"], + contents["split_path"], + contents["data"][str(i)]["name"] + ) for i in range(contents["total_num"]) + ] + + return contents, audio_files + + +def compute_label_graph(dataset_name, dataset_path, top_n, output_file): + if os.path.exists(output_file): + print('loading precomputed graph:', output_file) + with open(output_file, 'r') as json_file: + graph = json.load(json_file) + + else: + import torch + from sentence_transformers import SentenceTransformer, util + embedding_model = SentenceTransformer('all-MiniLM-L6-v2') + + print('precomputing graph and save to:', output_file) + + if dataset_name == 'AudioSetSL_singlelabel': + names = [] + with open(os.path.join(dataset_path, 'class_labels_indices.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in reader: + _, label, name = row # 123, /m/02zsn, "Female speech, woman speaking" + names += name.split(', ') + names = [x.lower().strip() for x in names] + + elif dataset_name == "Clotho-AQA_singlelabel": + names = set([]) + with open(os.path.join(dataset_path, 'clotho_aqa_metadata.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + _, file_name, keywords, _, _, _, _ = row + names |= set(keywords.split(';')) + names = [x.lower().strip() for x in names] + + names_embeddings = embedding_model.encode(names, convert_to_tensor=True) + similarity_matrix = util.pytorch_cos_sim(names_embeddings, names_embeddings) + + similarity_threshold = 0.75 + n_items = len(names) + + graph = {} + for i in range(n_items): + adjusted_top_n = min(top_n, n_items - 1) + values, indices = torch.topk(similarity_matrix[i], adjusted_top_n + 1, largest=True) + + most_similar_items = [] + for value, idx in zip(values, indices): + if idx != i and value <= similarity_threshold: + most_similar_items.append(idx.item()) + if len(most_similar_items) == adjusted_top_n: + break + graph[names[i]] = [names[j] for j in most_similar_items] + + with open(output_file, 'w') as json_file: + json.dump(graph, json_file) + + # graph is a dict: key = each label, value = List[20 similar labels] + return graph + + +def prepare_files(dataset_name, dataset_path, split, flamingo_task, output_file): + + assert not os.path.exists(output_file) + dataset_dic = { + "dataset_path": dataset_path, + "split": split, + "split_path": None, + "flamingo_task": "{}-{}".format(dataset_name, flamingo_task), + "total_num": 0, + "data": {} # {id: {'name': name, 'prompt': prompt, 'output': output}} + } + + if dataset_name == "AudioSet": + assert flamingo_task == "EventClassification" + + assert split == 'train' + map_split = lambda split: 'train_wav' if split == 'train' else '' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + dic = defaultdict(str) + with open(os.path.join(dataset_path, 'class_labels_indices.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + _, label, name = row # /m/02zsn,"Female speech, woman speaking" + dic[label] = name + + with open(os.path.join(dataset_path, 'train.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + filename, _, _, labels = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r + filename = filename + '.wav' + if filter_file(file_path, file_list, filename): + continue + + label_list = labels.split(",") + assert all(label in dic for label in label_list) + + text_output = ", ".join([dic[label] for label in label_list]) + if len(text_output) <= 1: + continue + text_prompt = 'this is a sound of' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "AudioSetFull": + assert flamingo_task == "EventClassification" + + assert split == 'train' + map_split = lambda split: '/mnt/fsx-main/rafaelvalle/datasets/audioset/unbalanced_train_segments/22khz' + file_path = map_split(split) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + dic_code2label = defaultdict(str) + with open(os.path.join(dataset_path, 'audioset-processing/data/class_labels_indices.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + _, code, name = row # /m/02zsn,"Female speech, woman speaking" + dic_code2label[code] = name + + dic_filename2code = {} + with open(os.path.join(dataset_path, 'audioset-processing/data/unbalanced_train_segments.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + next(reader) + for row in tqdm(reader): + filename, _, _, codes = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r + filename = 'Y' + filename + '.wav' + dic_filename2code[filename] = codes.split(",") + + for part in tqdm(range(41)): + part_str = str(part) + if len(part_str) == 1: + part_str = '0' + part_str + part_folder = 'unbalanced_train_segments_part{}'.format(part_str) + + for filename in os.listdir(os.path.join(file_path, part_folder)): + if not filename.endswith('.wav'): + continue + + if filter_file(file_path, file_list, os.path.join(part_folder, filename)): + continue + + if filename not in dic_filename2code: + continue + + text_output = ", ".join([dic_code2label[code] for code in dic_filename2code[filename] if code in dic_code2label]) + if len(text_output) <= 1: + continue + text_prompt = 'this is a sound of' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": os.path.join(part_folder, filename), + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "AudioSetFullwoAudioMusicCaps": + assert flamingo_task == "EventClassification" + + assert split == 'train' + map_split = lambda split: '/mnt/fsx-main/rafaelvalle/datasets/audioset/unbalanced_train_segments/22khz' + file_path = map_split(split) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + print('extracting AudioCaps and MusicCaps ytid to avoid these samples') + audiocaps_ytid = [] + for f in ['audiocaps_dataset/train.csv', 'audiocaps_dataset/test.csv', 'audiocaps_dataset/val.csv']: + with open(os.path.join(dataset_path, f), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in reader: + _, ytid, _, _ = row + audiocaps_ytid.append('Y' + ytid + '.wav') + audiocaps_ytid = set(audiocaps_ytid) + + musiccaps_ytid = [] + with open(os.path.join(dataset_path, 'musiccaps_dataset/musiccaps_manifest.json')) as f: + data = f.read() + musiccaps_list = json.loads(data) + for row in musiccaps_list: + musiccaps_ytid.append('Y' + row["ytid"] + '.wav') + musiccaps_ytid = set(musiccaps_ytid) + + print('Will exclude {} samples from MusicCaps and {} from AudioCaps'.format(len(audiocaps_ytid), len(musiccaps_ytid))) + + dic_code2label = defaultdict(str) + with open(os.path.join(dataset_path, '../AudioSetFull/audioset-processing/data/class_labels_indices.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + _, code, name = row # /m/02zsn,"Female speech, woman speaking" + dic_code2label[code] = name + + dic_filename2code = {} + with open(os.path.join(dataset_path, '../AudioSetFull/audioset-processing/data/unbalanced_train_segments.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + next(reader) + for row in tqdm(reader): + filename, _, _, codes = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r + filename = 'Y' + filename + '.wav' + dic_filename2code[filename] = codes.split(",") + + music_audio_caps_excluded = 0 + for part in tqdm(range(41)): + part_str = str(part) + if len(part_str) == 1: + part_str = '0' + part_str + part_folder = 'unbalanced_train_segments_part{}'.format(part_str) + + for filename in os.listdir(os.path.join(file_path, part_folder)): + if not filename.endswith('.wav'): + continue + + if filename in audiocaps_ytid or filename in musiccaps_ytid: + music_audio_caps_excluded += 1 + continue + + if filter_file(file_path, file_list, os.path.join(part_folder, filename)): + continue + + if filename not in dic_filename2code: + continue + + text_output = ", ".join([dic_code2label[code] for code in dic_filename2code[filename] if code in dic_code2label]) + if len(text_output) <= 1: + continue + text_prompt = 'this is a sound of' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": os.path.join(part_folder, filename), + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "AudioSetSL_singlelabel": + import numpy as np + + assert flamingo_task == "EventClassification" + + assert split == 'train' + map_split = lambda split: '../AudioSet/train_wav' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + dic = defaultdict(str) + with open(os.path.join(dataset_path, 'class_labels_indices.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + _, label, name = row # /m/02zsn,"Female speech, woman speaking" + dic[label] = name + + graph = compute_label_graph( + dataset_name, + dataset_path, + top_n=200, + output_file=os.path.join(dataset_path, 'label_graph.json') + ) + + with open(os.path.join(dataset_path, 'train.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + filename, _, _, labels = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r + filename = filename + '.wav' + if filter_file(file_path, file_list, filename): + continue + + label_list = labels.split(",") + assert all(label in dic for label in label_list) + + text_labels = ", ".join([dic[label] for label in label_list]).lower() + text_labels = text_labels.split(', ') + text_output = np.random.choice(text_labels) + if len(text_output) <= 1: + continue + + num_options = np.random.choice( + [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + p=[ 0.05, 0.1, 0.1, 0.1, 0.1, + 0.05, 0.05, 0.05, 0.1, 0.05, + 0.05, 0.1, 0.05, 0.05] + ) + + negative_samples = [x for x in graph[text_output] if x not in set(text_labels)] + candidate_negative_labels = list(np.random.choice( + negative_samples[:num_options*10], + size=num_options-1, + replace=False + )) + if type(candidate_negative_labels) is str: + candidate_negative_labels = [candidate_negative_labels] + + all_options = [text_output] + candidate_negative_labels + np.random.shuffle(all_options) + + text_prompt = 'Classify this sound.\nOPTIONS:\n - {}.'.format( + '.\n - '.join(all_options) + ) + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "AUDIOCAPS13k": + assert flamingo_task == 'AudioCaptioning' + + map_split = lambda split: 'audio_32000Hz/{}'.format(split) + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) + + with open(os.path.join( + dataset_path, + '{}_manifest.json'.format(split + ('_v2' if split == 'train' else '')) + ), 'r') as f: + data = f.readlines() + data = [json.loads(row) for row in data] + + for row in tqdm(data): + filename = row['audio_filepath'].split('/')[-1] + if filter_file(file_path, file_list, filename): + continue + + text_output = row['text'] + if len(text_output) <= 1: + continue + text_prompt = 'generate audio caption' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "audiocaps": + assert flamingo_task == 'AudioCaptioning' + + map_split = lambda split: 'audio/{}'.format(split if split in ['train', 'test'] else 'valid') + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) + + for filename in tqdm(file_list): + if filter_file(file_path, file_list, filename): + continue + + with open(os.path.join(file_path, filename.replace('.flac', '.json')), 'r') as f: + data = json.load(f) + + captions = data['text'] + for text_output in captions: + if len(text_output) <= 1: + continue + text_prompt = 'generate audio caption' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == 'BG-Gun-Sound-Dataset': + assert flamingo_task == "SoundClassification" + assert split in ["train", "test"] + + map_split = lambda split: 'data/gun_sound_v2' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = os.listdir(file_path) + + all_cates = set([]) + with open(os.path.join(dataset_path, 'data/v3_exp3_{}.csv'.format(split)), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + filename, cate, dist, dire = row + if filter_file(file_path, file_list, filename): + continue + + text_output = cate + if len(text_output) <= 1: + continue + text_prompt = 'What is the gun of this sound?' + + all_cates.add(cate) + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + print(all_cates) + + elif dataset_name == "BirdsDataset": + assert flamingo_task == "SoundClassification" + assert split == 'train' + + map_split = lambda split: 'Voice_of_Birds' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + for bird_type in tqdm(os.listdir(file_path)): + bird_name = ' '.join(bird_type.split('_')[:-1]) + for filename in os.listdir(os.path.join(file_path, bird_type)): + if filter_file(file_path, file_list, os.path.join(bird_type, filename)): + continue + + text_output = bird_name + if len(text_output) <= 1: + continue + text_prompt = 'What is the name of bird in this sound?' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": os.path.join(bird_type, filename), + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "BBCSoundEffects": + assert split in ['train'] + assert flamingo_task == 'AudioDescription' + + map_split = lambda split: '../WavCaps/BBC_Sound_Effects_flac' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) + + with open(os.path.join(dataset_path, 'BBCSoundDownloader/BBCSoundEffects.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + if len(row) != 7: + continue + filename, description, _, _, _, _, _ = row + filename = filename.replace('.wav', '.flac') + + if filter_file(file_path, file_list, filename): + continue + + text_output = description + if len(text_output) <= 1: + continue + text_prompt = 'generate audio description' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "chime-home": + assert flamingo_task == "EventClassification" + assert split == 'train' + + map_split = lambda split: 'chime_home/chunks' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file48k_list = list(filter(lambda x: x.endswith('48kHz.wav'), os.listdir(file_path))) + file16k_list = list(filter(lambda x: x.endswith('16kHz.wav'), os.listdir(file_path))) + csv_file_list = list(filter(lambda x: x.endswith('.csv'), os.listdir(file_path))) + + label_mapping = { + 'c': 'child speaking', + 'm': 'male speaking', + 'f': 'female speaking', + 'p': 'human activity', + 't': 'television', + 'b': 'household appliances', + 's': 'silence' + } + + for csv_file in tqdm(csv_file_list): + with open(os.path.join(file_path, csv_file), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + + labels = None + for row in reader: + if row[0] == 'majorityvote': + labels = row[1] + break + + if labels is None or len(labels) == 0: + continue + + filename = csv_file.replace('.csv', '.48kHz.wav') + if filter_file(file_path, file48k_list, filename): + filename = csv_file.replace('.csv', '.16kHz.wav') + if filter_file(file_path, file16k_list, filename): + continue + + text_output = ", ".join([label_mapping[l] for l in labels if l in label_mapping]) + if len(text_output) <= 1: + continue + text_prompt = 'this is a sound of' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "CLAP_freesound": + assert flamingo_task == "AudioCaptioning" + assert split in ["train", "test"] + + map_split = lambda split: os.path.join('freesound_no_overlap/split', split) + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) + + with open(os.path.join( + dataset_path, + 'freesound_no_overlap_meta.csv' + ), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + if len(row[0].split('/')) != 2: + continue + if len(row) <= 1: + continue + + file_split, filename = row[0].split('/') + + if file_split != split: + continue + if filter_file(file_path, file_list, filename): + continue + + caption_1 = row[1] # caption_2 = row[2] but not very good + text_output = caption_1 + if len(text_output) <= 2: + continue + + text_prompt = 'generate audio caption' + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "Clotho-AQA": + + map_split = lambda split: 'audio_files' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + if flamingo_task == "EventClassification": + dic = defaultdict(str) + with open(os.path.join(dataset_path, 'clotho_aqa_metadata.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + _, file_name, keywords, _, _, _, _ = row + dic[file_name] = keywords.replace(';', ', ') + + with open(os.path.join(dataset_path, 'clotho_aqa_{}.csv'.format(split)), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + filename = row[0] + if filename not in dic or filter_file(file_path, file_list, filename): + continue + + text_output = dic[filename] + if len(text_output) <= 1: + continue + text_prompt = 'this is a sound of' + del dic[filename] + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif flamingo_task == "AQA": + dic_qa = defaultdict(list) + with open(os.path.join(dataset_path, 'clotho_aqa_{}.csv'.format(split)), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + filename, question, answer, confidence = row + dic_qa[(filename, question)].append((answer.lower(), confidence.lower())) + + # get binary -> trinary + def preprocess(list_ans_conf): + assert set([x[1] for x in list_ans_conf]) <= set(['yes', 'no', 'maybe']) + + answers = set([x[0].lower() for x in list_ans_conf]) + if answers <= set(['yes', 'no']): + if len(answers) > 1: + return ['unsure'] + else: + return list(answers) + else: + return list(answers) + + # get majority vote + def majority_vote(list_ans_conf): + assert set([x[1] for x in list_ans_conf]) <= set(['yes', 'no', 'maybe']) + weight = {'yes': 1.0, 'no': 0.1, 'maybe': 0.6} + + if set([x[0] for x in list_ans_conf]) <= set(['yes', 'no']): + score = {'yes': 1.0, 'no': -1.0} + pred = sum([score[x[0]] * weight[x[1]] for x in list_ans_conf]) + if pred > 0: + return ['yes'] + else: + return ['no'] + else: + return list(set([x[0] for x in list_ans_conf])) + + for key in dic_qa: + filename, question = key + if filter_file(file_path, file_list, filename): + continue + + if split == 'train': + answers = majority_vote(dic_qa[key]) # majority vote + else: + answers = [x[0].strip().lower() for x in dic_qa[key]] + answers = [', '.join(answers)] + + for answer in answers: + text_output = answer + if len(text_output) <= 1: + continue + text_prompt = "Question: " + question + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "Clotho-AQA_singlelabel": + import numpy as np + + assert flamingo_task == "EventClassification" + + map_split = lambda split: '../Clotho-AQA/audio_files' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + dic = defaultdict(str) + with open(os.path.join(dataset_path, 'clotho_aqa_metadata.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + _, file_name, keywords, _, _, _, _ = row + dic[file_name] = keywords.split(';') + + graph = compute_label_graph( + dataset_name, + dataset_path, + top_n=300, + output_file=os.path.join(dataset_path, 'label_graph.json') + ) + + with open(os.path.join(dataset_path, 'clotho_aqa_{}.csv'.format(split)), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + filename = row[0] + if filename not in dic or filter_file(file_path, file_list, filename): + continue + + text_labels = [x.lower().strip() for x in dic[filename]] + del dic[filename] + + for _ in range(6): + text_output = np.random.choice(text_labels) + if len(text_output) <= 1: + continue + + num_options = np.random.choice( + [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + p=[ 0.05, 0.1, 0.1, 0.1, 0.1, + 0.05, 0.05, 0.05, 0.1, 0.05, + 0.05, 0.1, 0.05, 0.05] + ) + + negative_samples = [x for x in graph[text_output] if x not in set(text_labels)] + candidate_negative_labels = list(np.random.choice( + negative_samples[:num_options*20], + size=num_options-1, + replace=False + )) + if type(candidate_negative_labels) is str: + candidate_negative_labels = [candidate_negative_labels] + + all_options = [text_output] + candidate_negative_labels + np.random.shuffle(all_options) + + text_prompt = 'Classify this sound.\nOPTIONS:\n - {}.'.format( + '.\n - '.join(all_options) + ) + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "Clotho-v2": + assert flamingo_task == "AudioCaptioning" + assert split in ["train", "val", "test"] + + map_split = lambda split: 'development' if split == 'train' else ('validation' if split == "val" else "evaluation") + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + with open(os.path.join( + dataset_path, + 'clotho_captions_{}.csv'.format(map_split(split)) + ), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + filename = row[0] + if filter_file(file_path, file_list, filename): + continue + + for text_output in row[1:]: + if len(text_output) <= 1: + continue + text_prompt = 'generate audio caption' + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "CochlScene": + import ndjson + assert flamingo_task == "SceneClassification" + + map_split = lambda split: split.capitalize() + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + with open(os.path.join(dataset_path, 'cochlscene_{}.ndjson'.format(split))) as ndjsonfile: + reader = ndjson.load(ndjsonfile) + for row in tqdm(reader): + filename = "/".join(row["audiopath"].split("/")[1:]) + if filter_file(file_path, file_list, filename): + continue + + text_output = row["labels"].lower() + if len(text_output) <= 1: + continue + text_prompt = 'this acoustic scene is' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "common-accent": + import ndjson + import re + + assert flamingo_task == "AccentClassification" + assert split in ["train", "test"] + + map_split = lambda split: '22khz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = os.listdir(file_path) + + all_accent = [] + split_file = [f for f in os.listdir(dataset_path) if f.startswith(split) and f.endswith('.ndjson')][0] + with open(os.path.join(dataset_path, split_file)) as ndjsonfile: + reader = ndjson.load(ndjsonfile) + for row in tqdm(reader): + accent = row["accent"] + accent = re.sub(r'\(.*?\)', '', accent) + accent = accent.replace('English', '') + accent = accent.split(',') + accent = [x.strip() for x in accent if 'school' not in x] + all_accent += accent + + filename = row["filename"] + if filter_file(file_path, file_list, filename): + continue + + for accent_each in accent: + if accent_each == 'Javanese': + accent_each = 'Japanese' + if len(accent_each) > 25: + continue + + text_output = accent_each + if len(text_output) <= 1: + continue + text_prompt = 'Classify the accent of this speech.' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + print('all accents:', list(set(all_accent))) + + elif dataset_name == "CREMA-D": + assert flamingo_task == "EmotionClassification" + assert split in ["train"] + + map_split = lambda split: 'AudioWAV' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + split_file = os.path.join( + dataset_path, + 'crema-d_audiopath_text_sid_emotion_filelist.txt' + ) + with open(split_file, 'r') as f: + data = f.readlines() + data = [x.replace('\n', '') for x in data] + + for row in tqdm(data): + if row.count('|') != 3: + continue + filename, utterances, speaker, emotion = row.split('|') + if filter_file(file_path, file_list, filename): + continue + + text_output = emotion + text_prompt = 'this emotion is' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "DCASE17Task4": + assert flamingo_task == "SceneClassification" + assert split in ["test"] + + map_split = lambda split: 'unbalanced_train_segments_testing_set_audio_formatted_and_segmented_downloads' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + split_file = os.path.join( + dataset_path, + 'Task-4-Large-scale-weakly-supervised-sound-event-detection-for-smart-cars', + 'groundtruth_release', + 'groundtruth_strong_label_testing_set.csv' + ) + + dic = defaultdict(list) + all_labels = [] + with open(split_file, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter='\t', quotechar='"') + for row in tqdm(reader): + filename = 'Y' + row[0] + label = row[-1] + + if filter_file(file_path, file_list, filename): + continue + + dic[filename] += label.split(', ') + all_labels += label.split(', ') + + print('all labels:\n', ', '.join(list(set(all_labels)))) + + for filename in dic: + text_output = ', '.join(list(set(dic[filename]))) + text_prompt = 'this acoustic scene is' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "emov-db": + assert flamingo_task == "EmotionClassification" + assert split in ["train", "val"] + + map_split = lambda split: '22khz_from_16khz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + split_file = os.path.join( + dataset_path, + 'cleaned_emov_db_audiopath_text_sid_emotion_duration_filelist_merged_{}.txt'.format(split) + ) + with open(split_file, 'r') as f: + data = f.readlines() + data = [x.replace('\n', '') for x in data] + + for row in tqdm(data): + if row.count('|') != 4: + continue + filename, utterances, speaker, emotion, duration = row.split('|') + if filter_file(file_path, file_list, filename): + continue + + text_output = emotion + text_output = EMOTION_MAP_DICT[text_output] + text_prompt = 'this emotion is' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "Epidemic_sound": + assert split == 'train' + assert flamingo_task in ["AudioCaptioning", "Tagging"] + + map_split = lambda split: 'audio' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.mp3'), os.listdir(file_path))) + + with open(os.path.join(dataset_path, 'Epidemic_all_debiased.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + if len(row) != 5: + continue + _, caption_1, caption_2, caption_t5, fileid = row + filename = '{}.mp3'.format(fileid) + if filter_file(file_path, file_list, filename): + continue + + if flamingo_task == "AudioCaptioning": + text_output = caption_t5 + if len(text_output) <= 1: + continue + text_prompt = 'generate audio caption' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif flamingo_task == "Tagging": + if not caption_2.startswith('the sounds of'): + continue + caption_2 = caption_2.replace('the sounds of ', '') + caption_2 = caption_2.replace(', and', ',') + if len(caption_2) < 2: + continue + + tags = caption_2.split(', ') + tags = list(map(lambda x: x.replace("'", "").strip().lower(), tags)) + text_output = '{}'.format(', '.join(tags)) + if len(text_output) <= 1: + continue + text_prompt = 'generate tags' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "ESC50": + assert flamingo_task in ["EventClassification"] + assert split == 'train' + + map_split = lambda split: 'ESC-50-master/audio' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + with open(os.path.join(dataset_path, 'ESC-50-master/meta/esc50.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + if len(row) != 7: + continue + + filename, fold, target, category, esc10, src_file, take = row + if filter_file(file_path, file_list, filename): + continue + + text_output = category.replace('_', ' ') + text_prompt = 'classify this sound.' + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "FMA": + import ast + + assert flamingo_task in ["GenreClassification"] + assert split == 'train' + + map_split = lambda split: 'fma_large' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + with open(os.path.join(dataset_path, 'fma_metadata/raw_tracks.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + if len(row) != 39: + continue + track_id,album_id,album_title,album_url, \ + artist_id,artist_name,artist_url,artist_website, \ + license_image_file,license_image_file_large, \ + license_parent_id,license_title,license_url, \ + tags,track_bit_rate,track_comments,track_composer, \ + track_copyright_c,track_copyright_p,track_date_created,track_date_recorded, \ + track_disc_number,track_duration,track_explicit,track_explicit_notes, \ + track_favorites,track_file,track_genres,track_image_file,track_information, \ + track_instrumental,track_interest,track_language_code, \ + track_listens,track_lyricist,track_number,track_publisher,track_title,track_url = row + + l = len(str(track_id)) + if l <= 3: + filename = '{}/{}.mp3'.format( + '000', + '0'*(6-l)+str(track_id) + ) + else: + filename = '{}/{}.mp3'.format( + '0'*(6-l)+str(track_id)[:l-3], + '0'*(6-l)+str(track_id) + ) + if filter_file(file_path, file_list, filename): + continue + + if len(track_genres) == 0: + continue + + track_genres = ast.literal_eval(track_genres) + genres = ', '.join([dic['genre_title'].lower().strip() for dic in track_genres]) + text_output = genres + '.' + + text_prompt = "what is the genre of this music?" + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "FSD50k": + import ndjson + assert flamingo_task == "EventClassification" + assert split in ["train", "test"] + + map_split = lambda split: '44khz/dev' if split == 'train' else '44khz/eval' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + with open(os.path.join(dataset_path, '{}.ndjson'.format(map_split(split).replace('44khz/', '')))) as ndjsonfile: + reader = ndjson.load(ndjsonfile) + for row in tqdm(reader): + filename = row["filepath"].split("/")[1] + if filter_file(file_path, file_list, filename): + continue + + labels = [x.replace("_", " ").lower() for x in row["labels"]] + text_output = ", ".join(labels) + if len(text_output) <= 1: + continue + text_prompt = 'this is a sound of' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "GTZAN": + assert flamingo_task == "GenreClassification" + assert split in ["train"] + + map_split = lambda split: 'gtzan/data/genres' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + for genre in os.listdir(file_path): + genre_wavs = [x for x in os.listdir(os.path.join(file_path, genre)) if x.endswith('.wav')] + + for genre_wav in genre_wavs: + filename = os.path.join(genre, genre_wav) + if filter_file(file_path, file_list, filename): + continue + + text_output = genre + if len(text_output) <= 1: + continue + text_prompt = 'What is the genre of this music?' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "IEMOCAP": + assert flamingo_task == "EmotionClassification" + assert split in ["train", "test"] + + map_split = lambda split: 'IEMOCAP_full_release/16khz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + def read_this_ndjson(file_path): + dic_list = [] + with open(file_path, 'r') as f: + for line in f: + turn_name = line.split("'turn_name': ")[-1].split(',')[0].replace("'", "") + emotion = line.split("'emotion': ")[-1].split(',')[0].replace("'", "") + dic = { + 'turn_name': turn_name, + 'emotion': emotion + } + dic_list.append(dic) + return dic_list + + all_emotions = [] + meta_files = [x for x in os.listdir(os.path.join(dataset_path, 'IEMOCAP_full_release/ndjson')) if x.endswith('.ndjson')] + for meta_file in tqdm(meta_files): + main_folder = meta_file.split('_')[0] + sub_folder = (meta_file.split('.ndjson')[0])[len(main_folder)+1:] + + if split == "train" and main_folder == "Session5": + continue + elif split == "test" and main_folder != "Session5": + continue + + metadata_list = read_this_ndjson(os.path.join(dataset_path, 'IEMOCAP_full_release/ndjson', meta_file)) + + for dic in metadata_list: + filename = os.path.join(main_folder, sub_folder, dic['turn_name']+'.wav') + if filter_file(file_path, file_list, filename): + continue + + if dic['emotion'] in ['unknown', 'other']: + continue + + text_output = dic['emotion'] + text_output = EMOTION_MAP_DICT[text_output] + all_emotions.append(text_output) + + text_prompt = 'this emotion is' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + print('all emotions:', list(set(all_emotions))) + + elif dataset_name == "jl-corpus": + assert flamingo_task == "EmotionClassification" + assert split in ["train", "val"] + + map_split = lambda split: '44khz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + split_file = os.path.join( + dataset_path, + 'jl-corpus_audiopath_text_sid_emotion_duration_{}_filelist.txt'.format(split) + ) + with open(split_file, 'r') as f: + data = f.readlines() + data = [x.replace('\n', '') for x in data] + + for row in tqdm(data): + if row.count('|') != 4: + continue + filename, utterances, speaker, emotion, duration = row.split('|') + if filter_file(file_path, file_list, filename): + continue + + text_output = emotion + text_output = EMOTION_MAP_DICT[text_output] + text_prompt = 'this emotion is' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "LP-MusicCaps-MC": + import pandas as pd + assert flamingo_task in ["AudioCaptioning"] + assert split in ["train", "test"] + + map_split = lambda split: '../MusicCaps/44khz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + parquet_files = [f for f in os.listdir(os.path.join(dataset_path, 'data')) if f.endswith('.parquet') and f.startswith(split)] + print('parquet_files', parquet_files) + metadata_df = pd.concat([pd.read_parquet(os.path.join(dataset_path, 'data', f)) for f in parquet_files]) + + for index, row in tqdm(metadata_df.iterrows()): + filename = row['ytid'] + '.wav' + if filter_file(file_path, file_list, filename): + continue + + text_prompt = 'generate audio caption' + for caption in [row['caption_writing'], row['caption_summary'], row['caption_paraphrase']]: + text_output = caption + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "LP-MusicCaps-MSD": + import pandas as pd + assert flamingo_task in ["AudioCaptioning"] + assert split in ["train", "test", "val"] + + map_split = lambda split: '../MSD/mp3s_22khz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + parquet_files = [f for f in os.listdir(dataset_path) if f.endswith('.parquet') and f.startswith(split)] + print('parquet_files', parquet_files) + metadata_df = pd.concat([pd.read_parquet(os.path.join(dataset_path, f)) for f in parquet_files]) + + for index, row in tqdm(metadata_df.iterrows()): + filename = row['path'] + if filter_file(file_path, file_list, filename): + continue + + text_prompt = 'generate audio caption' + for caption in [row['caption_writing'], row['caption_summary'], row['caption_paraphrase']]: + text_output = caption + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "LP-MusicCaps-MTT": + import pandas as pd + assert flamingo_task in ["AudioCaptioning"] + assert split in ["train", "test", "val"] + + map_split = lambda split: '../MagnaTagATune/16khz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + parquet_files = [f for f in os.listdir(dataset_path) if f.endswith('.parquet') and f.startswith(split)] + print('parquet_files', parquet_files) + metadata_df = pd.concat([pd.read_parquet(os.path.join(dataset_path, f)) for f in parquet_files]) + + for index, row in tqdm(metadata_df.iterrows()): + filename = row['path'] + if filter_file(file_path, file_list, filename): + continue + + text_prompt = 'generate audio caption' + for caption in [row['caption_writing'], row['caption_summary'], row['caption_paraphrase']]: + text_output = caption + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "MACS": + assert flamingo_task in ["AudioCaptioning", "Tagging"] + assert split == 'train' + + map_split = lambda split: 'TAU_Urban_Acoustic_Scenes_2019/TAU-urban-acoustic-scenes-2019-development/audio' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + metadata_list = yaml.load(open(os.path.join(dataset_path, 'MACS.yaml')), Loader=yaml.FullLoader)['files'] + + for file_metadata in tqdm(metadata_list): + filename = file_metadata['filename'] + if filter_file(file_path, file_list, filename): + continue + + for each_annotated in file_metadata['annotations']: + caption = each_annotated['sentence'] + tags = ', '.join(each_annotated['tags']).replace('_', ' ') + + if flamingo_task == "AudioCaptioning": + text_output = caption + text_prompt = 'generate audio caption' + + elif flamingo_task == "Tagging": + raise NotImplementedError + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "Medley-solos-DB": + import ndjson + assert flamingo_task in ["InstrClassification"] + + map_split = lambda split: '44khz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + with open(os.path.join(dataset_path, 'medleysolosdb_manifest.ndjson')) as ndjsonfile: + metadata_list = ndjson.load(ndjsonfile) + + for file_metadata in tqdm(metadata_list): + subset = file_metadata['subset'] + if not subset.startswith(split): + continue + + filename = file_metadata['filepath'] + if filter_file(file_path, file_list, filename): + continue + + instrument = file_metadata["instrument"] + + text_output = instrument + text_prompt = 'this music note is produced by' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "MELD": + import numpy as np + assert flamingo_task in ["EmotionClassification", "SentimentClassification"] + + map_split = lambda split: '44khz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + split_file = os.path.join( + dataset_path, + '{}.txt'.format(split if split in ['train', 'test'] else 'dev') + ) + with open(split_file, 'r') as f: + data = f.readlines() + data = [x.replace('\n', '') for x in data] + + emotion_count = { + 'neutral': 4703, 'happy': 1739, 'sad': 683, 'surprised': 1204, + 'disgusted': 271, 'angry': 1108, 'fearful': 268, + } + sentiment_count = { + 'neutral': 4703, 'positive': 2330, 'negative': 2943, + } + balancing_factor = 1 + + for row in tqdm(data): + if row.count('|') != 4: + continue + filename, utterances, speaker, emotion, sentiment = row.split('|') + if filter_file(file_path, file_list, filename): + continue + + if flamingo_task == "EmotionClassification": + text_output = emotion + text_output = EMOTION_MAP_DICT[text_output] + text_prompt = 'this emotion is' + + if split == 'train': + balancing_factor = float(emotion_count['neutral']) / float(emotion_count[text_output]) + + elif flamingo_task == "SentimentClassification": + text_output = sentiment + text_prompt = 'this sentiment is' + + if split == 'train': + balancing_factor = float(sentiment_count['neutral']) / float(sentiment_count[text_output]) + + if len(text_output) <= 1: + continue + + for _ in range(int(np.floor(balancing_factor))): + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + if np.random.rand() < balancing_factor - np.floor(balancing_factor): + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "MSP-PODCAST-Publish-1.9": + assert flamingo_task == "EmotionClassification" + assert split in ["train", "val", "test"] + + map_split = lambda split: 'Audio' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + + file_list = glob.glob('{}/*/*.wav'.format(file_path)) + file_list = [x[len(file_path)+1:] for x in file_list] + + subfolder_map = {} + for f in tqdm(file_list): + subfolder, filename = f.split('/') + subfolder_map[filename] = subfolder + file_list = None + + emotion_dic = { + 'A': 'Angry', + 'S': 'Sad', + 'H': 'Happy', + 'U': 'Surprise', + 'F': 'Fear', + 'D': 'Disgust', + 'C': 'Contempt', + 'N': 'Neutral', + 'O': 'Other', + 'X': 'Not clear' + } + + with open(os.path.join(dataset_path, 'Labels/labels_concensus.json')) as f: + data = f.read() + metadata_dic = json.loads(data) + + for filename in tqdm(list(metadata_dic.keys())): + values = metadata_dic[filename] + if not values["Split_Set"].lower().startswith(split): + continue + if values["EmoClass"] in ["O", "X"] or values["EmoClass"] not in emotion_dic.keys(): + continue + + subfolder = subfolder_map[filename] + filename = '{}/{}'.format(subfolder, filename) + if filter_file(file_path, file_list, filename): + continue + + text_output = emotion_dic[values["EmoClass"]].lower() + text_output = EMOTION_MAP_DICT[text_output] + text_prompt = 'this emotion is' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "mtg-jamendo": + import ndjson + assert flamingo_task == "MusicTagging" + assert split in ["train", "val"] + + map_split = lambda split: '44khz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + with open(os.path.join(dataset_path, 'mtg_jamendo_{}_manifest.ndjson'.format(split))) as ndjsonfile: + reader = ndjson.load(ndjsonfile) + for row in tqdm(reader): + filename = row["audiopath"] + if filter_file(file_path, file_list, filename): + continue + + text_output = row["caption"] + text_prompt = 'generate music tags (genre, instrument, mood/theme)' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "MU-LLAMA": + + assert flamingo_task in ['AQA'] + assert split in ['train', 'test'] + + map_split = lambda split: 'MusicQA/audios' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + split_file = 'MusicQA/FinetuneMusicQA.json' if split == 'train' else 'MusicQA/EvalMusicQA.json' + with open(os.path.join(dataset_path, split_file), 'r') as f: + data = f.read() + metadata_list = json.loads(data) + + for dic in tqdm(metadata_list): + filename = dic["audio_name"] + if filter_file(file_path, file_list, filename): + continue + + text_prompt = 'Question: ' + dic["conversation"][0]["value"].strip() + if not (text_prompt.endswith('.') or text_prompt.endswith('?')): + text_prompt = text_prompt + '.' + + text_output = dic["conversation"][1]["value"].strip() + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "musdbhq": + assert flamingo_task in ["InstrClassification"] + assert split in ["train", "test", "val"] + + map_split = lambda split: './' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + with open(os.path.join(dataset_path, 'file_list_44k_{}.txt'.format(split))) as f: + data = f.readlines() + data = [x.replace('\n', '') for x in data] + + for row in tqdm(data): + if row.count('|') != 1: + continue + + filename, duration = row.split('|') + duration = float(duration) + + if filter_file(file_path, file_list, filename): + continue + + text_output = filename.split('/')[-1].split('.wav')[0] + if len(text_output) <= 1: + continue + text_prompt = 'this music is produced by' + + segment_length = 10 + for audio_start_idx in range(int(duration // segment_length)): + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' '), + "audio_start": audio_start_idx * segment_length + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "Music-AVQA": + import ast + import re + + assert flamingo_task in [ + "{}_{}".format(q, t) \ + for q in ['AQA', 'AVQA'] \ + for t in ['Comparative', 'Counting', 'Existential', 'Location', 'Temporal', 'All'] + ] + + def replace_bracketed_words(input_string, replacements): + def replacer(match): + word = next(replacements) + return word + + replacements = iter(replacements) + output_string = re.sub(r'<[^>]*>', replacer, input_string) + return output_string + + map_split = lambda split: 'audio' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + with open(os.path.join(dataset_path, 'MUSIC-AVQA/data/json/avqa-{}.json'.format(split)), 'r') as f: + data = f.read() + metadata_list = json.loads(data) + + for dic in tqdm(metadata_list): + filename = dic["video_id"] + '.wav' + if filter_file(file_path, file_list, filename): + continue + + types = ast.literal_eval(dic["type"]) + if 'Visual' in types: + continue + + if flamingo_task.startswith('AQA_') and 'Audio-Visual' in types: + continue + + if flamingo_task.startswith('AVQA_') and 'Audio' in types: + continue + + t = flamingo_task.split('_')[1] + if (not t == 'All') and (not t in types): + continue + + text_output = dic["anser"] + if len(text_output) <= 1: + continue + + question = dic["question_content"].replace("\uff1f", '?') + templ_values = ast.literal_eval(dic["templ_values"]) + if len(templ_values) > 0: + question = replace_bracketed_words(question, templ_values) + text_prompt = "Question: " + question + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "MusicCaps": + assert flamingo_task in ["AudioCaptioning", "EventClassification"] + assert split in ["train", "test"] + + map_split = lambda split: '44khz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + with open(os.path.join(dataset_path, 'musiccaps_manifest.json')) as f: + data = f.read() + metadata_list = json.loads(data) + + for file_metadata in tqdm(metadata_list): + filename = file_metadata['filepath'] + if filter_file(file_path, file_list, filename): + continue + + start_s, end_s = file_metadata["start_s"], file_metadata["end_s"] + caption = file_metadata["caption"] + audioset_positive_labels = file_metadata["audioset_positive_labels"] # audioset classes + aspect_list = file_metadata["aspect_list"] # annotated classes + + if (split == 'train') == file_metadata["is_audioset_eval"]: + continue + + if flamingo_task == "AudioCaptioning": + text_output = caption + text_prompt = 'generate audio caption' + + elif flamingo_task == "EventClassification": + raise NotImplementedError + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "NonSpeech7k": + assert flamingo_task in ["EventClassification"] + assert split in ["train", "test"] + + map_split = lambda split: split + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + all_classes = [] + with open(os.path.join(dataset_path, 'metadata of {} set.csv').format(split), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + filename, _, _, _, classname, _, _, _ = row + if filter_file(file_path, file_list, filename): + continue + + text_output = classname.lower() + if len(text_output) <= 1: + continue + text_prompt = 'this is a sound of' + + all_classes.append(classname) + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + print('all classes:', list(set(all_classes))) + + elif dataset_name == "NSynth": + import ndjson + assert flamingo_task in [ + "InstrClassification", + "PitchClassification", + "VelocityClassification", + "SourceClassification", + "QualityClassification", + "MIR" + ] + assert split in ["train", "test", "val"] + + map_split = lambda split: 'nsynth-{}/audio'.format('valid' if split == 'val' else split) + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + with open(os.path.join(dataset_path, map_split(split), '../examples.json')) as f: + data = f.read() + reader = json.loads(data) + + for key in tqdm(reader): + filename = key + '.wav' + if filter_file(file_path, file_list, filename): + continue + + if flamingo_task == "InstrClassification": + text_output = reader[key]["instrument_family_str"] + text_prompt = 'this music note is produced by' + + elif flamingo_task == "PitchClassification": + text_output = str(reader[key]["pitch"]) + text_prompt = 'this music note has pitch' + + elif flamingo_task == "VelocityClassification": + text_output = str(reader[key]["velocity"]) + text_prompt = 'this music note has velocity' + + elif flamingo_task == "SourceClassification": + text_output = reader[key]["instrument_source_str"] + text_prompt = 'this music note has sonic source' + + elif flamingo_task == "QualityClassification": + qualities_str = reader[key]["qualities_str"] + if len(qualities_str) >= 1: + text_output = ', '.join(qualities_str).replace('_', ' ') + else: + text_output = 'none' + text_prompt = 'this music note has sonic qualities' + + elif flamingo_task == "MIR": + instrument = reader[key]["instrument_family_str"] + pitch = str(reader[key]["pitch"]) + velocity = str(reader[key]["velocity"]) + source = reader[key]["instrument_source_str"] + qualities_str = ', '.join(reader[key]["qualities_str"]).replace('_', ' ') + + assert len(instrument) > 0 + text_output = 'produced by {}'.format(instrument) + if len(pitch) > 0: + text_output = text_output + ', pitch {}'.format(pitch) + if len(velocity) > 0: + text_output = text_output + ', velocity {}'.format(velocity) + if len(source) > 0: + text_output = text_output + ', source {}'.format(source) + if len(qualities_str) > 0: + text_output = text_output + ', and having qualities like {}'.format(qualities_str) + + text_prompt = 'this music note is' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "OMGEmotion": + import numpy as np + import webrtcvad + import wave + from pydub import AudioSegment + + assert flamingo_task == "EmotionClassification" + assert split in ["train", "val"] + + def convert_to_wav(file_path): + audio = AudioSegment.from_file(file_path).set_frame_rate(16000).set_channels(1) + wav_path = file_path.rsplit('.', 1)[0] + "_converted.wav" + audio.export(wav_path, format="wav") + return wav_path + + def contains_speech(file_path, aggressiveness=0): + # aggressiveness between 0 and 3, 0 for very clean speech, and 3 for noisy speech + wav_path = convert_to_wav(file_path) + vad = webrtcvad.Vad(aggressiveness) + + with wave.open(wav_path, 'rb') as audio: + assert audio.getsampwidth() == 2, "Audio must be 16-bit" + assert audio.getnchannels() == 1, "Audio must be mono" + assert audio.getframerate() == 16000, "Audio must be sampled at 16kHz" + + frame_duration = 10 # ms + frame_size = int(audio.getframerate() * frame_duration / 1000) + num_frames = int(audio.getnframes() / frame_size) + + for _ in range(num_frames): + frame = audio.readframes(frame_size) + if vad.is_speech(frame, audio.getframerate()): + return True + + return False + + map_split = lambda split: 'processed-{}_utterance_data'.format(split) + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + dic_code2emotion = { + "0": "anger", + "1": "disgust", + "2": "fear", + "3": "happy", + "4": "neutral", + "5": "sad", + "6": "surprise", + } + + all_emotions = [] + meta_file = os.path.join( + dataset_path, + 'OMGEmotionChallenge', + 'omg_{}Videos.csv'.format('Train' if split == 'train' else 'Validation') + ) + + with open(meta_file, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + link, start, end, video, utterance, _, _, EmotionMaxVote = row + emotion = dic_code2emotion[str(EmotionMaxVote)] + + filename = os.path.join(video, utterance.replace('.mp4', '.mp3')) + if filter_file(file_path, file_list, filename): + continue + + if not contains_speech(os.path.join(file_path, filename)): + print('{} does not contain speech'.format(filename)) + continue + + text_prompt = 'this emotion is' + text_output = emotion + if len(text_output) <= 1: + continue + + all_emotions.append(EMOTION_MAP_DICT[emotion]) + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + print('all emotions:', list(set(all_emotions))) + + elif dataset_name == "OpenAQA": + + assert flamingo_task == 'AQA' + assert split == 'train' + + map_split = lambda split: './' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + no_word_list = [ + 'cannot determine', 'not provided', 'cannot be determined', 'sorry', 'i cannot', + 'without more information', 'enough information', + 'not possible', 'more context', 'enough', 'impossible', 'cannot be determined', + 'without additional information', + 'unclear', 'cannot', 'not clear', 'do not provide sufficient', 'does not provide', + 'difficult to determine', 'no information provided', + "can't infer", "difficult to infer", "not specified", "no specific", "no information", + "without additional", 'it is difficult to', "no indication" + ] + + print('computing dic_audiosetfull_parts') + audiosetfull_root = '/mnt/fsx-main/rafaelvalle/datasets/audioset/unbalanced_train_segments/22khz/' + part_strings = [('0'*(2-len(str(p))) + str(p)) for p in range(41)] + dic_audiosetfull_parts = { + part: set(os.listdir(os.path.join(audiosetfull_root, 'unbalanced_train_segments_part{}'.format(part)))) \ + for part in part_strings + } + + audioset20k_filelist = set(os.listdir(os.path.join(file_path, '../AudioSet/train_wav'))) + + print('computing dic_clotho_filename') + clotho_files = os.listdir(os.path.join(dataset_path, '../Clotho-AQA/audio_files')) + dic_clotho_filename = { + '_'.join([s for s in f.split(' ') if len(s) > 0]): f \ + for f in clotho_files + } + + print('reading open_ended/all_open_qa.json') + with open(os.path.join(dataset_path, 'openaqa/data/open_ended/all_open_qa.json'), 'r') as f: + data = f.read() + metadata_list = json.loads(data) + + for dic in tqdm(metadata_list): + #keys: instruction, input, dataset, audio_id, output, task + + text_output = dic["output"] + if len(text_output) <= 1: + continue + if any(word in text_output.lower() for word in no_word_list): + continue + + question = dic["instruction"] + text_prompt = question + + audio_id = dic["audio_id"] + subset = dic["dataset"] + if subset == 'clotho_development': + filename = audio_id.split('/')[-1] + processed_filename = '_'.join([s for s in filename.split('_') if len(s) > 0]) + if processed_filename in dic_clotho_filename: + filename = os.path.join( + '../Clotho-AQA/audio_files', + dic_clotho_filename[processed_filename] + ) + else: + continue + + elif subset in ['audiocaps_train', 'as_20k', 'as_strong_train']: + found = False + + filename = audio_id.split('/')[-1].split('.flac')[0] + '.wav' + if filename in audioset20k_filelist: + filename = os.path.join('../AudioSet/train_wav', filename) + found = True + else: + filename = 'Y' + filename + for part in part_strings: + if filename in dic_audiosetfull_parts[part]: + filename = os.path.join( + audiosetfull_root, + 'unbalanced_train_segments_part{}'.format(part), + filename + ) + found = True + break + + if not found: + print(filename, 'not found') + continue + + elif subset == 'freesound_10s': + filename = os.path.join( + '../CLAP_freesound/freesound_no_overlap/split/train', + audio_id.split('/')[-1] + ) + + elif subset == 'vggsound_train': + continue + + if filter_file(file_path, file_list, filename): + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "ravdess": + assert flamingo_task == "EmotionClassification" + assert split in ["train", "val"] + + map_split = lambda split: '44khz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + split_file = os.path.join( + dataset_path, + 'ravdess_audiopath_text_sid_emotion_duration_{}_filelist.txt'.format(split) + ) + with open(split_file, 'r') as f: + data = f.readlines() + data = [x.replace('\n', '') for x in data] + + for row in tqdm(data): + if row.count('|') != 4: + continue + filename, utterances, speaker, emotion, duration = row.split('|') + if filter_file(file_path, file_list, filename): + continue + + text_output = emotion + text_prompt = 'this emotion is' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "SongDescriber": + assert flamingo_task in ["AudioCaptioning"] + assert split in ["train"] + + map_split = lambda split: './audio/audio' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + with open(os.path.join(dataset_path, 'song_describer.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + + for row in tqdm(reader): + caption_id,track_id,caption,is_valid_subset,familiarity,artist_id,album_id,path,duration = row + filename = '{}/{}.2min.mp3'.format(track_id[-2:], track_id) + duration = float(duration) + + if filter_file(file_path, file_list, filename): + continue + + text_output = caption + if len(text_output) <= 1: + continue + text_prompt = 'generate audio caption' + + segment_length = 30 + for audio_start_idx in range(int(duration // segment_length)): + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' '), + "audio_start": audio_start_idx * segment_length + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "SONYC-UST": + import numpy as np + + assert flamingo_task == "EventClassification" + assert split in ["train", "test", "val"] + + map_split = lambda split: 'audio' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + all_labels = [] + with open(os.path.join(dataset_path, 'annotations.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + for idx, row in tqdm(enumerate(reader)): + if idx == 0: + header = np.array(row) + continue + + if not row[0].startswith(split): + continue + + filename = row[2] + if filter_file(file_path, file_list, filename): + continue + + labels = [header[i] for i in range(12, len(header)-8) if str(row[i]) == "1"] + labels = [x.split("_")[1].replace('-', ' ').lower() for x in labels if 'X_' not in x] + all_labels += labels + + text_output = ", ".join(labels) + if len(text_output) <= 1: + continue + text_prompt = 'this is a sound of' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + print('all labels:', list(set(all_labels))) + + elif dataset_name == "SoundDescs": + import torch + assert flamingo_task in ["AudioDescription"] + assert split in ["train"] + + map_split = lambda split: 'raw/audios' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + split_file = os.path.join(dataset_path, 'audio-retrieval-benchmark/data/SoundDescs/{}_list.txt'.format(split)) + with open(split_file, 'r') as f: + data = f.readlines() + names = set([x.replace('\n', '') for x in data]) + + with open(os.path.join(dataset_path, 'audio-retrieval-benchmark/sounddescs_data/descriptions.pkl'), 'rb') as f: + obj = f.read() + metadata_dic = pickle.loads(obj, encoding='latin1') + + for name in tqdm(names): + if name not in metadata_dic.keys(): + continue + + filename = '{}.wav'.format(name) + if filter_file(file_path, file_list, filename): + continue + + description = metadata_dic[name] + text_output = description + text_prompt = 'generate audio description' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "tess": + assert flamingo_task == "EmotionClassification" + assert split in ["train", "val"] + + map_split = lambda split: '24414hz' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + split_file = os.path.join( + dataset_path, + 'tess_audiopath_text_sid_emotion_duration_{}_filelist.txt'.format(split) + ) + with open(split_file, 'r') as f: + data = f.readlines() + data = [x.replace('\n', '') for x in data] + + for row in tqdm(data): + if row.count('|') != 4: + continue + filename, utterances, speaker, emotion, duration = row.split('|') + if filter_file(file_path, file_list, filename): + continue + + text_output = emotion.replace('_', ' ') + text_output = EMOTION_MAP_DICT[text_output] + text_prompt = 'this emotion is' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "UrbanSound8K": + assert flamingo_task in ["EventClassification"] + assert split in ["train"] + + map_split = lambda split: 'audio' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = None + + with open(os.path.join(dataset_path, 'metadata/UrbanSound8K.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + filename, fsID, start, end, salience, fold, classID, class_name = row + filename = 'fold{}/{}'.format(fold, filename) + if filter_file(file_path, file_list, filename): + continue + + text_output = class_name.replace("_", " ").lower() + if len(text_output) <= 1: + continue + text_prompt = 'this is a sound of' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "VocalSound": + assert flamingo_task == "VocalClassification" + + map_split = lambda split: 'data_44k' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + split_file = os.path.join( + dataset_path, + 'meta/{}_meta.csv'.format(split[:2] if split in ['train', 'test'] else split[:3]) + ) + + prefix = set([]) + with open(split_file, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + for row in reader: + prefix.add(row[0]) + + all_labels = set([]) + for filename in tqdm(file_list): + if not filename.split('_')[0] in prefix: + continue + + if filter_file(file_path, file_list, filename): + continue + + label = filename.split('_')[2].split('.wav')[0] + if label == 'throatclearing': + label = 'throat clearing' + + text_output = label + text_prompt = 'this vocal sound is' + all_labels.add(label) + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + print('all labels:\n', "\'" + "\', \'".join(list(all_labels)) + "\'") + + elif dataset_name.startswith("WavCaps"): + assert split in ["train"] + + dataset_name, subset_name = dataset_name.split('-') + dataset_path = os.path.join( + '/'.join(dataset_path.split('/')[:-1]), + dataset_name + ) + dataset_dic['dataset_path'] = dataset_path + + map_split = lambda split: subset_name + '_flac' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) + + metadata_file = os.listdir(os.path.join(dataset_path, "json_files", subset_name)) + metadata_file = [x for x in metadata_file if x.endswith('json')][0] + with open(os.path.join(dataset_path, "json_files", subset_name, metadata_file)) as f: + data = f.read() + reader = json.loads(data) + + if subset_name == "AudioSet_SL": + assert flamingo_task == 'AudioCaptioning' + + for sample in tqdm(reader['data']): + filename = sample["id"].replace('.wav', '.flac') + if filter_file(file_path, file_list, filename): + continue + + text_output = sample['caption'] + if len(text_output) <= 1: + continue + text_prompt = 'generate audio caption' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + else: + assert flamingo_task in ['AudioCaptioning', 'AudioDescription'] + + for sample in tqdm(reader['data']): + filename = sample["id"] + '.flac' + if filter_file(file_path, file_list, filename): + continue + + if flamingo_task == 'AudioCaptioning': + text_output = sample['caption'] + text_prompt = 'generate audio caption' + + elif flamingo_task == 'AudioDescription': + text_output = sample['description'] + text_prompt = 'generate audio description' + + if len(text_output) <= 1: + continue + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif dataset_name == "WavText5K": + assert split == 'train' + + map_split = lambda split: 'Webcrawl/44100/audios' + file_path = os.path.join( + dataset_path, + map_split(split) + ) + assert os.path.exists(file_path), '{} not exist'.format(file_path) + + dataset_dic["split_path"] = map_split(split) + file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) + + dic = defaultdict(str) + with open(os.path.join(dataset_path, 'WavText5K.csv'), newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=',', quotechar='"') + next(reader) + for row in tqdm(reader): + _, _, title, description, filename, tags = row + dic[filename] = (title, description, tags) + + if flamingo_task == "AudioCaptioning": + for filename in tqdm(dic.keys()): + if filter_file(file_path, file_list, filename): + continue + + title, description, tags = dic[filename] + text_output = description + if len(text_output) <= 1: + continue + text_prompt = 'generate audio caption' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + elif flamingo_task == "Tagging": + for filename in tqdm(dic.keys()): + if filter_file(file_path, file_list, filename): + continue + + title, description, tags = dic[filename] + if len(tags) < 2 or not tags.startswith('[') or not tags.endswith(']'): + continue + + tags = tags[1:-1].split(', ') + tags = list(map(lambda x: x.replace("'", ""), tags)) + text_output = '{}'.format(', '.join(tags)) + if len(text_output) <= 1: + continue + text_prompt = 'generate tags' + + dataset_dic["data"][dataset_dic["total_num"]] = { + "name": filename, + "prompt": text_prompt, + "output": text_output.replace('\n', ' ') + } + dataset_dic["total_num"] += 1 + + + with open(output_file, 'w') as json_file: + json.dump(dataset_dic, json_file) + + +# ==================== Precompute CLAP and build Hashing ==================== + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1., a_max=1.) + return (x * 32767.).astype(np.int16) + + +def update_progress_bar(arg): + pbar.update() + + +@suppress_all_output +def load_clap_model(checkpoint): + if checkpoint in ['630k-audioset-best.pt', '630k-best.pt', '630k-audioset-fusion-best.pt', '630k-fusion-best.pt']: + amodel = 'HTSAT-tiny' + elif checkpoint in ['music_speech_audioset_epoch_15_esc_89.98.pt']: + amodel = 'HTSAT-base' + else: + raise NotImplementedError + + model = laion_clap.CLAP_Module( + enable_fusion=('fusion' in checkpoint.lower()), + amodel=amodel + ).cuda() + model.load_ckpt(ckpt=os.path.join( + '/lustre/fsw/portfolios/adlr/users/zkong/audio-flamingo-data/laion-clap-pretrained/laion_clap', + checkpoint + )) + return model + + +def load_audio(file_path, target_sr=44100, duration=30.0, start=0.0): + if file_path.endswith('.mp3'): + audio = AudioSegment.from_file(file_path) + if len(audio) > (start + duration) * 1000: + audio = audio[start * 1000:(start + duration) * 1000] + + if audio.frame_rate != target_sr: + audio = audio.set_frame_rate(target_sr) + + if audio.channels > 1: + audio = audio.set_channels(1) + + data = np.array(audio.get_array_of_samples()) + if audio.sample_width == 2: + data = data.astype(np.float32) / np.iinfo(np.int16).max + elif audio.sample_width == 4: + data = data.astype(np.float32) / np.iinfo(np.int32).max + else: + raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) + + else: + with sf.SoundFile(file_path) as audio: + original_sr = audio.samplerate + channels = audio.channels + + max_frames = int((start + duration) * original_sr) + + audio.seek(int(start * original_sr)) + frames_to_read = min(max_frames, len(audio)) + data = audio.read(frames_to_read) + + if data.max() > 1 or data.min() < -1: + data = data / max(abs(data.max()), abs(data.min())) + + if original_sr != target_sr: + if channels == 1: + data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) + else: + data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] + else: + if channels != 1: + data = data.T[0] + + if data.min() >= 0: + data = 2 * data / abs(data.max()) - 1.0 + else: + data = data / max(abs(data.max()), abs(data.min())) + return data + + +@torch.no_grad() +def compute_clap_each(audio_file, model): + try: + data = load_audio(audio_file, target_sr=48000, duration=10) + print(audio_file, 'loaded') + + except Exception as e: + print(audio_file, 'unsuccessful due to', e) + return None + + audio_data = data.reshape(1, -1) + + audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().cuda() + audio_embed = model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True) + audio_embed = audio_embed.squeeze(0).cpu() + return audio_embed + + +@torch.no_grad() +def compute_embeddings_batch(batch, audio_files, model): + batch_results = [] + for i in batch: + if i >= len(audio_files): + break + audio_file = audio_files[i] + audio_embed = compute_clap_each(audio_file, model) + batch_results.append((i, audio_file, audio_embed)) + return batch_results + + +@torch.no_grad() +def precompute_clap_for_dataset( + dataset_file, + embedding_output_file, + checkpoint='630k-audioset-fusion-best.pt' +): + contents, audio_files = load_dataset_file(dataset_file) + + model = load_clap_model(checkpoint) + + if os.path.exists(embedding_output_file): + print('loading already computed embedding file from', embedding_output_file) + with open(embedding_output_file, 'rb') as f: + saved_data = pickle.load(f) + curr_audio_indices = saved_data['audio_indices'] + curr_audio_files = saved_data['audio_files'] + curr_audio_embeds = saved_data['audio_embeds'] + + else: + curr_audio_indices = [] + curr_audio_files = [] + curr_audio_embeds = [] + + print('computing embeddings for {}'.format(dataset_file)) + start_index = len(curr_audio_files) + remaining_indices = list(range(start_index, len(audio_files))) + + batch_size = 128 + batches = [ + list(range(i, min(i + batch_size, len(audio_files)))) \ + for i in range(start_index, len(audio_files), batch_size) + ] + + with multiprocessing.Pool(processes=4) as pool: + for i, batch in enumerate(batches): + batch_results = pool.map( + partial(compute_embeddings_batch, model=model, audio_files=audio_files), + [batch] + ) + + for result in batch_results[0]: + curr_audio_indices.append(result[0]) + curr_audio_files.append(result[1]) + curr_audio_embeds.append(result[2]) + + with open(embedding_output_file, 'wb') as f: + pickle.dump({ + 'audio_indices': curr_audio_indices, + 'audio_files': curr_audio_files, + 'audio_embeds': curr_audio_embeds + }, f) + + print(f"Saved progress for batch {i+1}/{len(batches)}: \ + audio_indices {len(curr_audio_indices)}, \ + audio_files {len(curr_audio_files)}, \ + audio_embeds {len(curr_audio_embeds)}*{curr_audio_embeds[0].shape}") + + return curr_audio_indices, curr_audio_files, curr_audio_embeds + + +def build_faiss_index(embeddings): + d = embeddings[0].size(0) + index = faiss.IndexFlatL2(d) + np_embeddings = np.vstack([emb.numpy() for emb in embeddings]) + index.add(np_embeddings) + return index + + +def build_faiss_index_dataset( + dataset_file, + embedding_output_file, + faiss_output_file, + checkpoint='630k-audioset-fusion-best.pt', + only_precompute_clap=False +): + audio_indices, audio_files, audio_embeds = precompute_clap_for_dataset(dataset_file, embedding_output_file, checkpoint) + + if only_precompute_clap: + return + + valid_indices, valid_files, valid_embeds = [], [], [] + for audio_index, audio_file, audio_embed in zip(audio_indices, audio_files, audio_embeds): + if audio_embed is not None: + valid_indices.append(audio_index) + valid_files.append(audio_file) + valid_embeds.append(audio_embed) + + print('building faiss index') + faiss_index = build_faiss_index(valid_embeds) + + print('saving faiss index') + faiss.write_index(faiss_index, faiss_output_file) + with open(faiss_output_file + '.filenames', 'wb') as f: + pickle.dump({'audio_indices': valid_indices, 'audio_files': valid_files}, f) + + +# ==================== Generate interleaved dataset files ==================== +# only save index so that one can recover + +def build_interleaved_dataset(dataset_file, interleaved_output_file, embedding_output_file, faiss_output_file, mode='random', n_samples=3): + contents, audio_files = load_dataset_file(dataset_file) + + dataset_dic = { + "dataset_path": contents["dataset_path"], + "split": contents["split"], + "split_path": contents["split_path"], + "flamingo_task": contents["flamingo_task"], + "total_num": 0, + "interleaved_data": {}, + } + + # interleaved_data is + # { + # id: { + # "generation_index_in_split": index of sample in the train or val or test.json, + # "fewshot_indices_in_train": list(indices) of few shot samples in train.json + # } + # } + + if mode == 'knn': + model = load_clap_model(checkpoint='630k-audioset-fusion-best.pt') + + print('loading already computed embedding file from', embedding_output_file) + with open(embedding_output_file, 'rb') as f: + precomputed_data = pickle.load(f) + precomputed_audio_indices = precomputed_data['audio_indices'] + precomputed_audio_files = precomputed_data['audio_files'] + precomputed_audio_embeds = precomputed_data['audio_embeds'] + + faiss_index = faiss.read_index(faiss_output_file) + with open(faiss_output_file+'.filenames', 'rb') as f: + _data = pickle.load(f) + faiss_index_audio_indices = _data['audio_indices'] + faiss_index_audio_files = _data['audio_files'] + + print('looking for few shot samples and building interleaved_{} data'.format(mode)) + for i in tqdm(range(contents["total_num"])): + if mode == 'random': + few_shot_indices = list(np.random.choice( + list(set(list(range(contents["total_num"]))) - set([i])), + size=n_samples-1, + replace=False + )) + few_shot_indices = list(map(int, few_shot_indices)) + + elif mode == 'knn': + if audio_files[i] in precomputed_audio_files: + idx = precomputed_audio_files.index(audio_files[i]) + query_embedding_np = precomputed_audio_embeds[idx] + if query_embedding_np is not None: + query_embedding_np = query_embedding_np.numpy().reshape(1, -1) + else: + continue + + else: + query_embedding_np = compute_clap_each(audio_files[i], model) + if query_embedding_np is not None: + query_embedding_np = query_embedding_np.numpy().reshape(1, -1) + else: + continue + + distances, knn_indices = faiss_index.search(query_embedding_np, n_samples+50) + distances = distances[0] + knn_indices = knn_indices[0] + + knn_filenames = [faiss_index_audio_files[idx] for idx in knn_indices] + combined = list(zip(knn_indices, knn_filenames)) + unique_indices = defaultdict(list) + for idx, filename in combined: + unique_indices[filename].append(idx) + + cleared_knn_indices = [random.choice(unique_indices[filename]) for filename in unique_indices if filename != audio_files[i]] + + if dataset_file.endswith('train.json'): + cleared_knn_indices = [knn_i for knn_i in cleared_knn_indices if faiss_index_audio_indices[knn_i] != i] + cleared_knn_indices = cleared_knn_indices[:n_samples-1] + np.random.shuffle(cleared_knn_indices) + + few_shot_indices = [faiss_index_audio_indices[knn_i] for knn_i in cleared_knn_indices] + + dataset_dic["interleaved_data"][dataset_dic["total_num"]] = { + "generation_index_in_split": i, + "fewshot_indices_in_train": few_shot_indices + } + dataset_dic["total_num"] += 1 + + with open(interleaved_output_file, 'w') as json_file: + json.dump(dataset_dic, json_file) + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('-d', '--dataset_name', type=str, help='dataset name') + parser.add_argument('-f', '--flamingo_task', type=str, help='flamingo task') + parser.add_argument('--interleave', action="store_true", help='prepare the interleave dataset') + args = parser.parse_args() + + ROOT = "/lustre/fsw/portfolios/adlr/users/zkong" + dataset_root = os.path.join(ROOT, "datasets") + output_root = os.path.join(ROOT, "audio-flamingo-data/dataset_files") + os.makedirs(output_root, exist_ok=True) + + dataset_name = args.dataset_name # "Clotho-v2", "AudioSet", "Clotho-AQA", "WavText5K", "FSD50k", ... + flamingo_task = args.flamingo_task # AQA, AudioCaptioning, EventClassification, SceneClassification, Tagging, ... + + # must be train first otherwise there's no train.embedding for query + for split in ["train", "val", "test"]: + dataset_path = os.path.join(dataset_root, dataset_name) + + output_folder = '{}-{}'.format(dataset_name, flamingo_task) + os.makedirs(os.path.join(output_root, output_folder), exist_ok=True) + + dataset_file = os.path.join(output_root, output_folder, '{}.json'.format(split)) + if not os.path.exists(dataset_file): + try: + prepare_files(dataset_name, dataset_path, split, flamingo_task, dataset_file) + except AssertionError as e: + print('split {} not exist for {}: {}'.format(split, dataset_name, e)) + continue + else: + print('{} exists; exiting'.format(dataset_file)) + + if args.interleave: + faiss_output_file = dataset_file.replace('{}.json'.format(split), "train_faiss_index.index") + embedding_output_file = dataset_file.replace('.json', ".embedding") + + if split == 'train': + if (not os.path.exists(faiss_output_file)) or (not os.path.exists(faiss_output_file + '.filenames')): + build_faiss_index_dataset( + dataset_file, embedding_output_file, faiss_output_file, + only_precompute_clap=False + ) + else: + print('{} exists; exiting'.format(faiss_output_file)) + else: + build_faiss_index_dataset( + dataset_file, embedding_output_file, + faiss_output_file=None, + only_precompute_clap=True + ) + print('precomputing embedding for {} subset finished'.format(split)) + + for mode in ['knn', 'random']: + interleaved_output_file = '/'.join( + dataset_file.split('/')[:-1] + \ + ['interleaved_{}-'.format(mode) + dataset_file.split('/')[-1]] + ) + if not os.path.exists(interleaved_output_file): + build_interleaved_dataset( + dataset_file=dataset_file, + interleaved_output_file=interleaved_output_file, + embedding_output_file=embedding_output_file, + faiss_output_file=faiss_output_file, + mode=mode, + n_samples=4 + ) + else: + print('{} exists; exiting'.format(interleaved_output_file)) + + +