Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import csv | |
| import yaml | |
| from collections import defaultdict | |
| import pickle | |
| import glob | |
| import math | |
| from functools import partial | |
| import sys | |
| import io | |
| import warnings | |
| import random | |
| import numpy as np | |
| import torch | |
| import laion_clap | |
| import librosa | |
| from pydub import AudioSegment | |
| import soundfile as sf | |
| import faiss | |
| import multiprocessing | |
| multiprocessing.set_start_method('spawn', force=True) | |
| try: | |
| from tqdm import tqdm | |
| except: | |
| tqdm = lambda x: x | |
| def suppress_all_output(func): | |
| def wrapper(*args, **kwargs): | |
| old_stdout = sys.stdout | |
| old_stderr = sys.stderr | |
| sys.stdout = io.StringIO() | |
| sys.stderr = io.StringIO() | |
| old_fd_out = os.dup(1) | |
| old_fd_err = os.dup(2) | |
| null_fd = os.open(os.devnull, os.O_RDWR) | |
| os.dup2(null_fd, 1) | |
| os.dup2(null_fd, 2) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| try: | |
| result = func(*args, **kwargs) | |
| finally: | |
| os.dup2(old_fd_out, 1) | |
| os.dup2(old_fd_err, 2) | |
| os.close(null_fd) | |
| os.close(old_fd_out) | |
| os.close(old_fd_err) | |
| sys.stdout = old_stdout | |
| sys.stderr = old_stderr | |
| return result | |
| return wrapper | |
| def filter_file(file_path, file_list, filename): | |
| if file_list is not None: | |
| if filename not in file_list: | |
| print(filename, 'not exist') | |
| return True | |
| else: | |
| if not os.path.exists(os.path.join(file_path, filename)): | |
| print(filename, 'not exist') | |
| return True | |
| if os.path.getsize(os.path.join(file_path, filename)) < 16000: | |
| print(filename, 'less than 0.5 to 1 second') | |
| return True | |
| return False | |
| # ==================== Prepare dataset files from each data folder ==================== | |
| EMOTION_MAP_DICT = { | |
| 'amused': 'amused' , | |
| 'anger': 'angry' , 'angry': 'angry' , | |
| 'anxious': 'anxious' , | |
| 'apologetic': 'apologetic' , | |
| 'assertive': 'assertive' , | |
| 'calm': 'calm' , | |
| 'concerned': 'concerned' , | |
| 'contempt': 'contempt' , | |
| 'disgust': 'disgusted' , 'disgusted': 'disgusted' , | |
| 'encouraging': 'encouraging' , | |
| 'excited': 'excited' , | |
| 'fear': 'fearful' , 'fearful': 'fearful' , | |
| 'frustated': 'frustated' , | |
| 'happy': 'happy' , 'joy': 'happy' , | |
| 'neutral': 'neutral' , | |
| 'sad': 'sad' , 'sadness': 'sad' , | |
| 'sleepy': 'sleepy' , | |
| 'surprise': 'surprised' , 'surprised': 'surprised' , | |
| 'pleasantly surprised': 'pleasantly surprised' , | |
| } | |
| def load_dataset_file(dataset_file): | |
| with open(dataset_file) as f: | |
| contents = f.read() | |
| contents = json.loads(contents) | |
| audio_files = [ | |
| os.path.join( | |
| contents["dataset_path"], | |
| contents["split_path"], | |
| contents["data"][str(i)]["name"] | |
| ) for i in range(contents["total_num"]) | |
| ] | |
| return contents, audio_files | |
| def compute_label_graph(dataset_name, dataset_path, top_n, output_file): | |
| if os.path.exists(output_file): | |
| print('loading precomputed graph:', output_file) | |
| with open(output_file, 'r') as json_file: | |
| graph = json.load(json_file) | |
| else: | |
| import torch | |
| from sentence_transformers import SentenceTransformer, util | |
| embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| print('precomputing graph and save to:', output_file) | |
| if dataset_name == 'AudioSetSL_singlelabel': | |
| names = [] | |
| with open(os.path.join(dataset_path, 'class_labels_indices.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in reader: | |
| _, label, name = row # 123, /m/02zsn, "Female speech, woman speaking" | |
| names += name.split(', ') | |
| names = [x.lower().strip() for x in names] | |
| elif dataset_name == "Clotho-AQA_singlelabel": | |
| names = set([]) | |
| with open(os.path.join(dataset_path, 'clotho_aqa_metadata.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| _, file_name, keywords, _, _, _, _ = row | |
| names |= set(keywords.split(';')) | |
| names = [x.lower().strip() for x in names] | |
| names_embeddings = embedding_model.encode(names, convert_to_tensor=True) | |
| similarity_matrix = util.pytorch_cos_sim(names_embeddings, names_embeddings) | |
| similarity_threshold = 0.75 | |
| n_items = len(names) | |
| graph = {} | |
| for i in range(n_items): | |
| adjusted_top_n = min(top_n, n_items - 1) | |
| values, indices = torch.topk(similarity_matrix[i], adjusted_top_n + 1, largest=True) | |
| most_similar_items = [] | |
| for value, idx in zip(values, indices): | |
| if idx != i and value <= similarity_threshold: | |
| most_similar_items.append(idx.item()) | |
| if len(most_similar_items) == adjusted_top_n: | |
| break | |
| graph[names[i]] = [names[j] for j in most_similar_items] | |
| with open(output_file, 'w') as json_file: | |
| json.dump(graph, json_file) | |
| # graph is a dict: key = each label, value = List[20 similar labels] | |
| return graph | |
| def prepare_files(dataset_name, dataset_path, split, flamingo_task, output_file): | |
| assert not os.path.exists(output_file) | |
| dataset_dic = { | |
| "dataset_path": dataset_path, | |
| "split": split, | |
| "split_path": None, | |
| "flamingo_task": "{}-{}".format(dataset_name, flamingo_task), | |
| "total_num": 0, | |
| "data": {} # {id: {'name': name, 'prompt': prompt, 'output': output}} | |
| } | |
| if dataset_name == "AudioSet": | |
| assert flamingo_task == "EventClassification" | |
| assert split == 'train' | |
| map_split = lambda split: 'train_wav' if split == 'train' else '' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| dic = defaultdict(str) | |
| with open(os.path.join(dataset_path, 'class_labels_indices.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| _, label, name = row # /m/02zsn,"Female speech, woman speaking" | |
| dic[label] = name | |
| with open(os.path.join(dataset_path, 'train.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| filename, _, _, labels = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r | |
| filename = filename + '.wav' | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| label_list = labels.split(",") | |
| assert all(label in dic for label in label_list) | |
| text_output = ", ".join([dic[label] for label in label_list]) | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'this is a sound of' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "AudioSetFull": | |
| assert flamingo_task == "EventClassification" | |
| assert split == 'train' | |
| map_split = lambda split: '/mnt/fsx-main/rafaelvalle/datasets/audioset/unbalanced_train_segments/22khz' | |
| file_path = map_split(split) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| dic_code2label = defaultdict(str) | |
| with open(os.path.join(dataset_path, 'audioset-processing/data/class_labels_indices.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| _, code, name = row # /m/02zsn,"Female speech, woman speaking" | |
| dic_code2label[code] = name | |
| dic_filename2code = {} | |
| with open(os.path.join(dataset_path, 'audioset-processing/data/unbalanced_train_segments.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| next(reader) | |
| for row in tqdm(reader): | |
| filename, _, _, codes = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r | |
| filename = 'Y' + filename + '.wav' | |
| dic_filename2code[filename] = codes.split(",") | |
| for part in tqdm(range(41)): | |
| part_str = str(part) | |
| if len(part_str) == 1: | |
| part_str = '0' + part_str | |
| part_folder = 'unbalanced_train_segments_part{}'.format(part_str) | |
| for filename in os.listdir(os.path.join(file_path, part_folder)): | |
| if not filename.endswith('.wav'): | |
| continue | |
| if filter_file(file_path, file_list, os.path.join(part_folder, filename)): | |
| continue | |
| if filename not in dic_filename2code: | |
| continue | |
| text_output = ", ".join([dic_code2label[code] for code in dic_filename2code[filename] if code in dic_code2label]) | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'this is a sound of' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": os.path.join(part_folder, filename), | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "AudioSetFullwoAudioMusicCaps": | |
| assert flamingo_task == "EventClassification" | |
| assert split == 'train' | |
| map_split = lambda split: '/mnt/fsx-main/rafaelvalle/datasets/audioset/unbalanced_train_segments/22khz' | |
| file_path = map_split(split) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| print('extracting AudioCaps and MusicCaps ytid to avoid these samples') | |
| audiocaps_ytid = [] | |
| for f in ['audiocaps_dataset/train.csv', 'audiocaps_dataset/test.csv', 'audiocaps_dataset/val.csv']: | |
| with open(os.path.join(dataset_path, f), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in reader: | |
| _, ytid, _, _ = row | |
| audiocaps_ytid.append('Y' + ytid + '.wav') | |
| audiocaps_ytid = set(audiocaps_ytid) | |
| musiccaps_ytid = [] | |
| with open(os.path.join(dataset_path, 'musiccaps_dataset/musiccaps_manifest.json')) as f: | |
| data = f.read() | |
| musiccaps_list = json.loads(data) | |
| for row in musiccaps_list: | |
| musiccaps_ytid.append('Y' + row["ytid"] + '.wav') | |
| musiccaps_ytid = set(musiccaps_ytid) | |
| print('Will exclude {} samples from MusicCaps and {} from AudioCaps'.format(len(audiocaps_ytid), len(musiccaps_ytid))) | |
| dic_code2label = defaultdict(str) | |
| with open(os.path.join(dataset_path, '../AudioSetFull/audioset-processing/data/class_labels_indices.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| _, code, name = row # /m/02zsn,"Female speech, woman speaking" | |
| dic_code2label[code] = name | |
| dic_filename2code = {} | |
| with open(os.path.join(dataset_path, '../AudioSetFull/audioset-processing/data/unbalanced_train_segments.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| next(reader) | |
| for row in tqdm(reader): | |
| filename, _, _, codes = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r | |
| filename = 'Y' + filename + '.wav' | |
| dic_filename2code[filename] = codes.split(",") | |
| music_audio_caps_excluded = 0 | |
| for part in tqdm(range(41)): | |
| part_str = str(part) | |
| if len(part_str) == 1: | |
| part_str = '0' + part_str | |
| part_folder = 'unbalanced_train_segments_part{}'.format(part_str) | |
| for filename in os.listdir(os.path.join(file_path, part_folder)): | |
| if not filename.endswith('.wav'): | |
| continue | |
| if filename in audiocaps_ytid or filename in musiccaps_ytid: | |
| music_audio_caps_excluded += 1 | |
| continue | |
| if filter_file(file_path, file_list, os.path.join(part_folder, filename)): | |
| continue | |
| if filename not in dic_filename2code: | |
| continue | |
| text_output = ", ".join([dic_code2label[code] for code in dic_filename2code[filename] if code in dic_code2label]) | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'this is a sound of' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": os.path.join(part_folder, filename), | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "AudioSetSL_singlelabel": | |
| import numpy as np | |
| assert flamingo_task == "EventClassification" | |
| assert split == 'train' | |
| map_split = lambda split: '../AudioSet/train_wav' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| dic = defaultdict(str) | |
| with open(os.path.join(dataset_path, 'class_labels_indices.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| _, label, name = row # /m/02zsn,"Female speech, woman speaking" | |
| dic[label] = name | |
| graph = compute_label_graph( | |
| dataset_name, | |
| dataset_path, | |
| top_n=200, | |
| output_file=os.path.join(dataset_path, 'label_graph.json') | |
| ) | |
| with open(os.path.join(dataset_path, 'train.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| filename, _, _, labels = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r | |
| filename = filename + '.wav' | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| label_list = labels.split(",") | |
| assert all(label in dic for label in label_list) | |
| text_labels = ", ".join([dic[label] for label in label_list]).lower() | |
| text_labels = text_labels.split(', ') | |
| text_output = np.random.choice(text_labels) | |
| if len(text_output) <= 1: | |
| continue | |
| num_options = np.random.choice( | |
| [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], | |
| p=[ 0.05, 0.1, 0.1, 0.1, 0.1, | |
| 0.05, 0.05, 0.05, 0.1, 0.05, | |
| 0.05, 0.1, 0.05, 0.05] | |
| ) | |
| negative_samples = [x for x in graph[text_output] if x not in set(text_labels)] | |
| candidate_negative_labels = list(np.random.choice( | |
| negative_samples[:num_options*10], | |
| size=num_options-1, | |
| replace=False | |
| )) | |
| if type(candidate_negative_labels) is str: | |
| candidate_negative_labels = [candidate_negative_labels] | |
| all_options = [text_output] + candidate_negative_labels | |
| np.random.shuffle(all_options) | |
| text_prompt = 'Classify this sound.\nOPTIONS:\n - {}.'.format( | |
| '.\n - '.join(all_options) | |
| ) | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "AUDIOCAPS13k": | |
| assert flamingo_task == 'AudioCaptioning' | |
| map_split = lambda split: 'audio_32000Hz/{}'.format(split) | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) | |
| with open(os.path.join( | |
| dataset_path, | |
| '{}_manifest.json'.format(split + ('_v2' if split == 'train' else '')) | |
| ), 'r') as f: | |
| data = f.readlines() | |
| data = [json.loads(row) for row in data] | |
| for row in tqdm(data): | |
| filename = row['audio_filepath'].split('/')[-1] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = row['text'] | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'generate audio caption' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "audiocaps": | |
| assert flamingo_task == 'AudioCaptioning' | |
| map_split = lambda split: 'audio/{}'.format(split if split in ['train', 'test'] else 'valid') | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) | |
| for filename in tqdm(file_list): | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| with open(os.path.join(file_path, filename.replace('.flac', '.json')), 'r') as f: | |
| data = json.load(f) | |
| captions = data['text'] | |
| for text_output in captions: | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'generate audio caption' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == 'BG-Gun-Sound-Dataset': | |
| assert flamingo_task == "SoundClassification" | |
| assert split in ["train", "test"] | |
| map_split = lambda split: 'data/gun_sound_v2' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = os.listdir(file_path) | |
| all_cates = set([]) | |
| with open(os.path.join(dataset_path, 'data/v3_exp3_{}.csv'.format(split)), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| filename, cate, dist, dire = row | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = cate | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'What is the gun of this sound?' | |
| all_cates.add(cate) | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| print(all_cates) | |
| elif dataset_name == "BirdsDataset": | |
| assert flamingo_task == "SoundClassification" | |
| assert split == 'train' | |
| map_split = lambda split: 'Voice_of_Birds' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| for bird_type in tqdm(os.listdir(file_path)): | |
| bird_name = ' '.join(bird_type.split('_')[:-1]) | |
| for filename in os.listdir(os.path.join(file_path, bird_type)): | |
| if filter_file(file_path, file_list, os.path.join(bird_type, filename)): | |
| continue | |
| text_output = bird_name | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'What is the name of bird in this sound?' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": os.path.join(bird_type, filename), | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "BBCSoundEffects": | |
| assert split in ['train'] | |
| assert flamingo_task == 'AudioDescription' | |
| map_split = lambda split: '../WavCaps/BBC_Sound_Effects_flac' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) | |
| with open(os.path.join(dataset_path, 'BBCSoundDownloader/BBCSoundEffects.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| if len(row) != 7: | |
| continue | |
| filename, description, _, _, _, _, _ = row | |
| filename = filename.replace('.wav', '.flac') | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = description | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'generate audio description' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "chime-home": | |
| assert flamingo_task == "EventClassification" | |
| assert split == 'train' | |
| map_split = lambda split: 'chime_home/chunks' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file48k_list = list(filter(lambda x: x.endswith('48kHz.wav'), os.listdir(file_path))) | |
| file16k_list = list(filter(lambda x: x.endswith('16kHz.wav'), os.listdir(file_path))) | |
| csv_file_list = list(filter(lambda x: x.endswith('.csv'), os.listdir(file_path))) | |
| label_mapping = { | |
| 'c': 'child speaking', | |
| 'm': 'male speaking', | |
| 'f': 'female speaking', | |
| 'p': 'human activity', | |
| 't': 'television', | |
| 'b': 'household appliances', | |
| 's': 'silence' | |
| } | |
| for csv_file in tqdm(csv_file_list): | |
| with open(os.path.join(file_path, csv_file), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| labels = None | |
| for row in reader: | |
| if row[0] == 'majorityvote': | |
| labels = row[1] | |
| break | |
| if labels is None or len(labels) == 0: | |
| continue | |
| filename = csv_file.replace('.csv', '.48kHz.wav') | |
| if filter_file(file_path, file48k_list, filename): | |
| filename = csv_file.replace('.csv', '.16kHz.wav') | |
| if filter_file(file_path, file16k_list, filename): | |
| continue | |
| text_output = ", ".join([label_mapping[l] for l in labels if l in label_mapping]) | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'this is a sound of' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "CLAP_freesound": | |
| assert flamingo_task == "AudioCaptioning" | |
| assert split in ["train", "test"] | |
| map_split = lambda split: os.path.join('freesound_no_overlap/split', split) | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) | |
| with open(os.path.join( | |
| dataset_path, | |
| 'freesound_no_overlap_meta.csv' | |
| ), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| if len(row[0].split('/')) != 2: | |
| continue | |
| if len(row) <= 1: | |
| continue | |
| file_split, filename = row[0].split('/') | |
| if file_split != split: | |
| continue | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| caption_1 = row[1] # caption_2 = row[2] but not very good | |
| text_output = caption_1 | |
| if len(text_output) <= 2: | |
| continue | |
| text_prompt = 'generate audio caption' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "Clotho-AQA": | |
| map_split = lambda split: 'audio_files' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| if flamingo_task == "EventClassification": | |
| dic = defaultdict(str) | |
| with open(os.path.join(dataset_path, 'clotho_aqa_metadata.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| _, file_name, keywords, _, _, _, _ = row | |
| dic[file_name] = keywords.replace(';', ', ') | |
| with open(os.path.join(dataset_path, 'clotho_aqa_{}.csv'.format(split)), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| filename = row[0] | |
| if filename not in dic or filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = dic[filename] | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'this is a sound of' | |
| del dic[filename] | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif flamingo_task == "AQA": | |
| dic_qa = defaultdict(list) | |
| with open(os.path.join(dataset_path, 'clotho_aqa_{}.csv'.format(split)), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| filename, question, answer, confidence = row | |
| dic_qa[(filename, question)].append((answer.lower(), confidence.lower())) | |
| # get binary -> trinary | |
| def preprocess(list_ans_conf): | |
| assert set([x[1] for x in list_ans_conf]) <= set(['yes', 'no', 'maybe']) | |
| answers = set([x[0].lower() for x in list_ans_conf]) | |
| if answers <= set(['yes', 'no']): | |
| if len(answers) > 1: | |
| return ['unsure'] | |
| else: | |
| return list(answers) | |
| else: | |
| return list(answers) | |
| # get majority vote | |
| def majority_vote(list_ans_conf): | |
| assert set([x[1] for x in list_ans_conf]) <= set(['yes', 'no', 'maybe']) | |
| weight = {'yes': 1.0, 'no': 0.1, 'maybe': 0.6} | |
| if set([x[0] for x in list_ans_conf]) <= set(['yes', 'no']): | |
| score = {'yes': 1.0, 'no': -1.0} | |
| pred = sum([score[x[0]] * weight[x[1]] for x in list_ans_conf]) | |
| if pred > 0: | |
| return ['yes'] | |
| else: | |
| return ['no'] | |
| else: | |
| return list(set([x[0] for x in list_ans_conf])) | |
| for key in dic_qa: | |
| filename, question = key | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| if split == 'train': | |
| answers = majority_vote(dic_qa[key]) # majority vote | |
| else: | |
| answers = [x[0].strip().lower() for x in dic_qa[key]] | |
| answers = [', '.join(answers)] | |
| for answer in answers: | |
| text_output = answer | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = "Question: " + question | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "Clotho-AQA_singlelabel": | |
| import numpy as np | |
| assert flamingo_task == "EventClassification" | |
| map_split = lambda split: '../Clotho-AQA/audio_files' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| dic = defaultdict(str) | |
| with open(os.path.join(dataset_path, 'clotho_aqa_metadata.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| _, file_name, keywords, _, _, _, _ = row | |
| dic[file_name] = keywords.split(';') | |
| graph = compute_label_graph( | |
| dataset_name, | |
| dataset_path, | |
| top_n=300, | |
| output_file=os.path.join(dataset_path, 'label_graph.json') | |
| ) | |
| with open(os.path.join(dataset_path, 'clotho_aqa_{}.csv'.format(split)), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| filename = row[0] | |
| if filename not in dic or filter_file(file_path, file_list, filename): | |
| continue | |
| text_labels = [x.lower().strip() for x in dic[filename]] | |
| del dic[filename] | |
| for _ in range(6): | |
| text_output = np.random.choice(text_labels) | |
| if len(text_output) <= 1: | |
| continue | |
| num_options = np.random.choice( | |
| [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], | |
| p=[ 0.05, 0.1, 0.1, 0.1, 0.1, | |
| 0.05, 0.05, 0.05, 0.1, 0.05, | |
| 0.05, 0.1, 0.05, 0.05] | |
| ) | |
| negative_samples = [x for x in graph[text_output] if x not in set(text_labels)] | |
| candidate_negative_labels = list(np.random.choice( | |
| negative_samples[:num_options*20], | |
| size=num_options-1, | |
| replace=False | |
| )) | |
| if type(candidate_negative_labels) is str: | |
| candidate_negative_labels = [candidate_negative_labels] | |
| all_options = [text_output] + candidate_negative_labels | |
| np.random.shuffle(all_options) | |
| text_prompt = 'Classify this sound.\nOPTIONS:\n - {}.'.format( | |
| '.\n - '.join(all_options) | |
| ) | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "Clotho-v2": | |
| assert flamingo_task == "AudioCaptioning" | |
| assert split in ["train", "val", "test"] | |
| map_split = lambda split: 'development' if split == 'train' else ('validation' if split == "val" else "evaluation") | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| with open(os.path.join( | |
| dataset_path, | |
| 'clotho_captions_{}.csv'.format(map_split(split)) | |
| ), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| filename = row[0] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| for text_output in row[1:]: | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'generate audio caption' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "CochlScene": | |
| import ndjson | |
| assert flamingo_task == "SceneClassification" | |
| map_split = lambda split: split.capitalize() | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| with open(os.path.join(dataset_path, 'cochlscene_{}.ndjson'.format(split))) as ndjsonfile: | |
| reader = ndjson.load(ndjsonfile) | |
| for row in tqdm(reader): | |
| filename = "/".join(row["audiopath"].split("/")[1:]) | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = row["labels"].lower() | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'this acoustic scene is' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "common-accent": | |
| import ndjson | |
| import re | |
| assert flamingo_task == "AccentClassification" | |
| assert split in ["train", "test"] | |
| map_split = lambda split: '22khz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = os.listdir(file_path) | |
| all_accent = [] | |
| split_file = [f for f in os.listdir(dataset_path) if f.startswith(split) and f.endswith('.ndjson')][0] | |
| with open(os.path.join(dataset_path, split_file)) as ndjsonfile: | |
| reader = ndjson.load(ndjsonfile) | |
| for row in tqdm(reader): | |
| accent = row["accent"] | |
| accent = re.sub(r'\(.*?\)', '', accent) | |
| accent = accent.replace('English', '') | |
| accent = accent.split(',') | |
| accent = [x.strip() for x in accent if 'school' not in x] | |
| all_accent += accent | |
| filename = row["filename"] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| for accent_each in accent: | |
| if accent_each == 'Javanese': | |
| accent_each = 'Japanese' | |
| if len(accent_each) > 25: | |
| continue | |
| text_output = accent_each | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'Classify the accent of this speech.' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| print('all accents:', list(set(all_accent))) | |
| elif dataset_name == "CREMA-D": | |
| assert flamingo_task == "EmotionClassification" | |
| assert split in ["train"] | |
| map_split = lambda split: 'AudioWAV' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| split_file = os.path.join( | |
| dataset_path, | |
| 'crema-d_audiopath_text_sid_emotion_filelist.txt' | |
| ) | |
| with open(split_file, 'r') as f: | |
| data = f.readlines() | |
| data = [x.replace('\n', '') for x in data] | |
| for row in tqdm(data): | |
| if row.count('|') != 3: | |
| continue | |
| filename, utterances, speaker, emotion = row.split('|') | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = emotion | |
| text_prompt = 'this emotion is' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "DCASE17Task4": | |
| assert flamingo_task == "SceneClassification" | |
| assert split in ["test"] | |
| map_split = lambda split: 'unbalanced_train_segments_testing_set_audio_formatted_and_segmented_downloads' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| split_file = os.path.join( | |
| dataset_path, | |
| 'Task-4-Large-scale-weakly-supervised-sound-event-detection-for-smart-cars', | |
| 'groundtruth_release', | |
| 'groundtruth_strong_label_testing_set.csv' | |
| ) | |
| dic = defaultdict(list) | |
| all_labels = [] | |
| with open(split_file, newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter='\t', quotechar='"') | |
| for row in tqdm(reader): | |
| filename = 'Y' + row[0] | |
| label = row[-1] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| dic[filename] += label.split(', ') | |
| all_labels += label.split(', ') | |
| print('all labels:\n', ', '.join(list(set(all_labels)))) | |
| for filename in dic: | |
| text_output = ', '.join(list(set(dic[filename]))) | |
| text_prompt = 'this acoustic scene is' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "emov-db": | |
| assert flamingo_task == "EmotionClassification" | |
| assert split in ["train", "val"] | |
| map_split = lambda split: '22khz_from_16khz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| split_file = os.path.join( | |
| dataset_path, | |
| 'cleaned_emov_db_audiopath_text_sid_emotion_duration_filelist_merged_{}.txt'.format(split) | |
| ) | |
| with open(split_file, 'r') as f: | |
| data = f.readlines() | |
| data = [x.replace('\n', '') for x in data] | |
| for row in tqdm(data): | |
| if row.count('|') != 4: | |
| continue | |
| filename, utterances, speaker, emotion, duration = row.split('|') | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = emotion | |
| text_output = EMOTION_MAP_DICT[text_output] | |
| text_prompt = 'this emotion is' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "Epidemic_sound": | |
| assert split == 'train' | |
| assert flamingo_task in ["AudioCaptioning", "Tagging"] | |
| map_split = lambda split: 'audio' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.mp3'), os.listdir(file_path))) | |
| with open(os.path.join(dataset_path, 'Epidemic_all_debiased.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| if len(row) != 5: | |
| continue | |
| _, caption_1, caption_2, caption_t5, fileid = row | |
| filename = '{}.mp3'.format(fileid) | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| if flamingo_task == "AudioCaptioning": | |
| text_output = caption_t5 | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'generate audio caption' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif flamingo_task == "Tagging": | |
| if not caption_2.startswith('the sounds of'): | |
| continue | |
| caption_2 = caption_2.replace('the sounds of ', '') | |
| caption_2 = caption_2.replace(', and', ',') | |
| if len(caption_2) < 2: | |
| continue | |
| tags = caption_2.split(', ') | |
| tags = list(map(lambda x: x.replace("'", "").strip().lower(), tags)) | |
| text_output = '{}'.format(', '.join(tags)) | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'generate tags' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "ESC50": | |
| assert flamingo_task in ["EventClassification"] | |
| assert split == 'train' | |
| map_split = lambda split: 'ESC-50-master/audio' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| with open(os.path.join(dataset_path, 'ESC-50-master/meta/esc50.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| if len(row) != 7: | |
| continue | |
| filename, fold, target, category, esc10, src_file, take = row | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = category.replace('_', ' ') | |
| text_prompt = 'classify this sound.' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "FMA": | |
| import ast | |
| assert flamingo_task in ["GenreClassification"] | |
| assert split == 'train' | |
| map_split = lambda split: 'fma_large' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| with open(os.path.join(dataset_path, 'fma_metadata/raw_tracks.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| if len(row) != 39: | |
| continue | |
| track_id,album_id,album_title,album_url, \ | |
| artist_id,artist_name,artist_url,artist_website, \ | |
| license_image_file,license_image_file_large, \ | |
| license_parent_id,license_title,license_url, \ | |
| tags,track_bit_rate,track_comments,track_composer, \ | |
| track_copyright_c,track_copyright_p,track_date_created,track_date_recorded, \ | |
| track_disc_number,track_duration,track_explicit,track_explicit_notes, \ | |
| track_favorites,track_file,track_genres,track_image_file,track_information, \ | |
| track_instrumental,track_interest,track_language_code, \ | |
| track_listens,track_lyricist,track_number,track_publisher,track_title,track_url = row | |
| l = len(str(track_id)) | |
| if l <= 3: | |
| filename = '{}/{}.mp3'.format( | |
| '000', | |
| '0'*(6-l)+str(track_id) | |
| ) | |
| else: | |
| filename = '{}/{}.mp3'.format( | |
| '0'*(6-l)+str(track_id)[:l-3], | |
| '0'*(6-l)+str(track_id) | |
| ) | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| if len(track_genres) == 0: | |
| continue | |
| track_genres = ast.literal_eval(track_genres) | |
| genres = ', '.join([dic['genre_title'].lower().strip() for dic in track_genres]) | |
| text_output = genres + '.' | |
| text_prompt = "what is the genre of this music?" | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "FSD50k": | |
| import ndjson | |
| assert flamingo_task == "EventClassification" | |
| assert split in ["train", "test"] | |
| map_split = lambda split: '44khz/dev' if split == 'train' else '44khz/eval' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| with open(os.path.join(dataset_path, '{}.ndjson'.format(map_split(split).replace('44khz/', '')))) as ndjsonfile: | |
| reader = ndjson.load(ndjsonfile) | |
| for row in tqdm(reader): | |
| filename = row["filepath"].split("/")[1] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| labels = [x.replace("_", " ").lower() for x in row["labels"]] | |
| text_output = ", ".join(labels) | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'this is a sound of' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "GTZAN": | |
| assert flamingo_task == "GenreClassification" | |
| assert split in ["train"] | |
| map_split = lambda split: 'gtzan/data/genres' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| for genre in os.listdir(file_path): | |
| genre_wavs = [x for x in os.listdir(os.path.join(file_path, genre)) if x.endswith('.wav')] | |
| for genre_wav in genre_wavs: | |
| filename = os.path.join(genre, genre_wav) | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = genre | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'What is the genre of this music?' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "IEMOCAP": | |
| assert flamingo_task == "EmotionClassification" | |
| assert split in ["train", "test"] | |
| map_split = lambda split: 'IEMOCAP_full_release/16khz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| def read_this_ndjson(file_path): | |
| dic_list = [] | |
| with open(file_path, 'r') as f: | |
| for line in f: | |
| turn_name = line.split("'turn_name': ")[-1].split(',')[0].replace("'", "") | |
| emotion = line.split("'emotion': ")[-1].split(',')[0].replace("'", "") | |
| dic = { | |
| 'turn_name': turn_name, | |
| 'emotion': emotion | |
| } | |
| dic_list.append(dic) | |
| return dic_list | |
| all_emotions = [] | |
| meta_files = [x for x in os.listdir(os.path.join(dataset_path, 'IEMOCAP_full_release/ndjson')) if x.endswith('.ndjson')] | |
| for meta_file in tqdm(meta_files): | |
| main_folder = meta_file.split('_')[0] | |
| sub_folder = (meta_file.split('.ndjson')[0])[len(main_folder)+1:] | |
| if split == "train" and main_folder == "Session5": | |
| continue | |
| elif split == "test" and main_folder != "Session5": | |
| continue | |
| metadata_list = read_this_ndjson(os.path.join(dataset_path, 'IEMOCAP_full_release/ndjson', meta_file)) | |
| for dic in metadata_list: | |
| filename = os.path.join(main_folder, sub_folder, dic['turn_name']+'.wav') | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| if dic['emotion'] in ['unknown', 'other']: | |
| continue | |
| text_output = dic['emotion'] | |
| text_output = EMOTION_MAP_DICT[text_output] | |
| all_emotions.append(text_output) | |
| text_prompt = 'this emotion is' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| print('all emotions:', list(set(all_emotions))) | |
| elif dataset_name == "jl-corpus": | |
| assert flamingo_task == "EmotionClassification" | |
| assert split in ["train", "val"] | |
| map_split = lambda split: '44khz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| split_file = os.path.join( | |
| dataset_path, | |
| 'jl-corpus_audiopath_text_sid_emotion_duration_{}_filelist.txt'.format(split) | |
| ) | |
| with open(split_file, 'r') as f: | |
| data = f.readlines() | |
| data = [x.replace('\n', '') for x in data] | |
| for row in tqdm(data): | |
| if row.count('|') != 4: | |
| continue | |
| filename, utterances, speaker, emotion, duration = row.split('|') | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = emotion | |
| text_output = EMOTION_MAP_DICT[text_output] | |
| text_prompt = 'this emotion is' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "LP-MusicCaps-MC": | |
| import pandas as pd | |
| assert flamingo_task in ["AudioCaptioning"] | |
| assert split in ["train", "test"] | |
| map_split = lambda split: '../MusicCaps/44khz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| parquet_files = [f for f in os.listdir(os.path.join(dataset_path, 'data')) if f.endswith('.parquet') and f.startswith(split)] | |
| print('parquet_files', parquet_files) | |
| metadata_df = pd.concat([pd.read_parquet(os.path.join(dataset_path, 'data', f)) for f in parquet_files]) | |
| for index, row in tqdm(metadata_df.iterrows()): | |
| filename = row['ytid'] + '.wav' | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_prompt = 'generate audio caption' | |
| for caption in [row['caption_writing'], row['caption_summary'], row['caption_paraphrase']]: | |
| text_output = caption | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "LP-MusicCaps-MSD": | |
| import pandas as pd | |
| assert flamingo_task in ["AudioCaptioning"] | |
| assert split in ["train", "test", "val"] | |
| map_split = lambda split: '../MSD/mp3s_22khz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| parquet_files = [f for f in os.listdir(dataset_path) if f.endswith('.parquet') and f.startswith(split)] | |
| print('parquet_files', parquet_files) | |
| metadata_df = pd.concat([pd.read_parquet(os.path.join(dataset_path, f)) for f in parquet_files]) | |
| for index, row in tqdm(metadata_df.iterrows()): | |
| filename = row['path'] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_prompt = 'generate audio caption' | |
| for caption in [row['caption_writing'], row['caption_summary'], row['caption_paraphrase']]: | |
| text_output = caption | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "LP-MusicCaps-MTT": | |
| import pandas as pd | |
| assert flamingo_task in ["AudioCaptioning"] | |
| assert split in ["train", "test", "val"] | |
| map_split = lambda split: '../MagnaTagATune/16khz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| parquet_files = [f for f in os.listdir(dataset_path) if f.endswith('.parquet') and f.startswith(split)] | |
| print('parquet_files', parquet_files) | |
| metadata_df = pd.concat([pd.read_parquet(os.path.join(dataset_path, f)) for f in parquet_files]) | |
| for index, row in tqdm(metadata_df.iterrows()): | |
| filename = row['path'] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_prompt = 'generate audio caption' | |
| for caption in [row['caption_writing'], row['caption_summary'], row['caption_paraphrase']]: | |
| text_output = caption | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "MACS": | |
| assert flamingo_task in ["AudioCaptioning", "Tagging"] | |
| assert split == 'train' | |
| map_split = lambda split: 'TAU_Urban_Acoustic_Scenes_2019/TAU-urban-acoustic-scenes-2019-development/audio' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| metadata_list = yaml.load(open(os.path.join(dataset_path, 'MACS.yaml')), Loader=yaml.FullLoader)['files'] | |
| for file_metadata in tqdm(metadata_list): | |
| filename = file_metadata['filename'] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| for each_annotated in file_metadata['annotations']: | |
| caption = each_annotated['sentence'] | |
| tags = ', '.join(each_annotated['tags']).replace('_', ' ') | |
| if flamingo_task == "AudioCaptioning": | |
| text_output = caption | |
| text_prompt = 'generate audio caption' | |
| elif flamingo_task == "Tagging": | |
| raise NotImplementedError | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "Medley-solos-DB": | |
| import ndjson | |
| assert flamingo_task in ["InstrClassification"] | |
| map_split = lambda split: '44khz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| with open(os.path.join(dataset_path, 'medleysolosdb_manifest.ndjson')) as ndjsonfile: | |
| metadata_list = ndjson.load(ndjsonfile) | |
| for file_metadata in tqdm(metadata_list): | |
| subset = file_metadata['subset'] | |
| if not subset.startswith(split): | |
| continue | |
| filename = file_metadata['filepath'] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| instrument = file_metadata["instrument"] | |
| text_output = instrument | |
| text_prompt = 'this music note is produced by' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "MELD": | |
| import numpy as np | |
| assert flamingo_task in ["EmotionClassification", "SentimentClassification"] | |
| map_split = lambda split: '44khz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| split_file = os.path.join( | |
| dataset_path, | |
| '{}.txt'.format(split if split in ['train', 'test'] else 'dev') | |
| ) | |
| with open(split_file, 'r') as f: | |
| data = f.readlines() | |
| data = [x.replace('\n', '') for x in data] | |
| emotion_count = { | |
| 'neutral': 4703, 'happy': 1739, 'sad': 683, 'surprised': 1204, | |
| 'disgusted': 271, 'angry': 1108, 'fearful': 268, | |
| } | |
| sentiment_count = { | |
| 'neutral': 4703, 'positive': 2330, 'negative': 2943, | |
| } | |
| balancing_factor = 1 | |
| for row in tqdm(data): | |
| if row.count('|') != 4: | |
| continue | |
| filename, utterances, speaker, emotion, sentiment = row.split('|') | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| if flamingo_task == "EmotionClassification": | |
| text_output = emotion | |
| text_output = EMOTION_MAP_DICT[text_output] | |
| text_prompt = 'this emotion is' | |
| if split == 'train': | |
| balancing_factor = float(emotion_count['neutral']) / float(emotion_count[text_output]) | |
| elif flamingo_task == "SentimentClassification": | |
| text_output = sentiment | |
| text_prompt = 'this sentiment is' | |
| if split == 'train': | |
| balancing_factor = float(sentiment_count['neutral']) / float(sentiment_count[text_output]) | |
| if len(text_output) <= 1: | |
| continue | |
| for _ in range(int(np.floor(balancing_factor))): | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| if np.random.rand() < balancing_factor - np.floor(balancing_factor): | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "MSP-PODCAST-Publish-1.9": | |
| assert flamingo_task == "EmotionClassification" | |
| assert split in ["train", "val", "test"] | |
| map_split = lambda split: 'Audio' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = glob.glob('{}/*/*.wav'.format(file_path)) | |
| file_list = [x[len(file_path)+1:] for x in file_list] | |
| subfolder_map = {} | |
| for f in tqdm(file_list): | |
| subfolder, filename = f.split('/') | |
| subfolder_map[filename] = subfolder | |
| file_list = None | |
| emotion_dic = { | |
| 'A': 'Angry', | |
| 'S': 'Sad', | |
| 'H': 'Happy', | |
| 'U': 'Surprise', | |
| 'F': 'Fear', | |
| 'D': 'Disgust', | |
| 'C': 'Contempt', | |
| 'N': 'Neutral', | |
| 'O': 'Other', | |
| 'X': 'Not clear' | |
| } | |
| with open(os.path.join(dataset_path, 'Labels/labels_concensus.json')) as f: | |
| data = f.read() | |
| metadata_dic = json.loads(data) | |
| for filename in tqdm(list(metadata_dic.keys())): | |
| values = metadata_dic[filename] | |
| if not values["Split_Set"].lower().startswith(split): | |
| continue | |
| if values["EmoClass"] in ["O", "X"] or values["EmoClass"] not in emotion_dic.keys(): | |
| continue | |
| subfolder = subfolder_map[filename] | |
| filename = '{}/{}'.format(subfolder, filename) | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = emotion_dic[values["EmoClass"]].lower() | |
| text_output = EMOTION_MAP_DICT[text_output] | |
| text_prompt = 'this emotion is' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "mtg-jamendo": | |
| import ndjson | |
| assert flamingo_task == "MusicTagging" | |
| assert split in ["train", "val"] | |
| map_split = lambda split: '44khz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| with open(os.path.join(dataset_path, 'mtg_jamendo_{}_manifest.ndjson'.format(split))) as ndjsonfile: | |
| reader = ndjson.load(ndjsonfile) | |
| for row in tqdm(reader): | |
| filename = row["audiopath"] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = row["caption"] | |
| text_prompt = 'generate music tags (genre, instrument, mood/theme)' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "MU-LLAMA": | |
| assert flamingo_task in ['AQA'] | |
| assert split in ['train', 'test'] | |
| map_split = lambda split: 'MusicQA/audios' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| split_file = 'MusicQA/FinetuneMusicQA.json' if split == 'train' else 'MusicQA/EvalMusicQA.json' | |
| with open(os.path.join(dataset_path, split_file), 'r') as f: | |
| data = f.read() | |
| metadata_list = json.loads(data) | |
| for dic in tqdm(metadata_list): | |
| filename = dic["audio_name"] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_prompt = 'Question: ' + dic["conversation"][0]["value"].strip() | |
| if not (text_prompt.endswith('.') or text_prompt.endswith('?')): | |
| text_prompt = text_prompt + '.' | |
| text_output = dic["conversation"][1]["value"].strip() | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "musdbhq": | |
| assert flamingo_task in ["InstrClassification"] | |
| assert split in ["train", "test", "val"] | |
| map_split = lambda split: './' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| with open(os.path.join(dataset_path, 'file_list_44k_{}.txt'.format(split))) as f: | |
| data = f.readlines() | |
| data = [x.replace('\n', '') for x in data] | |
| for row in tqdm(data): | |
| if row.count('|') != 1: | |
| continue | |
| filename, duration = row.split('|') | |
| duration = float(duration) | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = filename.split('/')[-1].split('.wav')[0] | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'this music is produced by' | |
| segment_length = 10 | |
| for audio_start_idx in range(int(duration // segment_length)): | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' '), | |
| "audio_start": audio_start_idx * segment_length | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "Music-AVQA": | |
| import ast | |
| import re | |
| assert flamingo_task in [ | |
| "{}_{}".format(q, t) \ | |
| for q in ['AQA', 'AVQA'] \ | |
| for t in ['Comparative', 'Counting', 'Existential', 'Location', 'Temporal', 'All'] | |
| ] | |
| def replace_bracketed_words(input_string, replacements): | |
| def replacer(match): | |
| word = next(replacements) | |
| return word | |
| replacements = iter(replacements) | |
| output_string = re.sub(r'<[^>]*>', replacer, input_string) | |
| return output_string | |
| map_split = lambda split: 'audio' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| with open(os.path.join(dataset_path, 'MUSIC-AVQA/data/json/avqa-{}.json'.format(split)), 'r') as f: | |
| data = f.read() | |
| metadata_list = json.loads(data) | |
| for dic in tqdm(metadata_list): | |
| filename = dic["video_id"] + '.wav' | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| types = ast.literal_eval(dic["type"]) | |
| if 'Visual' in types: | |
| continue | |
| if flamingo_task.startswith('AQA_') and 'Audio-Visual' in types: | |
| continue | |
| if flamingo_task.startswith('AVQA_') and 'Audio' in types: | |
| continue | |
| t = flamingo_task.split('_')[1] | |
| if (not t == 'All') and (not t in types): | |
| continue | |
| text_output = dic["anser"] | |
| if len(text_output) <= 1: | |
| continue | |
| question = dic["question_content"].replace("\uff1f", '?') | |
| templ_values = ast.literal_eval(dic["templ_values"]) | |
| if len(templ_values) > 0: | |
| question = replace_bracketed_words(question, templ_values) | |
| text_prompt = "Question: " + question | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "MusicCaps": | |
| assert flamingo_task in ["AudioCaptioning", "EventClassification"] | |
| assert split in ["train", "test"] | |
| map_split = lambda split: '44khz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| with open(os.path.join(dataset_path, 'musiccaps_manifest.json')) as f: | |
| data = f.read() | |
| metadata_list = json.loads(data) | |
| for file_metadata in tqdm(metadata_list): | |
| filename = file_metadata['filepath'] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| start_s, end_s = file_metadata["start_s"], file_metadata["end_s"] | |
| caption = file_metadata["caption"] | |
| audioset_positive_labels = file_metadata["audioset_positive_labels"] # audioset classes | |
| aspect_list = file_metadata["aspect_list"] # annotated classes | |
| if (split == 'train') == file_metadata["is_audioset_eval"]: | |
| continue | |
| if flamingo_task == "AudioCaptioning": | |
| text_output = caption | |
| text_prompt = 'generate audio caption' | |
| elif flamingo_task == "EventClassification": | |
| raise NotImplementedError | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "NonSpeech7k": | |
| assert flamingo_task in ["EventClassification"] | |
| assert split in ["train", "test"] | |
| map_split = lambda split: split | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| all_classes = [] | |
| with open(os.path.join(dataset_path, 'metadata of {} set.csv').format(split), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| filename, _, _, _, classname, _, _, _ = row | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = classname.lower() | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'this is a sound of' | |
| all_classes.append(classname) | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| print('all classes:', list(set(all_classes))) | |
| elif dataset_name == "NSynth": | |
| import ndjson | |
| assert flamingo_task in [ | |
| "InstrClassification", | |
| "PitchClassification", | |
| "VelocityClassification", | |
| "SourceClassification", | |
| "QualityClassification", | |
| "MIR" | |
| ] | |
| assert split in ["train", "test", "val"] | |
| map_split = lambda split: 'nsynth-{}/audio'.format('valid' if split == 'val' else split) | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| with open(os.path.join(dataset_path, map_split(split), '../examples.json')) as f: | |
| data = f.read() | |
| reader = json.loads(data) | |
| for key in tqdm(reader): | |
| filename = key + '.wav' | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| if flamingo_task == "InstrClassification": | |
| text_output = reader[key]["instrument_family_str"] | |
| text_prompt = 'this music note is produced by' | |
| elif flamingo_task == "PitchClassification": | |
| text_output = str(reader[key]["pitch"]) | |
| text_prompt = 'this music note has pitch' | |
| elif flamingo_task == "VelocityClassification": | |
| text_output = str(reader[key]["velocity"]) | |
| text_prompt = 'this music note has velocity' | |
| elif flamingo_task == "SourceClassification": | |
| text_output = reader[key]["instrument_source_str"] | |
| text_prompt = 'this music note has sonic source' | |
| elif flamingo_task == "QualityClassification": | |
| qualities_str = reader[key]["qualities_str"] | |
| if len(qualities_str) >= 1: | |
| text_output = ', '.join(qualities_str).replace('_', ' ') | |
| else: | |
| text_output = 'none' | |
| text_prompt = 'this music note has sonic qualities' | |
| elif flamingo_task == "MIR": | |
| instrument = reader[key]["instrument_family_str"] | |
| pitch = str(reader[key]["pitch"]) | |
| velocity = str(reader[key]["velocity"]) | |
| source = reader[key]["instrument_source_str"] | |
| qualities_str = ', '.join(reader[key]["qualities_str"]).replace('_', ' ') | |
| assert len(instrument) > 0 | |
| text_output = 'produced by {}'.format(instrument) | |
| if len(pitch) > 0: | |
| text_output = text_output + ', pitch {}'.format(pitch) | |
| if len(velocity) > 0: | |
| text_output = text_output + ', velocity {}'.format(velocity) | |
| if len(source) > 0: | |
| text_output = text_output + ', source {}'.format(source) | |
| if len(qualities_str) > 0: | |
| text_output = text_output + ', and having qualities like {}'.format(qualities_str) | |
| text_prompt = 'this music note is' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "OMGEmotion": | |
| import numpy as np | |
| import webrtcvad | |
| import wave | |
| from pydub import AudioSegment | |
| assert flamingo_task == "EmotionClassification" | |
| assert split in ["train", "val"] | |
| def convert_to_wav(file_path): | |
| audio = AudioSegment.from_file(file_path).set_frame_rate(16000).set_channels(1) | |
| wav_path = file_path.rsplit('.', 1)[0] + "_converted.wav" | |
| audio.export(wav_path, format="wav") | |
| return wav_path | |
| def contains_speech(file_path, aggressiveness=0): | |
| # aggressiveness between 0 and 3, 0 for very clean speech, and 3 for noisy speech | |
| wav_path = convert_to_wav(file_path) | |
| vad = webrtcvad.Vad(aggressiveness) | |
| with wave.open(wav_path, 'rb') as audio: | |
| assert audio.getsampwidth() == 2, "Audio must be 16-bit" | |
| assert audio.getnchannels() == 1, "Audio must be mono" | |
| assert audio.getframerate() == 16000, "Audio must be sampled at 16kHz" | |
| frame_duration = 10 # ms | |
| frame_size = int(audio.getframerate() * frame_duration / 1000) | |
| num_frames = int(audio.getnframes() / frame_size) | |
| for _ in range(num_frames): | |
| frame = audio.readframes(frame_size) | |
| if vad.is_speech(frame, audio.getframerate()): | |
| return True | |
| return False | |
| map_split = lambda split: 'processed-{}_utterance_data'.format(split) | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| dic_code2emotion = { | |
| "0": "anger", | |
| "1": "disgust", | |
| "2": "fear", | |
| "3": "happy", | |
| "4": "neutral", | |
| "5": "sad", | |
| "6": "surprise", | |
| } | |
| all_emotions = [] | |
| meta_file = os.path.join( | |
| dataset_path, | |
| 'OMGEmotionChallenge', | |
| 'omg_{}Videos.csv'.format('Train' if split == 'train' else 'Validation') | |
| ) | |
| with open(meta_file, newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| link, start, end, video, utterance, _, _, EmotionMaxVote = row | |
| emotion = dic_code2emotion[str(EmotionMaxVote)] | |
| filename = os.path.join(video, utterance.replace('.mp4', '.mp3')) | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| if not contains_speech(os.path.join(file_path, filename)): | |
| print('{} does not contain speech'.format(filename)) | |
| continue | |
| text_prompt = 'this emotion is' | |
| text_output = emotion | |
| if len(text_output) <= 1: | |
| continue | |
| all_emotions.append(EMOTION_MAP_DICT[emotion]) | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| print('all emotions:', list(set(all_emotions))) | |
| elif dataset_name == "OpenAQA": | |
| assert flamingo_task == 'AQA' | |
| assert split == 'train' | |
| map_split = lambda split: './' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| no_word_list = [ | |
| 'cannot determine', 'not provided', 'cannot be determined', 'sorry', 'i cannot', | |
| 'without more information', 'enough information', | |
| 'not possible', 'more context', 'enough', 'impossible', 'cannot be determined', | |
| 'without additional information', | |
| 'unclear', 'cannot', 'not clear', 'do not provide sufficient', 'does not provide', | |
| 'difficult to determine', 'no information provided', | |
| "can't infer", "difficult to infer", "not specified", "no specific", "no information", | |
| "without additional", 'it is difficult to', "no indication" | |
| ] | |
| print('computing dic_audiosetfull_parts') | |
| audiosetfull_root = '/mnt/fsx-main/rafaelvalle/datasets/audioset/unbalanced_train_segments/22khz/' | |
| part_strings = [('0'*(2-len(str(p))) + str(p)) for p in range(41)] | |
| dic_audiosetfull_parts = { | |
| part: set(os.listdir(os.path.join(audiosetfull_root, 'unbalanced_train_segments_part{}'.format(part)))) \ | |
| for part in part_strings | |
| } | |
| audioset20k_filelist = set(os.listdir(os.path.join(file_path, '../AudioSet/train_wav'))) | |
| print('computing dic_clotho_filename') | |
| clotho_files = os.listdir(os.path.join(dataset_path, '../Clotho-AQA/audio_files')) | |
| dic_clotho_filename = { | |
| '_'.join([s for s in f.split(' ') if len(s) > 0]): f \ | |
| for f in clotho_files | |
| } | |
| print('reading open_ended/all_open_qa.json') | |
| with open(os.path.join(dataset_path, 'openaqa/data/open_ended/all_open_qa.json'), 'r') as f: | |
| data = f.read() | |
| metadata_list = json.loads(data) | |
| for dic in tqdm(metadata_list): | |
| #keys: instruction, input, dataset, audio_id, output, task | |
| text_output = dic["output"] | |
| if len(text_output) <= 1: | |
| continue | |
| if any(word in text_output.lower() for word in no_word_list): | |
| continue | |
| question = dic["instruction"] | |
| text_prompt = question | |
| audio_id = dic["audio_id"] | |
| subset = dic["dataset"] | |
| if subset == 'clotho_development': | |
| filename = audio_id.split('/')[-1] | |
| processed_filename = '_'.join([s for s in filename.split('_') if len(s) > 0]) | |
| if processed_filename in dic_clotho_filename: | |
| filename = os.path.join( | |
| '../Clotho-AQA/audio_files', | |
| dic_clotho_filename[processed_filename] | |
| ) | |
| else: | |
| continue | |
| elif subset in ['audiocaps_train', 'as_20k', 'as_strong_train']: | |
| found = False | |
| filename = audio_id.split('/')[-1].split('.flac')[0] + '.wav' | |
| if filename in audioset20k_filelist: | |
| filename = os.path.join('../AudioSet/train_wav', filename) | |
| found = True | |
| else: | |
| filename = 'Y' + filename | |
| for part in part_strings: | |
| if filename in dic_audiosetfull_parts[part]: | |
| filename = os.path.join( | |
| audiosetfull_root, | |
| 'unbalanced_train_segments_part{}'.format(part), | |
| filename | |
| ) | |
| found = True | |
| break | |
| if not found: | |
| print(filename, 'not found') | |
| continue | |
| elif subset == 'freesound_10s': | |
| filename = os.path.join( | |
| '../CLAP_freesound/freesound_no_overlap/split/train', | |
| audio_id.split('/')[-1] | |
| ) | |
| elif subset == 'vggsound_train': | |
| continue | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "ravdess": | |
| assert flamingo_task == "EmotionClassification" | |
| assert split in ["train", "val"] | |
| map_split = lambda split: '44khz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| split_file = os.path.join( | |
| dataset_path, | |
| 'ravdess_audiopath_text_sid_emotion_duration_{}_filelist.txt'.format(split) | |
| ) | |
| with open(split_file, 'r') as f: | |
| data = f.readlines() | |
| data = [x.replace('\n', '') for x in data] | |
| for row in tqdm(data): | |
| if row.count('|') != 4: | |
| continue | |
| filename, utterances, speaker, emotion, duration = row.split('|') | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = emotion | |
| text_prompt = 'this emotion is' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "SongDescriber": | |
| assert flamingo_task in ["AudioCaptioning"] | |
| assert split in ["train"] | |
| map_split = lambda split: './audio/audio' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| with open(os.path.join(dataset_path, 'song_describer.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| caption_id,track_id,caption,is_valid_subset,familiarity,artist_id,album_id,path,duration = row | |
| filename = '{}/{}.2min.mp3'.format(track_id[-2:], track_id) | |
| duration = float(duration) | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = caption | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'generate audio caption' | |
| segment_length = 30 | |
| for audio_start_idx in range(int(duration // segment_length)): | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' '), | |
| "audio_start": audio_start_idx * segment_length | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "SONYC-UST": | |
| import numpy as np | |
| assert flamingo_task == "EventClassification" | |
| assert split in ["train", "test", "val"] | |
| map_split = lambda split: 'audio' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| all_labels = [] | |
| with open(os.path.join(dataset_path, 'annotations.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| for idx, row in tqdm(enumerate(reader)): | |
| if idx == 0: | |
| header = np.array(row) | |
| continue | |
| if not row[0].startswith(split): | |
| continue | |
| filename = row[2] | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| labels = [header[i] for i in range(12, len(header)-8) if str(row[i]) == "1"] | |
| labels = [x.split("_")[1].replace('-', ' ').lower() for x in labels if 'X_' not in x] | |
| all_labels += labels | |
| text_output = ", ".join(labels) | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'this is a sound of' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| print('all labels:', list(set(all_labels))) | |
| elif dataset_name == "SoundDescs": | |
| import torch | |
| assert flamingo_task in ["AudioDescription"] | |
| assert split in ["train"] | |
| map_split = lambda split: 'raw/audios' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| split_file = os.path.join(dataset_path, 'audio-retrieval-benchmark/data/SoundDescs/{}_list.txt'.format(split)) | |
| with open(split_file, 'r') as f: | |
| data = f.readlines() | |
| names = set([x.replace('\n', '') for x in data]) | |
| with open(os.path.join(dataset_path, 'audio-retrieval-benchmark/sounddescs_data/descriptions.pkl'), 'rb') as f: | |
| obj = f.read() | |
| metadata_dic = pickle.loads(obj, encoding='latin1') | |
| for name in tqdm(names): | |
| if name not in metadata_dic.keys(): | |
| continue | |
| filename = '{}.wav'.format(name) | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| description = metadata_dic[name] | |
| text_output = description | |
| text_prompt = 'generate audio description' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "tess": | |
| assert flamingo_task == "EmotionClassification" | |
| assert split in ["train", "val"] | |
| map_split = lambda split: '24414hz' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| split_file = os.path.join( | |
| dataset_path, | |
| 'tess_audiopath_text_sid_emotion_duration_{}_filelist.txt'.format(split) | |
| ) | |
| with open(split_file, 'r') as f: | |
| data = f.readlines() | |
| data = [x.replace('\n', '') for x in data] | |
| for row in tqdm(data): | |
| if row.count('|') != 4: | |
| continue | |
| filename, utterances, speaker, emotion, duration = row.split('|') | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = emotion.replace('_', ' ') | |
| text_output = EMOTION_MAP_DICT[text_output] | |
| text_prompt = 'this emotion is' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "UrbanSound8K": | |
| assert flamingo_task in ["EventClassification"] | |
| assert split in ["train"] | |
| map_split = lambda split: 'audio' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = None | |
| with open(os.path.join(dataset_path, 'metadata/UrbanSound8K.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| filename, fsID, start, end, salience, fold, classID, class_name = row | |
| filename = 'fold{}/{}'.format(fold, filename) | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = class_name.replace("_", " ").lower() | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'this is a sound of' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "VocalSound": | |
| assert flamingo_task == "VocalClassification" | |
| map_split = lambda split: 'data_44k' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| split_file = os.path.join( | |
| dataset_path, | |
| 'meta/{}_meta.csv'.format(split[:2] if split in ['train', 'test'] else split[:3]) | |
| ) | |
| prefix = set([]) | |
| with open(split_file, newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| for row in reader: | |
| prefix.add(row[0]) | |
| all_labels = set([]) | |
| for filename in tqdm(file_list): | |
| if not filename.split('_')[0] in prefix: | |
| continue | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| label = filename.split('_')[2].split('.wav')[0] | |
| if label == 'throatclearing': | |
| label = 'throat clearing' | |
| text_output = label | |
| text_prompt = 'this vocal sound is' | |
| all_labels.add(label) | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| print('all labels:\n', "\'" + "\', \'".join(list(all_labels)) + "\'") | |
| elif dataset_name.startswith("WavCaps"): | |
| assert split in ["train"] | |
| dataset_name, subset_name = dataset_name.split('-') | |
| dataset_path = os.path.join( | |
| '/'.join(dataset_path.split('/')[:-1]), | |
| dataset_name | |
| ) | |
| dataset_dic['dataset_path'] = dataset_path | |
| map_split = lambda split: subset_name + '_flac' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) | |
| metadata_file = os.listdir(os.path.join(dataset_path, "json_files", subset_name)) | |
| metadata_file = [x for x in metadata_file if x.endswith('json')][0] | |
| with open(os.path.join(dataset_path, "json_files", subset_name, metadata_file)) as f: | |
| data = f.read() | |
| reader = json.loads(data) | |
| if subset_name == "AudioSet_SL": | |
| assert flamingo_task == 'AudioCaptioning' | |
| for sample in tqdm(reader['data']): | |
| filename = sample["id"].replace('.wav', '.flac') | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| text_output = sample['caption'] | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'generate audio caption' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| else: | |
| assert flamingo_task in ['AudioCaptioning', 'AudioDescription'] | |
| for sample in tqdm(reader['data']): | |
| filename = sample["id"] + '.flac' | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| if flamingo_task == 'AudioCaptioning': | |
| text_output = sample['caption'] | |
| text_prompt = 'generate audio caption' | |
| elif flamingo_task == 'AudioDescription': | |
| text_output = sample['description'] | |
| text_prompt = 'generate audio description' | |
| if len(text_output) <= 1: | |
| continue | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif dataset_name == "WavText5K": | |
| assert split == 'train' | |
| map_split = lambda split: 'Webcrawl/44100/audios' | |
| file_path = os.path.join( | |
| dataset_path, | |
| map_split(split) | |
| ) | |
| assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
| dataset_dic["split_path"] = map_split(split) | |
| file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
| dic = defaultdict(str) | |
| with open(os.path.join(dataset_path, 'WavText5K.csv'), newline='') as csvfile: | |
| reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
| next(reader) | |
| for row in tqdm(reader): | |
| _, _, title, description, filename, tags = row | |
| dic[filename] = (title, description, tags) | |
| if flamingo_task == "AudioCaptioning": | |
| for filename in tqdm(dic.keys()): | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| title, description, tags = dic[filename] | |
| text_output = description | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'generate audio caption' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| elif flamingo_task == "Tagging": | |
| for filename in tqdm(dic.keys()): | |
| if filter_file(file_path, file_list, filename): | |
| continue | |
| title, description, tags = dic[filename] | |
| if len(tags) < 2 or not tags.startswith('[') or not tags.endswith(']'): | |
| continue | |
| tags = tags[1:-1].split(', ') | |
| tags = list(map(lambda x: x.replace("'", ""), tags)) | |
| text_output = '{}'.format(', '.join(tags)) | |
| if len(text_output) <= 1: | |
| continue | |
| text_prompt = 'generate tags' | |
| dataset_dic["data"][dataset_dic["total_num"]] = { | |
| "name": filename, | |
| "prompt": text_prompt, | |
| "output": text_output.replace('\n', ' ') | |
| } | |
| dataset_dic["total_num"] += 1 | |
| with open(output_file, 'w') as json_file: | |
| json.dump(dataset_dic, json_file) | |
| # ==================== Precompute CLAP and build Hashing ==================== | |
| def int16_to_float32(x): | |
| return (x / 32767.0).astype(np.float32) | |
| def float32_to_int16(x): | |
| x = np.clip(x, a_min=-1., a_max=1.) | |
| return (x * 32767.).astype(np.int16) | |
| def update_progress_bar(arg): | |
| pbar.update() | |
| def load_clap_model(checkpoint): | |
| if checkpoint in ['630k-audioset-best.pt', '630k-best.pt', '630k-audioset-fusion-best.pt', '630k-fusion-best.pt']: | |
| amodel = 'HTSAT-tiny' | |
| elif checkpoint in ['music_speech_audioset_epoch_15_esc_89.98.pt']: | |
| amodel = 'HTSAT-base' | |
| else: | |
| raise NotImplementedError | |
| model = laion_clap.CLAP_Module( | |
| enable_fusion=('fusion' in checkpoint.lower()), | |
| amodel=amodel | |
| ).cuda() | |
| model.load_ckpt(ckpt=os.path.join( | |
| '/lustre/fsw/portfolios/adlr/users/zkong/audio-flamingo-data/laion-clap-pretrained/laion_clap', | |
| checkpoint | |
| )) | |
| return model | |
| def load_audio(file_path, target_sr=44100, duration=30.0, start=0.0): | |
| if file_path.endswith('.mp3'): | |
| audio = AudioSegment.from_file(file_path) | |
| if len(audio) > (start + duration) * 1000: | |
| audio = audio[start * 1000:(start + duration) * 1000] | |
| if audio.frame_rate != target_sr: | |
| audio = audio.set_frame_rate(target_sr) | |
| if audio.channels > 1: | |
| audio = audio.set_channels(1) | |
| data = np.array(audio.get_array_of_samples()) | |
| if audio.sample_width == 2: | |
| data = data.astype(np.float32) / np.iinfo(np.int16).max | |
| elif audio.sample_width == 4: | |
| data = data.astype(np.float32) / np.iinfo(np.int32).max | |
| else: | |
| raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) | |
| else: | |
| with sf.SoundFile(file_path) as audio: | |
| original_sr = audio.samplerate | |
| channels = audio.channels | |
| max_frames = int((start + duration) * original_sr) | |
| audio.seek(int(start * original_sr)) | |
| frames_to_read = min(max_frames, len(audio)) | |
| data = audio.read(frames_to_read) | |
| if data.max() > 1 or data.min() < -1: | |
| data = data / max(abs(data.max()), abs(data.min())) | |
| if original_sr != target_sr: | |
| if channels == 1: | |
| data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) | |
| else: | |
| data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] | |
| else: | |
| if channels != 1: | |
| data = data.T[0] | |
| if data.min() >= 0: | |
| data = 2 * data / abs(data.max()) - 1.0 | |
| else: | |
| data = data / max(abs(data.max()), abs(data.min())) | |
| return data | |
| def compute_clap_each(audio_file, model): | |
| try: | |
| data = load_audio(audio_file, target_sr=48000, duration=10) | |
| print(audio_file, 'loaded') | |
| except Exception as e: | |
| print(audio_file, 'unsuccessful due to', e) | |
| return None | |
| audio_data = data.reshape(1, -1) | |
| audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().cuda() | |
| audio_embed = model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True) | |
| audio_embed = audio_embed.squeeze(0).cpu() | |
| return audio_embed | |
| def compute_embeddings_batch(batch, audio_files, model): | |
| batch_results = [] | |
| for i in batch: | |
| if i >= len(audio_files): | |
| break | |
| audio_file = audio_files[i] | |
| audio_embed = compute_clap_each(audio_file, model) | |
| batch_results.append((i, audio_file, audio_embed)) | |
| return batch_results | |
| def precompute_clap_for_dataset( | |
| dataset_file, | |
| embedding_output_file, | |
| checkpoint='630k-audioset-fusion-best.pt' | |
| ): | |
| contents, audio_files = load_dataset_file(dataset_file) | |
| model = load_clap_model(checkpoint) | |
| if os.path.exists(embedding_output_file): | |
| print('loading already computed embedding file from', embedding_output_file) | |
| with open(embedding_output_file, 'rb') as f: | |
| saved_data = pickle.load(f) | |
| curr_audio_indices = saved_data['audio_indices'] | |
| curr_audio_files = saved_data['audio_files'] | |
| curr_audio_embeds = saved_data['audio_embeds'] | |
| else: | |
| curr_audio_indices = [] | |
| curr_audio_files = [] | |
| curr_audio_embeds = [] | |
| print('computing embeddings for {}'.format(dataset_file)) | |
| start_index = len(curr_audio_files) | |
| remaining_indices = list(range(start_index, len(audio_files))) | |
| batch_size = 128 | |
| batches = [ | |
| list(range(i, min(i + batch_size, len(audio_files)))) \ | |
| for i in range(start_index, len(audio_files), batch_size) | |
| ] | |
| with multiprocessing.Pool(processes=4) as pool: | |
| for i, batch in enumerate(batches): | |
| batch_results = pool.map( | |
| partial(compute_embeddings_batch, model=model, audio_files=audio_files), | |
| [batch] | |
| ) | |
| for result in batch_results[0]: | |
| curr_audio_indices.append(result[0]) | |
| curr_audio_files.append(result[1]) | |
| curr_audio_embeds.append(result[2]) | |
| with open(embedding_output_file, 'wb') as f: | |
| pickle.dump({ | |
| 'audio_indices': curr_audio_indices, | |
| 'audio_files': curr_audio_files, | |
| 'audio_embeds': curr_audio_embeds | |
| }, f) | |
| print(f"Saved progress for batch {i+1}/{len(batches)}: \ | |
| audio_indices {len(curr_audio_indices)}, \ | |
| audio_files {len(curr_audio_files)}, \ | |
| audio_embeds {len(curr_audio_embeds)}*{curr_audio_embeds[0].shape}") | |
| return curr_audio_indices, curr_audio_files, curr_audio_embeds | |
| def build_faiss_index(embeddings): | |
| d = embeddings[0].size(0) | |
| index = faiss.IndexFlatL2(d) | |
| np_embeddings = np.vstack([emb.numpy() for emb in embeddings]) | |
| index.add(np_embeddings) | |
| return index | |
| def build_faiss_index_dataset( | |
| dataset_file, | |
| embedding_output_file, | |
| faiss_output_file, | |
| checkpoint='630k-audioset-fusion-best.pt', | |
| only_precompute_clap=False | |
| ): | |
| audio_indices, audio_files, audio_embeds = precompute_clap_for_dataset(dataset_file, embedding_output_file, checkpoint) | |
| if only_precompute_clap: | |
| return | |
| valid_indices, valid_files, valid_embeds = [], [], [] | |
| for audio_index, audio_file, audio_embed in zip(audio_indices, audio_files, audio_embeds): | |
| if audio_embed is not None: | |
| valid_indices.append(audio_index) | |
| valid_files.append(audio_file) | |
| valid_embeds.append(audio_embed) | |
| print('building faiss index') | |
| faiss_index = build_faiss_index(valid_embeds) | |
| print('saving faiss index') | |
| faiss.write_index(faiss_index, faiss_output_file) | |
| with open(faiss_output_file + '.filenames', 'wb') as f: | |
| pickle.dump({'audio_indices': valid_indices, 'audio_files': valid_files}, f) | |
| # ==================== Generate interleaved dataset files ==================== | |
| # only save index so that one can recover | |
| def build_interleaved_dataset(dataset_file, interleaved_output_file, embedding_output_file, faiss_output_file, mode='random', n_samples=3): | |
| contents, audio_files = load_dataset_file(dataset_file) | |
| dataset_dic = { | |
| "dataset_path": contents["dataset_path"], | |
| "split": contents["split"], | |
| "split_path": contents["split_path"], | |
| "flamingo_task": contents["flamingo_task"], | |
| "total_num": 0, | |
| "interleaved_data": {}, | |
| } | |
| # interleaved_data is | |
| # { | |
| # id: { | |
| # "generation_index_in_split": index of sample in the train or val or test.json, | |
| # "fewshot_indices_in_train": list(indices) of few shot samples in train.json | |
| # } | |
| # } | |
| if mode == 'knn': | |
| model = load_clap_model(checkpoint='630k-audioset-fusion-best.pt') | |
| print('loading already computed embedding file from', embedding_output_file) | |
| with open(embedding_output_file, 'rb') as f: | |
| precomputed_data = pickle.load(f) | |
| precomputed_audio_indices = precomputed_data['audio_indices'] | |
| precomputed_audio_files = precomputed_data['audio_files'] | |
| precomputed_audio_embeds = precomputed_data['audio_embeds'] | |
| faiss_index = faiss.read_index(faiss_output_file) | |
| with open(faiss_output_file+'.filenames', 'rb') as f: | |
| _data = pickle.load(f) | |
| faiss_index_audio_indices = _data['audio_indices'] | |
| faiss_index_audio_files = _data['audio_files'] | |
| print('looking for few shot samples and building interleaved_{} data'.format(mode)) | |
| for i in tqdm(range(contents["total_num"])): | |
| if mode == 'random': | |
| few_shot_indices = list(np.random.choice( | |
| list(set(list(range(contents["total_num"]))) - set([i])), | |
| size=n_samples-1, | |
| replace=False | |
| )) | |
| few_shot_indices = list(map(int, few_shot_indices)) | |
| elif mode == 'knn': | |
| if audio_files[i] in precomputed_audio_files: | |
| idx = precomputed_audio_files.index(audio_files[i]) | |
| query_embedding_np = precomputed_audio_embeds[idx] | |
| if query_embedding_np is not None: | |
| query_embedding_np = query_embedding_np.numpy().reshape(1, -1) | |
| else: | |
| continue | |
| else: | |
| query_embedding_np = compute_clap_each(audio_files[i], model) | |
| if query_embedding_np is not None: | |
| query_embedding_np = query_embedding_np.numpy().reshape(1, -1) | |
| else: | |
| continue | |
| distances, knn_indices = faiss_index.search(query_embedding_np, n_samples+50) | |
| distances = distances[0] | |
| knn_indices = knn_indices[0] | |
| knn_filenames = [faiss_index_audio_files[idx] for idx in knn_indices] | |
| combined = list(zip(knn_indices, knn_filenames)) | |
| unique_indices = defaultdict(list) | |
| for idx, filename in combined: | |
| unique_indices[filename].append(idx) | |
| cleared_knn_indices = [random.choice(unique_indices[filename]) for filename in unique_indices if filename != audio_files[i]] | |
| if dataset_file.endswith('train.json'): | |
| cleared_knn_indices = [knn_i for knn_i in cleared_knn_indices if faiss_index_audio_indices[knn_i] != i] | |
| cleared_knn_indices = cleared_knn_indices[:n_samples-1] | |
| np.random.shuffle(cleared_knn_indices) | |
| few_shot_indices = [faiss_index_audio_indices[knn_i] for knn_i in cleared_knn_indices] | |
| dataset_dic["interleaved_data"][dataset_dic["total_num"]] = { | |
| "generation_index_in_split": i, | |
| "fewshot_indices_in_train": few_shot_indices | |
| } | |
| dataset_dic["total_num"] += 1 | |
| with open(interleaved_output_file, 'w') as json_file: | |
| json.dump(dataset_dic, json_file) | |
| if __name__ == '__main__': | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-d', '--dataset_name', type=str, help='dataset name') | |
| parser.add_argument('-f', '--flamingo_task', type=str, help='flamingo task') | |
| parser.add_argument('--interleave', action="store_true", help='prepare the interleave dataset') | |
| args = parser.parse_args() | |
| ROOT = "/lustre/fsw/portfolios/adlr/users/zkong" | |
| dataset_root = os.path.join(ROOT, "datasets") | |
| output_root = os.path.join(ROOT, "audio-flamingo-data/dataset_files") | |
| os.makedirs(output_root, exist_ok=True) | |
| dataset_name = args.dataset_name # "Clotho-v2", "AudioSet", "Clotho-AQA", "WavText5K", "FSD50k", ... | |
| flamingo_task = args.flamingo_task # AQA, AudioCaptioning, EventClassification, SceneClassification, Tagging, ... | |
| # must be train first otherwise there's no train.embedding for query | |
| for split in ["train", "val", "test"]: | |
| dataset_path = os.path.join(dataset_root, dataset_name) | |
| output_folder = '{}-{}'.format(dataset_name, flamingo_task) | |
| os.makedirs(os.path.join(output_root, output_folder), exist_ok=True) | |
| dataset_file = os.path.join(output_root, output_folder, '{}.json'.format(split)) | |
| if not os.path.exists(dataset_file): | |
| try: | |
| prepare_files(dataset_name, dataset_path, split, flamingo_task, dataset_file) | |
| except AssertionError as e: | |
| print('split {} not exist for {}: {}'.format(split, dataset_name, e)) | |
| continue | |
| else: | |
| print('{} exists; exiting'.format(dataset_file)) | |
| if args.interleave: | |
| faiss_output_file = dataset_file.replace('{}.json'.format(split), "train_faiss_index.index") | |
| embedding_output_file = dataset_file.replace('.json', ".embedding") | |
| if split == 'train': | |
| if (not os.path.exists(faiss_output_file)) or (not os.path.exists(faiss_output_file + '.filenames')): | |
| build_faiss_index_dataset( | |
| dataset_file, embedding_output_file, faiss_output_file, | |
| only_precompute_clap=False | |
| ) | |
| else: | |
| print('{} exists; exiting'.format(faiss_output_file)) | |
| else: | |
| build_faiss_index_dataset( | |
| dataset_file, embedding_output_file, | |
| faiss_output_file=None, | |
| only_precompute_clap=True | |
| ) | |
| print('precomputing embedding for {} subset finished'.format(split)) | |
| for mode in ['knn', 'random']: | |
| interleaved_output_file = '/'.join( | |
| dataset_file.split('/')[:-1] + \ | |
| ['interleaved_{}-'.format(mode) + dataset_file.split('/')[-1]] | |
| ) | |
| if not os.path.exists(interleaved_output_file): | |
| build_interleaved_dataset( | |
| dataset_file=dataset_file, | |
| interleaved_output_file=interleaved_output_file, | |
| embedding_output_file=embedding_output_file, | |
| faiss_output_file=faiss_output_file, | |
| mode=mode, | |
| n_samples=4 | |
| ) | |
| else: | |
| print('{} exists; exiting'.format(interleaved_output_file)) | |