Spaces:
Runtime error
Runtime error
import os | |
import json | |
import csv | |
import yaml | |
from collections import defaultdict | |
import pickle | |
import glob | |
import math | |
from functools import partial | |
import sys | |
import io | |
import warnings | |
import random | |
import numpy as np | |
import torch | |
import laion_clap | |
import librosa | |
from pydub import AudioSegment | |
import soundfile as sf | |
import faiss | |
import multiprocessing | |
multiprocessing.set_start_method('spawn', force=True) | |
try: | |
from tqdm import tqdm | |
except: | |
tqdm = lambda x: x | |
def suppress_all_output(func): | |
def wrapper(*args, **kwargs): | |
old_stdout = sys.stdout | |
old_stderr = sys.stderr | |
sys.stdout = io.StringIO() | |
sys.stderr = io.StringIO() | |
old_fd_out = os.dup(1) | |
old_fd_err = os.dup(2) | |
null_fd = os.open(os.devnull, os.O_RDWR) | |
os.dup2(null_fd, 1) | |
os.dup2(null_fd, 2) | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
try: | |
result = func(*args, **kwargs) | |
finally: | |
os.dup2(old_fd_out, 1) | |
os.dup2(old_fd_err, 2) | |
os.close(null_fd) | |
os.close(old_fd_out) | |
os.close(old_fd_err) | |
sys.stdout = old_stdout | |
sys.stderr = old_stderr | |
return result | |
return wrapper | |
def filter_file(file_path, file_list, filename): | |
if file_list is not None: | |
if filename not in file_list: | |
print(filename, 'not exist') | |
return True | |
else: | |
if not os.path.exists(os.path.join(file_path, filename)): | |
print(filename, 'not exist') | |
return True | |
if os.path.getsize(os.path.join(file_path, filename)) < 16000: | |
print(filename, 'less than 0.5 to 1 second') | |
return True | |
return False | |
# ==================== Prepare dataset files from each data folder ==================== | |
EMOTION_MAP_DICT = { | |
'amused': 'amused' , | |
'anger': 'angry' , 'angry': 'angry' , | |
'anxious': 'anxious' , | |
'apologetic': 'apologetic' , | |
'assertive': 'assertive' , | |
'calm': 'calm' , | |
'concerned': 'concerned' , | |
'contempt': 'contempt' , | |
'disgust': 'disgusted' , 'disgusted': 'disgusted' , | |
'encouraging': 'encouraging' , | |
'excited': 'excited' , | |
'fear': 'fearful' , 'fearful': 'fearful' , | |
'frustated': 'frustated' , | |
'happy': 'happy' , 'joy': 'happy' , | |
'neutral': 'neutral' , | |
'sad': 'sad' , 'sadness': 'sad' , | |
'sleepy': 'sleepy' , | |
'surprise': 'surprised' , 'surprised': 'surprised' , | |
'pleasantly surprised': 'pleasantly surprised' , | |
} | |
def load_dataset_file(dataset_file): | |
with open(dataset_file) as f: | |
contents = f.read() | |
contents = json.loads(contents) | |
audio_files = [ | |
os.path.join( | |
contents["dataset_path"], | |
contents["split_path"], | |
contents["data"][str(i)]["name"] | |
) for i in range(contents["total_num"]) | |
] | |
return contents, audio_files | |
def compute_label_graph(dataset_name, dataset_path, top_n, output_file): | |
if os.path.exists(output_file): | |
print('loading precomputed graph:', output_file) | |
with open(output_file, 'r') as json_file: | |
graph = json.load(json_file) | |
else: | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
print('precomputing graph and save to:', output_file) | |
if dataset_name == 'AudioSetSL_singlelabel': | |
names = [] | |
with open(os.path.join(dataset_path, 'class_labels_indices.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in reader: | |
_, label, name = row # 123, /m/02zsn, "Female speech, woman speaking" | |
names += name.split(', ') | |
names = [x.lower().strip() for x in names] | |
elif dataset_name == "Clotho-AQA_singlelabel": | |
names = set([]) | |
with open(os.path.join(dataset_path, 'clotho_aqa_metadata.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
_, file_name, keywords, _, _, _, _ = row | |
names |= set(keywords.split(';')) | |
names = [x.lower().strip() for x in names] | |
names_embeddings = embedding_model.encode(names, convert_to_tensor=True) | |
similarity_matrix = util.pytorch_cos_sim(names_embeddings, names_embeddings) | |
similarity_threshold = 0.75 | |
n_items = len(names) | |
graph = {} | |
for i in range(n_items): | |
adjusted_top_n = min(top_n, n_items - 1) | |
values, indices = torch.topk(similarity_matrix[i], adjusted_top_n + 1, largest=True) | |
most_similar_items = [] | |
for value, idx in zip(values, indices): | |
if idx != i and value <= similarity_threshold: | |
most_similar_items.append(idx.item()) | |
if len(most_similar_items) == adjusted_top_n: | |
break | |
graph[names[i]] = [names[j] for j in most_similar_items] | |
with open(output_file, 'w') as json_file: | |
json.dump(graph, json_file) | |
# graph is a dict: key = each label, value = List[20 similar labels] | |
return graph | |
def prepare_files(dataset_name, dataset_path, split, flamingo_task, output_file): | |
assert not os.path.exists(output_file) | |
dataset_dic = { | |
"dataset_path": dataset_path, | |
"split": split, | |
"split_path": None, | |
"flamingo_task": "{}-{}".format(dataset_name, flamingo_task), | |
"total_num": 0, | |
"data": {} # {id: {'name': name, 'prompt': prompt, 'output': output}} | |
} | |
if dataset_name == "AudioSet": | |
assert flamingo_task == "EventClassification" | |
assert split == 'train' | |
map_split = lambda split: 'train_wav' if split == 'train' else '' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
dic = defaultdict(str) | |
with open(os.path.join(dataset_path, 'class_labels_indices.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
_, label, name = row # /m/02zsn,"Female speech, woman speaking" | |
dic[label] = name | |
with open(os.path.join(dataset_path, 'train.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
filename, _, _, labels = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r | |
filename = filename + '.wav' | |
if filter_file(file_path, file_list, filename): | |
continue | |
label_list = labels.split(",") | |
assert all(label in dic for label in label_list) | |
text_output = ", ".join([dic[label] for label in label_list]) | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'this is a sound of' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "AudioSetFull": | |
assert flamingo_task == "EventClassification" | |
assert split == 'train' | |
map_split = lambda split: '/mnt/fsx-main/rafaelvalle/datasets/audioset/unbalanced_train_segments/22khz' | |
file_path = map_split(split) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
dic_code2label = defaultdict(str) | |
with open(os.path.join(dataset_path, 'audioset-processing/data/class_labels_indices.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
_, code, name = row # /m/02zsn,"Female speech, woman speaking" | |
dic_code2label[code] = name | |
dic_filename2code = {} | |
with open(os.path.join(dataset_path, 'audioset-processing/data/unbalanced_train_segments.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
next(reader) | |
for row in tqdm(reader): | |
filename, _, _, codes = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r | |
filename = 'Y' + filename + '.wav' | |
dic_filename2code[filename] = codes.split(",") | |
for part in tqdm(range(41)): | |
part_str = str(part) | |
if len(part_str) == 1: | |
part_str = '0' + part_str | |
part_folder = 'unbalanced_train_segments_part{}'.format(part_str) | |
for filename in os.listdir(os.path.join(file_path, part_folder)): | |
if not filename.endswith('.wav'): | |
continue | |
if filter_file(file_path, file_list, os.path.join(part_folder, filename)): | |
continue | |
if filename not in dic_filename2code: | |
continue | |
text_output = ", ".join([dic_code2label[code] for code in dic_filename2code[filename] if code in dic_code2label]) | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'this is a sound of' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": os.path.join(part_folder, filename), | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "AudioSetFullwoAudioMusicCaps": | |
assert flamingo_task == "EventClassification" | |
assert split == 'train' | |
map_split = lambda split: '/mnt/fsx-main/rafaelvalle/datasets/audioset/unbalanced_train_segments/22khz' | |
file_path = map_split(split) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
print('extracting AudioCaps and MusicCaps ytid to avoid these samples') | |
audiocaps_ytid = [] | |
for f in ['audiocaps_dataset/train.csv', 'audiocaps_dataset/test.csv', 'audiocaps_dataset/val.csv']: | |
with open(os.path.join(dataset_path, f), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in reader: | |
_, ytid, _, _ = row | |
audiocaps_ytid.append('Y' + ytid + '.wav') | |
audiocaps_ytid = set(audiocaps_ytid) | |
musiccaps_ytid = [] | |
with open(os.path.join(dataset_path, 'musiccaps_dataset/musiccaps_manifest.json')) as f: | |
data = f.read() | |
musiccaps_list = json.loads(data) | |
for row in musiccaps_list: | |
musiccaps_ytid.append('Y' + row["ytid"] + '.wav') | |
musiccaps_ytid = set(musiccaps_ytid) | |
print('Will exclude {} samples from MusicCaps and {} from AudioCaps'.format(len(audiocaps_ytid), len(musiccaps_ytid))) | |
dic_code2label = defaultdict(str) | |
with open(os.path.join(dataset_path, '../AudioSetFull/audioset-processing/data/class_labels_indices.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
_, code, name = row # /m/02zsn,"Female speech, woman speaking" | |
dic_code2label[code] = name | |
dic_filename2code = {} | |
with open(os.path.join(dataset_path, '../AudioSetFull/audioset-processing/data/unbalanced_train_segments.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
next(reader) | |
for row in tqdm(reader): | |
filename, _, _, codes = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r | |
filename = 'Y' + filename + '.wav' | |
dic_filename2code[filename] = codes.split(",") | |
music_audio_caps_excluded = 0 | |
for part in tqdm(range(41)): | |
part_str = str(part) | |
if len(part_str) == 1: | |
part_str = '0' + part_str | |
part_folder = 'unbalanced_train_segments_part{}'.format(part_str) | |
for filename in os.listdir(os.path.join(file_path, part_folder)): | |
if not filename.endswith('.wav'): | |
continue | |
if filename in audiocaps_ytid or filename in musiccaps_ytid: | |
music_audio_caps_excluded += 1 | |
continue | |
if filter_file(file_path, file_list, os.path.join(part_folder, filename)): | |
continue | |
if filename not in dic_filename2code: | |
continue | |
text_output = ", ".join([dic_code2label[code] for code in dic_filename2code[filename] if code in dic_code2label]) | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'this is a sound of' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": os.path.join(part_folder, filename), | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "AudioSetSL_singlelabel": | |
import numpy as np | |
assert flamingo_task == "EventClassification" | |
assert split == 'train' | |
map_split = lambda split: '../AudioSet/train_wav' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
dic = defaultdict(str) | |
with open(os.path.join(dataset_path, 'class_labels_indices.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
_, label, name = row # /m/02zsn,"Female speech, woman speaking" | |
dic[label] = name | |
graph = compute_label_graph( | |
dataset_name, | |
dataset_path, | |
top_n=200, | |
output_file=os.path.join(dataset_path, 'label_graph.json') | |
) | |
with open(os.path.join(dataset_path, 'train.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
filename, _, _, labels = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r | |
filename = filename + '.wav' | |
if filter_file(file_path, file_list, filename): | |
continue | |
label_list = labels.split(",") | |
assert all(label in dic for label in label_list) | |
text_labels = ", ".join([dic[label] for label in label_list]).lower() | |
text_labels = text_labels.split(', ') | |
text_output = np.random.choice(text_labels) | |
if len(text_output) <= 1: | |
continue | |
num_options = np.random.choice( | |
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], | |
p=[ 0.05, 0.1, 0.1, 0.1, 0.1, | |
0.05, 0.05, 0.05, 0.1, 0.05, | |
0.05, 0.1, 0.05, 0.05] | |
) | |
negative_samples = [x for x in graph[text_output] if x not in set(text_labels)] | |
candidate_negative_labels = list(np.random.choice( | |
negative_samples[:num_options*10], | |
size=num_options-1, | |
replace=False | |
)) | |
if type(candidate_negative_labels) is str: | |
candidate_negative_labels = [candidate_negative_labels] | |
all_options = [text_output] + candidate_negative_labels | |
np.random.shuffle(all_options) | |
text_prompt = 'Classify this sound.\nOPTIONS:\n - {}.'.format( | |
'.\n - '.join(all_options) | |
) | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "AUDIOCAPS13k": | |
assert flamingo_task == 'AudioCaptioning' | |
map_split = lambda split: 'audio_32000Hz/{}'.format(split) | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) | |
with open(os.path.join( | |
dataset_path, | |
'{}_manifest.json'.format(split + ('_v2' if split == 'train' else '')) | |
), 'r') as f: | |
data = f.readlines() | |
data = [json.loads(row) for row in data] | |
for row in tqdm(data): | |
filename = row['audio_filepath'].split('/')[-1] | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = row['text'] | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'generate audio caption' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "audiocaps": | |
assert flamingo_task == 'AudioCaptioning' | |
map_split = lambda split: 'audio/{}'.format(split if split in ['train', 'test'] else 'valid') | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) | |
for filename in tqdm(file_list): | |
if filter_file(file_path, file_list, filename): | |
continue | |
with open(os.path.join(file_path, filename.replace('.flac', '.json')), 'r') as f: | |
data = json.load(f) | |
captions = data['text'] | |
for text_output in captions: | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'generate audio caption' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == 'BG-Gun-Sound-Dataset': | |
assert flamingo_task == "SoundClassification" | |
assert split in ["train", "test"] | |
map_split = lambda split: 'data/gun_sound_v2' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = os.listdir(file_path) | |
all_cates = set([]) | |
with open(os.path.join(dataset_path, 'data/v3_exp3_{}.csv'.format(split)), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
filename, cate, dist, dire = row | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = cate | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'What is the gun of this sound?' | |
all_cates.add(cate) | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
print(all_cates) | |
elif dataset_name == "BirdsDataset": | |
assert flamingo_task == "SoundClassification" | |
assert split == 'train' | |
map_split = lambda split: 'Voice_of_Birds' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
for bird_type in tqdm(os.listdir(file_path)): | |
bird_name = ' '.join(bird_type.split('_')[:-1]) | |
for filename in os.listdir(os.path.join(file_path, bird_type)): | |
if filter_file(file_path, file_list, os.path.join(bird_type, filename)): | |
continue | |
text_output = bird_name | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'What is the name of bird in this sound?' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": os.path.join(bird_type, filename), | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "BBCSoundEffects": | |
assert split in ['train'] | |
assert flamingo_task == 'AudioDescription' | |
map_split = lambda split: '../WavCaps/BBC_Sound_Effects_flac' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) | |
with open(os.path.join(dataset_path, 'BBCSoundDownloader/BBCSoundEffects.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
if len(row) != 7: | |
continue | |
filename, description, _, _, _, _, _ = row | |
filename = filename.replace('.wav', '.flac') | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = description | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'generate audio description' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "chime-home": | |
assert flamingo_task == "EventClassification" | |
assert split == 'train' | |
map_split = lambda split: 'chime_home/chunks' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file48k_list = list(filter(lambda x: x.endswith('48kHz.wav'), os.listdir(file_path))) | |
file16k_list = list(filter(lambda x: x.endswith('16kHz.wav'), os.listdir(file_path))) | |
csv_file_list = list(filter(lambda x: x.endswith('.csv'), os.listdir(file_path))) | |
label_mapping = { | |
'c': 'child speaking', | |
'm': 'male speaking', | |
'f': 'female speaking', | |
'p': 'human activity', | |
't': 'television', | |
'b': 'household appliances', | |
's': 'silence' | |
} | |
for csv_file in tqdm(csv_file_list): | |
with open(os.path.join(file_path, csv_file), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
labels = None | |
for row in reader: | |
if row[0] == 'majorityvote': | |
labels = row[1] | |
break | |
if labels is None or len(labels) == 0: | |
continue | |
filename = csv_file.replace('.csv', '.48kHz.wav') | |
if filter_file(file_path, file48k_list, filename): | |
filename = csv_file.replace('.csv', '.16kHz.wav') | |
if filter_file(file_path, file16k_list, filename): | |
continue | |
text_output = ", ".join([label_mapping[l] for l in labels if l in label_mapping]) | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'this is a sound of' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "CLAP_freesound": | |
assert flamingo_task == "AudioCaptioning" | |
assert split in ["train", "test"] | |
map_split = lambda split: os.path.join('freesound_no_overlap/split', split) | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) | |
with open(os.path.join( | |
dataset_path, | |
'freesound_no_overlap_meta.csv' | |
), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
if len(row[0].split('/')) != 2: | |
continue | |
if len(row) <= 1: | |
continue | |
file_split, filename = row[0].split('/') | |
if file_split != split: | |
continue | |
if filter_file(file_path, file_list, filename): | |
continue | |
caption_1 = row[1] # caption_2 = row[2] but not very good | |
text_output = caption_1 | |
if len(text_output) <= 2: | |
continue | |
text_prompt = 'generate audio caption' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "Clotho-AQA": | |
map_split = lambda split: 'audio_files' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
if flamingo_task == "EventClassification": | |
dic = defaultdict(str) | |
with open(os.path.join(dataset_path, 'clotho_aqa_metadata.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
_, file_name, keywords, _, _, _, _ = row | |
dic[file_name] = keywords.replace(';', ', ') | |
with open(os.path.join(dataset_path, 'clotho_aqa_{}.csv'.format(split)), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
filename = row[0] | |
if filename not in dic or filter_file(file_path, file_list, filename): | |
continue | |
text_output = dic[filename] | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'this is a sound of' | |
del dic[filename] | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif flamingo_task == "AQA": | |
dic_qa = defaultdict(list) | |
with open(os.path.join(dataset_path, 'clotho_aqa_{}.csv'.format(split)), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
filename, question, answer, confidence = row | |
dic_qa[(filename, question)].append((answer.lower(), confidence.lower())) | |
# get binary -> trinary | |
def preprocess(list_ans_conf): | |
assert set([x[1] for x in list_ans_conf]) <= set(['yes', 'no', 'maybe']) | |
answers = set([x[0].lower() for x in list_ans_conf]) | |
if answers <= set(['yes', 'no']): | |
if len(answers) > 1: | |
return ['unsure'] | |
else: | |
return list(answers) | |
else: | |
return list(answers) | |
# get majority vote | |
def majority_vote(list_ans_conf): | |
assert set([x[1] for x in list_ans_conf]) <= set(['yes', 'no', 'maybe']) | |
weight = {'yes': 1.0, 'no': 0.1, 'maybe': 0.6} | |
if set([x[0] for x in list_ans_conf]) <= set(['yes', 'no']): | |
score = {'yes': 1.0, 'no': -1.0} | |
pred = sum([score[x[0]] * weight[x[1]] for x in list_ans_conf]) | |
if pred > 0: | |
return ['yes'] | |
else: | |
return ['no'] | |
else: | |
return list(set([x[0] for x in list_ans_conf])) | |
for key in dic_qa: | |
filename, question = key | |
if filter_file(file_path, file_list, filename): | |
continue | |
if split == 'train': | |
answers = majority_vote(dic_qa[key]) # majority vote | |
else: | |
answers = [x[0].strip().lower() for x in dic_qa[key]] | |
answers = [', '.join(answers)] | |
for answer in answers: | |
text_output = answer | |
if len(text_output) <= 1: | |
continue | |
text_prompt = "Question: " + question | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "Clotho-AQA_singlelabel": | |
import numpy as np | |
assert flamingo_task == "EventClassification" | |
map_split = lambda split: '../Clotho-AQA/audio_files' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
dic = defaultdict(str) | |
with open(os.path.join(dataset_path, 'clotho_aqa_metadata.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
_, file_name, keywords, _, _, _, _ = row | |
dic[file_name] = keywords.split(';') | |
graph = compute_label_graph( | |
dataset_name, | |
dataset_path, | |
top_n=300, | |
output_file=os.path.join(dataset_path, 'label_graph.json') | |
) | |
with open(os.path.join(dataset_path, 'clotho_aqa_{}.csv'.format(split)), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
filename = row[0] | |
if filename not in dic or filter_file(file_path, file_list, filename): | |
continue | |
text_labels = [x.lower().strip() for x in dic[filename]] | |
del dic[filename] | |
for _ in range(6): | |
text_output = np.random.choice(text_labels) | |
if len(text_output) <= 1: | |
continue | |
num_options = np.random.choice( | |
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], | |
p=[ 0.05, 0.1, 0.1, 0.1, 0.1, | |
0.05, 0.05, 0.05, 0.1, 0.05, | |
0.05, 0.1, 0.05, 0.05] | |
) | |
negative_samples = [x for x in graph[text_output] if x not in set(text_labels)] | |
candidate_negative_labels = list(np.random.choice( | |
negative_samples[:num_options*20], | |
size=num_options-1, | |
replace=False | |
)) | |
if type(candidate_negative_labels) is str: | |
candidate_negative_labels = [candidate_negative_labels] | |
all_options = [text_output] + candidate_negative_labels | |
np.random.shuffle(all_options) | |
text_prompt = 'Classify this sound.\nOPTIONS:\n - {}.'.format( | |
'.\n - '.join(all_options) | |
) | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "Clotho-v2": | |
assert flamingo_task == "AudioCaptioning" | |
assert split in ["train", "val", "test"] | |
map_split = lambda split: 'development' if split == 'train' else ('validation' if split == "val" else "evaluation") | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
with open(os.path.join( | |
dataset_path, | |
'clotho_captions_{}.csv'.format(map_split(split)) | |
), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
filename = row[0] | |
if filter_file(file_path, file_list, filename): | |
continue | |
for text_output in row[1:]: | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'generate audio caption' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "CochlScene": | |
import ndjson | |
assert flamingo_task == "SceneClassification" | |
map_split = lambda split: split.capitalize() | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
with open(os.path.join(dataset_path, 'cochlscene_{}.ndjson'.format(split))) as ndjsonfile: | |
reader = ndjson.load(ndjsonfile) | |
for row in tqdm(reader): | |
filename = "/".join(row["audiopath"].split("/")[1:]) | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = row["labels"].lower() | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'this acoustic scene is' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "common-accent": | |
import ndjson | |
import re | |
assert flamingo_task == "AccentClassification" | |
assert split in ["train", "test"] | |
map_split = lambda split: '22khz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = os.listdir(file_path) | |
all_accent = [] | |
split_file = [f for f in os.listdir(dataset_path) if f.startswith(split) and f.endswith('.ndjson')][0] | |
with open(os.path.join(dataset_path, split_file)) as ndjsonfile: | |
reader = ndjson.load(ndjsonfile) | |
for row in tqdm(reader): | |
accent = row["accent"] | |
accent = re.sub(r'\(.*?\)', '', accent) | |
accent = accent.replace('English', '') | |
accent = accent.split(',') | |
accent = [x.strip() for x in accent if 'school' not in x] | |
all_accent += accent | |
filename = row["filename"] | |
if filter_file(file_path, file_list, filename): | |
continue | |
for accent_each in accent: | |
if accent_each == 'Javanese': | |
accent_each = 'Japanese' | |
if len(accent_each) > 25: | |
continue | |
text_output = accent_each | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'Classify the accent of this speech.' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
print('all accents:', list(set(all_accent))) | |
elif dataset_name == "CREMA-D": | |
assert flamingo_task == "EmotionClassification" | |
assert split in ["train"] | |
map_split = lambda split: 'AudioWAV' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
split_file = os.path.join( | |
dataset_path, | |
'crema-d_audiopath_text_sid_emotion_filelist.txt' | |
) | |
with open(split_file, 'r') as f: | |
data = f.readlines() | |
data = [x.replace('\n', '') for x in data] | |
for row in tqdm(data): | |
if row.count('|') != 3: | |
continue | |
filename, utterances, speaker, emotion = row.split('|') | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = emotion | |
text_prompt = 'this emotion is' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "DCASE17Task4": | |
assert flamingo_task == "SceneClassification" | |
assert split in ["test"] | |
map_split = lambda split: 'unbalanced_train_segments_testing_set_audio_formatted_and_segmented_downloads' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
split_file = os.path.join( | |
dataset_path, | |
'Task-4-Large-scale-weakly-supervised-sound-event-detection-for-smart-cars', | |
'groundtruth_release', | |
'groundtruth_strong_label_testing_set.csv' | |
) | |
dic = defaultdict(list) | |
all_labels = [] | |
with open(split_file, newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter='\t', quotechar='"') | |
for row in tqdm(reader): | |
filename = 'Y' + row[0] | |
label = row[-1] | |
if filter_file(file_path, file_list, filename): | |
continue | |
dic[filename] += label.split(', ') | |
all_labels += label.split(', ') | |
print('all labels:\n', ', '.join(list(set(all_labels)))) | |
for filename in dic: | |
text_output = ', '.join(list(set(dic[filename]))) | |
text_prompt = 'this acoustic scene is' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "emov-db": | |
assert flamingo_task == "EmotionClassification" | |
assert split in ["train", "val"] | |
map_split = lambda split: '22khz_from_16khz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
split_file = os.path.join( | |
dataset_path, | |
'cleaned_emov_db_audiopath_text_sid_emotion_duration_filelist_merged_{}.txt'.format(split) | |
) | |
with open(split_file, 'r') as f: | |
data = f.readlines() | |
data = [x.replace('\n', '') for x in data] | |
for row in tqdm(data): | |
if row.count('|') != 4: | |
continue | |
filename, utterances, speaker, emotion, duration = row.split('|') | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = emotion | |
text_output = EMOTION_MAP_DICT[text_output] | |
text_prompt = 'this emotion is' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "Epidemic_sound": | |
assert split == 'train' | |
assert flamingo_task in ["AudioCaptioning", "Tagging"] | |
map_split = lambda split: 'audio' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.mp3'), os.listdir(file_path))) | |
with open(os.path.join(dataset_path, 'Epidemic_all_debiased.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
if len(row) != 5: | |
continue | |
_, caption_1, caption_2, caption_t5, fileid = row | |
filename = '{}.mp3'.format(fileid) | |
if filter_file(file_path, file_list, filename): | |
continue | |
if flamingo_task == "AudioCaptioning": | |
text_output = caption_t5 | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'generate audio caption' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif flamingo_task == "Tagging": | |
if not caption_2.startswith('the sounds of'): | |
continue | |
caption_2 = caption_2.replace('the sounds of ', '') | |
caption_2 = caption_2.replace(', and', ',') | |
if len(caption_2) < 2: | |
continue | |
tags = caption_2.split(', ') | |
tags = list(map(lambda x: x.replace("'", "").strip().lower(), tags)) | |
text_output = '{}'.format(', '.join(tags)) | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'generate tags' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "ESC50": | |
assert flamingo_task in ["EventClassification"] | |
assert split == 'train' | |
map_split = lambda split: 'ESC-50-master/audio' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
with open(os.path.join(dataset_path, 'ESC-50-master/meta/esc50.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
if len(row) != 7: | |
continue | |
filename, fold, target, category, esc10, src_file, take = row | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = category.replace('_', ' ') | |
text_prompt = 'classify this sound.' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "FMA": | |
import ast | |
assert flamingo_task in ["GenreClassification"] | |
assert split == 'train' | |
map_split = lambda split: 'fma_large' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
with open(os.path.join(dataset_path, 'fma_metadata/raw_tracks.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
if len(row) != 39: | |
continue | |
track_id,album_id,album_title,album_url, \ | |
artist_id,artist_name,artist_url,artist_website, \ | |
license_image_file,license_image_file_large, \ | |
license_parent_id,license_title,license_url, \ | |
tags,track_bit_rate,track_comments,track_composer, \ | |
track_copyright_c,track_copyright_p,track_date_created,track_date_recorded, \ | |
track_disc_number,track_duration,track_explicit,track_explicit_notes, \ | |
track_favorites,track_file,track_genres,track_image_file,track_information, \ | |
track_instrumental,track_interest,track_language_code, \ | |
track_listens,track_lyricist,track_number,track_publisher,track_title,track_url = row | |
l = len(str(track_id)) | |
if l <= 3: | |
filename = '{}/{}.mp3'.format( | |
'000', | |
'0'*(6-l)+str(track_id) | |
) | |
else: | |
filename = '{}/{}.mp3'.format( | |
'0'*(6-l)+str(track_id)[:l-3], | |
'0'*(6-l)+str(track_id) | |
) | |
if filter_file(file_path, file_list, filename): | |
continue | |
if len(track_genres) == 0: | |
continue | |
track_genres = ast.literal_eval(track_genres) | |
genres = ', '.join([dic['genre_title'].lower().strip() for dic in track_genres]) | |
text_output = genres + '.' | |
text_prompt = "what is the genre of this music?" | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "FSD50k": | |
import ndjson | |
assert flamingo_task == "EventClassification" | |
assert split in ["train", "test"] | |
map_split = lambda split: '44khz/dev' if split == 'train' else '44khz/eval' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
with open(os.path.join(dataset_path, '{}.ndjson'.format(map_split(split).replace('44khz/', '')))) as ndjsonfile: | |
reader = ndjson.load(ndjsonfile) | |
for row in tqdm(reader): | |
filename = row["filepath"].split("/")[1] | |
if filter_file(file_path, file_list, filename): | |
continue | |
labels = [x.replace("_", " ").lower() for x in row["labels"]] | |
text_output = ", ".join(labels) | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'this is a sound of' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "GTZAN": | |
assert flamingo_task == "GenreClassification" | |
assert split in ["train"] | |
map_split = lambda split: 'gtzan/data/genres' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
for genre in os.listdir(file_path): | |
genre_wavs = [x for x in os.listdir(os.path.join(file_path, genre)) if x.endswith('.wav')] | |
for genre_wav in genre_wavs: | |
filename = os.path.join(genre, genre_wav) | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = genre | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'What is the genre of this music?' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "IEMOCAP": | |
assert flamingo_task == "EmotionClassification" | |
assert split in ["train", "test"] | |
map_split = lambda split: 'IEMOCAP_full_release/16khz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
def read_this_ndjson(file_path): | |
dic_list = [] | |
with open(file_path, 'r') as f: | |
for line in f: | |
turn_name = line.split("'turn_name': ")[-1].split(',')[0].replace("'", "") | |
emotion = line.split("'emotion': ")[-1].split(',')[0].replace("'", "") | |
dic = { | |
'turn_name': turn_name, | |
'emotion': emotion | |
} | |
dic_list.append(dic) | |
return dic_list | |
all_emotions = [] | |
meta_files = [x for x in os.listdir(os.path.join(dataset_path, 'IEMOCAP_full_release/ndjson')) if x.endswith('.ndjson')] | |
for meta_file in tqdm(meta_files): | |
main_folder = meta_file.split('_')[0] | |
sub_folder = (meta_file.split('.ndjson')[0])[len(main_folder)+1:] | |
if split == "train" and main_folder == "Session5": | |
continue | |
elif split == "test" and main_folder != "Session5": | |
continue | |
metadata_list = read_this_ndjson(os.path.join(dataset_path, 'IEMOCAP_full_release/ndjson', meta_file)) | |
for dic in metadata_list: | |
filename = os.path.join(main_folder, sub_folder, dic['turn_name']+'.wav') | |
if filter_file(file_path, file_list, filename): | |
continue | |
if dic['emotion'] in ['unknown', 'other']: | |
continue | |
text_output = dic['emotion'] | |
text_output = EMOTION_MAP_DICT[text_output] | |
all_emotions.append(text_output) | |
text_prompt = 'this emotion is' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
print('all emotions:', list(set(all_emotions))) | |
elif dataset_name == "jl-corpus": | |
assert flamingo_task == "EmotionClassification" | |
assert split in ["train", "val"] | |
map_split = lambda split: '44khz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
split_file = os.path.join( | |
dataset_path, | |
'jl-corpus_audiopath_text_sid_emotion_duration_{}_filelist.txt'.format(split) | |
) | |
with open(split_file, 'r') as f: | |
data = f.readlines() | |
data = [x.replace('\n', '') for x in data] | |
for row in tqdm(data): | |
if row.count('|') != 4: | |
continue | |
filename, utterances, speaker, emotion, duration = row.split('|') | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = emotion | |
text_output = EMOTION_MAP_DICT[text_output] | |
text_prompt = 'this emotion is' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "LP-MusicCaps-MC": | |
import pandas as pd | |
assert flamingo_task in ["AudioCaptioning"] | |
assert split in ["train", "test"] | |
map_split = lambda split: '../MusicCaps/44khz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
parquet_files = [f for f in os.listdir(os.path.join(dataset_path, 'data')) if f.endswith('.parquet') and f.startswith(split)] | |
print('parquet_files', parquet_files) | |
metadata_df = pd.concat([pd.read_parquet(os.path.join(dataset_path, 'data', f)) for f in parquet_files]) | |
for index, row in tqdm(metadata_df.iterrows()): | |
filename = row['ytid'] + '.wav' | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_prompt = 'generate audio caption' | |
for caption in [row['caption_writing'], row['caption_summary'], row['caption_paraphrase']]: | |
text_output = caption | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "LP-MusicCaps-MSD": | |
import pandas as pd | |
assert flamingo_task in ["AudioCaptioning"] | |
assert split in ["train", "test", "val"] | |
map_split = lambda split: '../MSD/mp3s_22khz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
parquet_files = [f for f in os.listdir(dataset_path) if f.endswith('.parquet') and f.startswith(split)] | |
print('parquet_files', parquet_files) | |
metadata_df = pd.concat([pd.read_parquet(os.path.join(dataset_path, f)) for f in parquet_files]) | |
for index, row in tqdm(metadata_df.iterrows()): | |
filename = row['path'] | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_prompt = 'generate audio caption' | |
for caption in [row['caption_writing'], row['caption_summary'], row['caption_paraphrase']]: | |
text_output = caption | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "LP-MusicCaps-MTT": | |
import pandas as pd | |
assert flamingo_task in ["AudioCaptioning"] | |
assert split in ["train", "test", "val"] | |
map_split = lambda split: '../MagnaTagATune/16khz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
parquet_files = [f for f in os.listdir(dataset_path) if f.endswith('.parquet') and f.startswith(split)] | |
print('parquet_files', parquet_files) | |
metadata_df = pd.concat([pd.read_parquet(os.path.join(dataset_path, f)) for f in parquet_files]) | |
for index, row in tqdm(metadata_df.iterrows()): | |
filename = row['path'] | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_prompt = 'generate audio caption' | |
for caption in [row['caption_writing'], row['caption_summary'], row['caption_paraphrase']]: | |
text_output = caption | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "MACS": | |
assert flamingo_task in ["AudioCaptioning", "Tagging"] | |
assert split == 'train' | |
map_split = lambda split: 'TAU_Urban_Acoustic_Scenes_2019/TAU-urban-acoustic-scenes-2019-development/audio' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
metadata_list = yaml.load(open(os.path.join(dataset_path, 'MACS.yaml')), Loader=yaml.FullLoader)['files'] | |
for file_metadata in tqdm(metadata_list): | |
filename = file_metadata['filename'] | |
if filter_file(file_path, file_list, filename): | |
continue | |
for each_annotated in file_metadata['annotations']: | |
caption = each_annotated['sentence'] | |
tags = ', '.join(each_annotated['tags']).replace('_', ' ') | |
if flamingo_task == "AudioCaptioning": | |
text_output = caption | |
text_prompt = 'generate audio caption' | |
elif flamingo_task == "Tagging": | |
raise NotImplementedError | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "Medley-solos-DB": | |
import ndjson | |
assert flamingo_task in ["InstrClassification"] | |
map_split = lambda split: '44khz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
with open(os.path.join(dataset_path, 'medleysolosdb_manifest.ndjson')) as ndjsonfile: | |
metadata_list = ndjson.load(ndjsonfile) | |
for file_metadata in tqdm(metadata_list): | |
subset = file_metadata['subset'] | |
if not subset.startswith(split): | |
continue | |
filename = file_metadata['filepath'] | |
if filter_file(file_path, file_list, filename): | |
continue | |
instrument = file_metadata["instrument"] | |
text_output = instrument | |
text_prompt = 'this music note is produced by' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "MELD": | |
import numpy as np | |
assert flamingo_task in ["EmotionClassification", "SentimentClassification"] | |
map_split = lambda split: '44khz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
split_file = os.path.join( | |
dataset_path, | |
'{}.txt'.format(split if split in ['train', 'test'] else 'dev') | |
) | |
with open(split_file, 'r') as f: | |
data = f.readlines() | |
data = [x.replace('\n', '') for x in data] | |
emotion_count = { | |
'neutral': 4703, 'happy': 1739, 'sad': 683, 'surprised': 1204, | |
'disgusted': 271, 'angry': 1108, 'fearful': 268, | |
} | |
sentiment_count = { | |
'neutral': 4703, 'positive': 2330, 'negative': 2943, | |
} | |
balancing_factor = 1 | |
for row in tqdm(data): | |
if row.count('|') != 4: | |
continue | |
filename, utterances, speaker, emotion, sentiment = row.split('|') | |
if filter_file(file_path, file_list, filename): | |
continue | |
if flamingo_task == "EmotionClassification": | |
text_output = emotion | |
text_output = EMOTION_MAP_DICT[text_output] | |
text_prompt = 'this emotion is' | |
if split == 'train': | |
balancing_factor = float(emotion_count['neutral']) / float(emotion_count[text_output]) | |
elif flamingo_task == "SentimentClassification": | |
text_output = sentiment | |
text_prompt = 'this sentiment is' | |
if split == 'train': | |
balancing_factor = float(sentiment_count['neutral']) / float(sentiment_count[text_output]) | |
if len(text_output) <= 1: | |
continue | |
for _ in range(int(np.floor(balancing_factor))): | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
if np.random.rand() < balancing_factor - np.floor(balancing_factor): | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "MSP-PODCAST-Publish-1.9": | |
assert flamingo_task == "EmotionClassification" | |
assert split in ["train", "val", "test"] | |
map_split = lambda split: 'Audio' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = glob.glob('{}/*/*.wav'.format(file_path)) | |
file_list = [x[len(file_path)+1:] for x in file_list] | |
subfolder_map = {} | |
for f in tqdm(file_list): | |
subfolder, filename = f.split('/') | |
subfolder_map[filename] = subfolder | |
file_list = None | |
emotion_dic = { | |
'A': 'Angry', | |
'S': 'Sad', | |
'H': 'Happy', | |
'U': 'Surprise', | |
'F': 'Fear', | |
'D': 'Disgust', | |
'C': 'Contempt', | |
'N': 'Neutral', | |
'O': 'Other', | |
'X': 'Not clear' | |
} | |
with open(os.path.join(dataset_path, 'Labels/labels_concensus.json')) as f: | |
data = f.read() | |
metadata_dic = json.loads(data) | |
for filename in tqdm(list(metadata_dic.keys())): | |
values = metadata_dic[filename] | |
if not values["Split_Set"].lower().startswith(split): | |
continue | |
if values["EmoClass"] in ["O", "X"] or values["EmoClass"] not in emotion_dic.keys(): | |
continue | |
subfolder = subfolder_map[filename] | |
filename = '{}/{}'.format(subfolder, filename) | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = emotion_dic[values["EmoClass"]].lower() | |
text_output = EMOTION_MAP_DICT[text_output] | |
text_prompt = 'this emotion is' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "mtg-jamendo": | |
import ndjson | |
assert flamingo_task == "MusicTagging" | |
assert split in ["train", "val"] | |
map_split = lambda split: '44khz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
with open(os.path.join(dataset_path, 'mtg_jamendo_{}_manifest.ndjson'.format(split))) as ndjsonfile: | |
reader = ndjson.load(ndjsonfile) | |
for row in tqdm(reader): | |
filename = row["audiopath"] | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = row["caption"] | |
text_prompt = 'generate music tags (genre, instrument, mood/theme)' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "MU-LLAMA": | |
assert flamingo_task in ['AQA'] | |
assert split in ['train', 'test'] | |
map_split = lambda split: 'MusicQA/audios' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
split_file = 'MusicQA/FinetuneMusicQA.json' if split == 'train' else 'MusicQA/EvalMusicQA.json' | |
with open(os.path.join(dataset_path, split_file), 'r') as f: | |
data = f.read() | |
metadata_list = json.loads(data) | |
for dic in tqdm(metadata_list): | |
filename = dic["audio_name"] | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_prompt = 'Question: ' + dic["conversation"][0]["value"].strip() | |
if not (text_prompt.endswith('.') or text_prompt.endswith('?')): | |
text_prompt = text_prompt + '.' | |
text_output = dic["conversation"][1]["value"].strip() | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "musdbhq": | |
assert flamingo_task in ["InstrClassification"] | |
assert split in ["train", "test", "val"] | |
map_split = lambda split: './' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
with open(os.path.join(dataset_path, 'file_list_44k_{}.txt'.format(split))) as f: | |
data = f.readlines() | |
data = [x.replace('\n', '') for x in data] | |
for row in tqdm(data): | |
if row.count('|') != 1: | |
continue | |
filename, duration = row.split('|') | |
duration = float(duration) | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = filename.split('/')[-1].split('.wav')[0] | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'this music is produced by' | |
segment_length = 10 | |
for audio_start_idx in range(int(duration // segment_length)): | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' '), | |
"audio_start": audio_start_idx * segment_length | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "Music-AVQA": | |
import ast | |
import re | |
assert flamingo_task in [ | |
"{}_{}".format(q, t) \ | |
for q in ['AQA', 'AVQA'] \ | |
for t in ['Comparative', 'Counting', 'Existential', 'Location', 'Temporal', 'All'] | |
] | |
def replace_bracketed_words(input_string, replacements): | |
def replacer(match): | |
word = next(replacements) | |
return word | |
replacements = iter(replacements) | |
output_string = re.sub(r'<[^>]*>', replacer, input_string) | |
return output_string | |
map_split = lambda split: 'audio' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
with open(os.path.join(dataset_path, 'MUSIC-AVQA/data/json/avqa-{}.json'.format(split)), 'r') as f: | |
data = f.read() | |
metadata_list = json.loads(data) | |
for dic in tqdm(metadata_list): | |
filename = dic["video_id"] + '.wav' | |
if filter_file(file_path, file_list, filename): | |
continue | |
types = ast.literal_eval(dic["type"]) | |
if 'Visual' in types: | |
continue | |
if flamingo_task.startswith('AQA_') and 'Audio-Visual' in types: | |
continue | |
if flamingo_task.startswith('AVQA_') and 'Audio' in types: | |
continue | |
t = flamingo_task.split('_')[1] | |
if (not t == 'All') and (not t in types): | |
continue | |
text_output = dic["anser"] | |
if len(text_output) <= 1: | |
continue | |
question = dic["question_content"].replace("\uff1f", '?') | |
templ_values = ast.literal_eval(dic["templ_values"]) | |
if len(templ_values) > 0: | |
question = replace_bracketed_words(question, templ_values) | |
text_prompt = "Question: " + question | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "MusicCaps": | |
assert flamingo_task in ["AudioCaptioning", "EventClassification"] | |
assert split in ["train", "test"] | |
map_split = lambda split: '44khz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
with open(os.path.join(dataset_path, 'musiccaps_manifest.json')) as f: | |
data = f.read() | |
metadata_list = json.loads(data) | |
for file_metadata in tqdm(metadata_list): | |
filename = file_metadata['filepath'] | |
if filter_file(file_path, file_list, filename): | |
continue | |
start_s, end_s = file_metadata["start_s"], file_metadata["end_s"] | |
caption = file_metadata["caption"] | |
audioset_positive_labels = file_metadata["audioset_positive_labels"] # audioset classes | |
aspect_list = file_metadata["aspect_list"] # annotated classes | |
if (split == 'train') == file_metadata["is_audioset_eval"]: | |
continue | |
if flamingo_task == "AudioCaptioning": | |
text_output = caption | |
text_prompt = 'generate audio caption' | |
elif flamingo_task == "EventClassification": | |
raise NotImplementedError | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "NonSpeech7k": | |
assert flamingo_task in ["EventClassification"] | |
assert split in ["train", "test"] | |
map_split = lambda split: split | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
all_classes = [] | |
with open(os.path.join(dataset_path, 'metadata of {} set.csv').format(split), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
filename, _, _, _, classname, _, _, _ = row | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = classname.lower() | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'this is a sound of' | |
all_classes.append(classname) | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
print('all classes:', list(set(all_classes))) | |
elif dataset_name == "NSynth": | |
import ndjson | |
assert flamingo_task in [ | |
"InstrClassification", | |
"PitchClassification", | |
"VelocityClassification", | |
"SourceClassification", | |
"QualityClassification", | |
"MIR" | |
] | |
assert split in ["train", "test", "val"] | |
map_split = lambda split: 'nsynth-{}/audio'.format('valid' if split == 'val' else split) | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
with open(os.path.join(dataset_path, map_split(split), '../examples.json')) as f: | |
data = f.read() | |
reader = json.loads(data) | |
for key in tqdm(reader): | |
filename = key + '.wav' | |
if filter_file(file_path, file_list, filename): | |
continue | |
if flamingo_task == "InstrClassification": | |
text_output = reader[key]["instrument_family_str"] | |
text_prompt = 'this music note is produced by' | |
elif flamingo_task == "PitchClassification": | |
text_output = str(reader[key]["pitch"]) | |
text_prompt = 'this music note has pitch' | |
elif flamingo_task == "VelocityClassification": | |
text_output = str(reader[key]["velocity"]) | |
text_prompt = 'this music note has velocity' | |
elif flamingo_task == "SourceClassification": | |
text_output = reader[key]["instrument_source_str"] | |
text_prompt = 'this music note has sonic source' | |
elif flamingo_task == "QualityClassification": | |
qualities_str = reader[key]["qualities_str"] | |
if len(qualities_str) >= 1: | |
text_output = ', '.join(qualities_str).replace('_', ' ') | |
else: | |
text_output = 'none' | |
text_prompt = 'this music note has sonic qualities' | |
elif flamingo_task == "MIR": | |
instrument = reader[key]["instrument_family_str"] | |
pitch = str(reader[key]["pitch"]) | |
velocity = str(reader[key]["velocity"]) | |
source = reader[key]["instrument_source_str"] | |
qualities_str = ', '.join(reader[key]["qualities_str"]).replace('_', ' ') | |
assert len(instrument) > 0 | |
text_output = 'produced by {}'.format(instrument) | |
if len(pitch) > 0: | |
text_output = text_output + ', pitch {}'.format(pitch) | |
if len(velocity) > 0: | |
text_output = text_output + ', velocity {}'.format(velocity) | |
if len(source) > 0: | |
text_output = text_output + ', source {}'.format(source) | |
if len(qualities_str) > 0: | |
text_output = text_output + ', and having qualities like {}'.format(qualities_str) | |
text_prompt = 'this music note is' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "OMGEmotion": | |
import numpy as np | |
import webrtcvad | |
import wave | |
from pydub import AudioSegment | |
assert flamingo_task == "EmotionClassification" | |
assert split in ["train", "val"] | |
def convert_to_wav(file_path): | |
audio = AudioSegment.from_file(file_path).set_frame_rate(16000).set_channels(1) | |
wav_path = file_path.rsplit('.', 1)[0] + "_converted.wav" | |
audio.export(wav_path, format="wav") | |
return wav_path | |
def contains_speech(file_path, aggressiveness=0): | |
# aggressiveness between 0 and 3, 0 for very clean speech, and 3 for noisy speech | |
wav_path = convert_to_wav(file_path) | |
vad = webrtcvad.Vad(aggressiveness) | |
with wave.open(wav_path, 'rb') as audio: | |
assert audio.getsampwidth() == 2, "Audio must be 16-bit" | |
assert audio.getnchannels() == 1, "Audio must be mono" | |
assert audio.getframerate() == 16000, "Audio must be sampled at 16kHz" | |
frame_duration = 10 # ms | |
frame_size = int(audio.getframerate() * frame_duration / 1000) | |
num_frames = int(audio.getnframes() / frame_size) | |
for _ in range(num_frames): | |
frame = audio.readframes(frame_size) | |
if vad.is_speech(frame, audio.getframerate()): | |
return True | |
return False | |
map_split = lambda split: 'processed-{}_utterance_data'.format(split) | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
dic_code2emotion = { | |
"0": "anger", | |
"1": "disgust", | |
"2": "fear", | |
"3": "happy", | |
"4": "neutral", | |
"5": "sad", | |
"6": "surprise", | |
} | |
all_emotions = [] | |
meta_file = os.path.join( | |
dataset_path, | |
'OMGEmotionChallenge', | |
'omg_{}Videos.csv'.format('Train' if split == 'train' else 'Validation') | |
) | |
with open(meta_file, newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
link, start, end, video, utterance, _, _, EmotionMaxVote = row | |
emotion = dic_code2emotion[str(EmotionMaxVote)] | |
filename = os.path.join(video, utterance.replace('.mp4', '.mp3')) | |
if filter_file(file_path, file_list, filename): | |
continue | |
if not contains_speech(os.path.join(file_path, filename)): | |
print('{} does not contain speech'.format(filename)) | |
continue | |
text_prompt = 'this emotion is' | |
text_output = emotion | |
if len(text_output) <= 1: | |
continue | |
all_emotions.append(EMOTION_MAP_DICT[emotion]) | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
print('all emotions:', list(set(all_emotions))) | |
elif dataset_name == "OpenAQA": | |
assert flamingo_task == 'AQA' | |
assert split == 'train' | |
map_split = lambda split: './' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
no_word_list = [ | |
'cannot determine', 'not provided', 'cannot be determined', 'sorry', 'i cannot', | |
'without more information', 'enough information', | |
'not possible', 'more context', 'enough', 'impossible', 'cannot be determined', | |
'without additional information', | |
'unclear', 'cannot', 'not clear', 'do not provide sufficient', 'does not provide', | |
'difficult to determine', 'no information provided', | |
"can't infer", "difficult to infer", "not specified", "no specific", "no information", | |
"without additional", 'it is difficult to', "no indication" | |
] | |
print('computing dic_audiosetfull_parts') | |
audiosetfull_root = '/mnt/fsx-main/rafaelvalle/datasets/audioset/unbalanced_train_segments/22khz/' | |
part_strings = [('0'*(2-len(str(p))) + str(p)) for p in range(41)] | |
dic_audiosetfull_parts = { | |
part: set(os.listdir(os.path.join(audiosetfull_root, 'unbalanced_train_segments_part{}'.format(part)))) \ | |
for part in part_strings | |
} | |
audioset20k_filelist = set(os.listdir(os.path.join(file_path, '../AudioSet/train_wav'))) | |
print('computing dic_clotho_filename') | |
clotho_files = os.listdir(os.path.join(dataset_path, '../Clotho-AQA/audio_files')) | |
dic_clotho_filename = { | |
'_'.join([s for s in f.split(' ') if len(s) > 0]): f \ | |
for f in clotho_files | |
} | |
print('reading open_ended/all_open_qa.json') | |
with open(os.path.join(dataset_path, 'openaqa/data/open_ended/all_open_qa.json'), 'r') as f: | |
data = f.read() | |
metadata_list = json.loads(data) | |
for dic in tqdm(metadata_list): | |
#keys: instruction, input, dataset, audio_id, output, task | |
text_output = dic["output"] | |
if len(text_output) <= 1: | |
continue | |
if any(word in text_output.lower() for word in no_word_list): | |
continue | |
question = dic["instruction"] | |
text_prompt = question | |
audio_id = dic["audio_id"] | |
subset = dic["dataset"] | |
if subset == 'clotho_development': | |
filename = audio_id.split('/')[-1] | |
processed_filename = '_'.join([s for s in filename.split('_') if len(s) > 0]) | |
if processed_filename in dic_clotho_filename: | |
filename = os.path.join( | |
'../Clotho-AQA/audio_files', | |
dic_clotho_filename[processed_filename] | |
) | |
else: | |
continue | |
elif subset in ['audiocaps_train', 'as_20k', 'as_strong_train']: | |
found = False | |
filename = audio_id.split('/')[-1].split('.flac')[0] + '.wav' | |
if filename in audioset20k_filelist: | |
filename = os.path.join('../AudioSet/train_wav', filename) | |
found = True | |
else: | |
filename = 'Y' + filename | |
for part in part_strings: | |
if filename in dic_audiosetfull_parts[part]: | |
filename = os.path.join( | |
audiosetfull_root, | |
'unbalanced_train_segments_part{}'.format(part), | |
filename | |
) | |
found = True | |
break | |
if not found: | |
print(filename, 'not found') | |
continue | |
elif subset == 'freesound_10s': | |
filename = os.path.join( | |
'../CLAP_freesound/freesound_no_overlap/split/train', | |
audio_id.split('/')[-1] | |
) | |
elif subset == 'vggsound_train': | |
continue | |
if filter_file(file_path, file_list, filename): | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "ravdess": | |
assert flamingo_task == "EmotionClassification" | |
assert split in ["train", "val"] | |
map_split = lambda split: '44khz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
split_file = os.path.join( | |
dataset_path, | |
'ravdess_audiopath_text_sid_emotion_duration_{}_filelist.txt'.format(split) | |
) | |
with open(split_file, 'r') as f: | |
data = f.readlines() | |
data = [x.replace('\n', '') for x in data] | |
for row in tqdm(data): | |
if row.count('|') != 4: | |
continue | |
filename, utterances, speaker, emotion, duration = row.split('|') | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = emotion | |
text_prompt = 'this emotion is' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "SongDescriber": | |
assert flamingo_task in ["AudioCaptioning"] | |
assert split in ["train"] | |
map_split = lambda split: './audio/audio' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
with open(os.path.join(dataset_path, 'song_describer.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
caption_id,track_id,caption,is_valid_subset,familiarity,artist_id,album_id,path,duration = row | |
filename = '{}/{}.2min.mp3'.format(track_id[-2:], track_id) | |
duration = float(duration) | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = caption | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'generate audio caption' | |
segment_length = 30 | |
for audio_start_idx in range(int(duration // segment_length)): | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' '), | |
"audio_start": audio_start_idx * segment_length | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "SONYC-UST": | |
import numpy as np | |
assert flamingo_task == "EventClassification" | |
assert split in ["train", "test", "val"] | |
map_split = lambda split: 'audio' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
all_labels = [] | |
with open(os.path.join(dataset_path, 'annotations.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
for idx, row in tqdm(enumerate(reader)): | |
if idx == 0: | |
header = np.array(row) | |
continue | |
if not row[0].startswith(split): | |
continue | |
filename = row[2] | |
if filter_file(file_path, file_list, filename): | |
continue | |
labels = [header[i] for i in range(12, len(header)-8) if str(row[i]) == "1"] | |
labels = [x.split("_")[1].replace('-', ' ').lower() for x in labels if 'X_' not in x] | |
all_labels += labels | |
text_output = ", ".join(labels) | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'this is a sound of' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
print('all labels:', list(set(all_labels))) | |
elif dataset_name == "SoundDescs": | |
import torch | |
assert flamingo_task in ["AudioDescription"] | |
assert split in ["train"] | |
map_split = lambda split: 'raw/audios' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
split_file = os.path.join(dataset_path, 'audio-retrieval-benchmark/data/SoundDescs/{}_list.txt'.format(split)) | |
with open(split_file, 'r') as f: | |
data = f.readlines() | |
names = set([x.replace('\n', '') for x in data]) | |
with open(os.path.join(dataset_path, 'audio-retrieval-benchmark/sounddescs_data/descriptions.pkl'), 'rb') as f: | |
obj = f.read() | |
metadata_dic = pickle.loads(obj, encoding='latin1') | |
for name in tqdm(names): | |
if name not in metadata_dic.keys(): | |
continue | |
filename = '{}.wav'.format(name) | |
if filter_file(file_path, file_list, filename): | |
continue | |
description = metadata_dic[name] | |
text_output = description | |
text_prompt = 'generate audio description' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "tess": | |
assert flamingo_task == "EmotionClassification" | |
assert split in ["train", "val"] | |
map_split = lambda split: '24414hz' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
split_file = os.path.join( | |
dataset_path, | |
'tess_audiopath_text_sid_emotion_duration_{}_filelist.txt'.format(split) | |
) | |
with open(split_file, 'r') as f: | |
data = f.readlines() | |
data = [x.replace('\n', '') for x in data] | |
for row in tqdm(data): | |
if row.count('|') != 4: | |
continue | |
filename, utterances, speaker, emotion, duration = row.split('|') | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = emotion.replace('_', ' ') | |
text_output = EMOTION_MAP_DICT[text_output] | |
text_prompt = 'this emotion is' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "UrbanSound8K": | |
assert flamingo_task in ["EventClassification"] | |
assert split in ["train"] | |
map_split = lambda split: 'audio' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = None | |
with open(os.path.join(dataset_path, 'metadata/UrbanSound8K.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
filename, fsID, start, end, salience, fold, classID, class_name = row | |
filename = 'fold{}/{}'.format(fold, filename) | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = class_name.replace("_", " ").lower() | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'this is a sound of' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "VocalSound": | |
assert flamingo_task == "VocalClassification" | |
map_split = lambda split: 'data_44k' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
split_file = os.path.join( | |
dataset_path, | |
'meta/{}_meta.csv'.format(split[:2] if split in ['train', 'test'] else split[:3]) | |
) | |
prefix = set([]) | |
with open(split_file, newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
for row in reader: | |
prefix.add(row[0]) | |
all_labels = set([]) | |
for filename in tqdm(file_list): | |
if not filename.split('_')[0] in prefix: | |
continue | |
if filter_file(file_path, file_list, filename): | |
continue | |
label = filename.split('_')[2].split('.wav')[0] | |
if label == 'throatclearing': | |
label = 'throat clearing' | |
text_output = label | |
text_prompt = 'this vocal sound is' | |
all_labels.add(label) | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
print('all labels:\n', "\'" + "\', \'".join(list(all_labels)) + "\'") | |
elif dataset_name.startswith("WavCaps"): | |
assert split in ["train"] | |
dataset_name, subset_name = dataset_name.split('-') | |
dataset_path = os.path.join( | |
'/'.join(dataset_path.split('/')[:-1]), | |
dataset_name | |
) | |
dataset_dic['dataset_path'] = dataset_path | |
map_split = lambda split: subset_name + '_flac' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path))) | |
metadata_file = os.listdir(os.path.join(dataset_path, "json_files", subset_name)) | |
metadata_file = [x for x in metadata_file if x.endswith('json')][0] | |
with open(os.path.join(dataset_path, "json_files", subset_name, metadata_file)) as f: | |
data = f.read() | |
reader = json.loads(data) | |
if subset_name == "AudioSet_SL": | |
assert flamingo_task == 'AudioCaptioning' | |
for sample in tqdm(reader['data']): | |
filename = sample["id"].replace('.wav', '.flac') | |
if filter_file(file_path, file_list, filename): | |
continue | |
text_output = sample['caption'] | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'generate audio caption' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
else: | |
assert flamingo_task in ['AudioCaptioning', 'AudioDescription'] | |
for sample in tqdm(reader['data']): | |
filename = sample["id"] + '.flac' | |
if filter_file(file_path, file_list, filename): | |
continue | |
if flamingo_task == 'AudioCaptioning': | |
text_output = sample['caption'] | |
text_prompt = 'generate audio caption' | |
elif flamingo_task == 'AudioDescription': | |
text_output = sample['description'] | |
text_prompt = 'generate audio description' | |
if len(text_output) <= 1: | |
continue | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif dataset_name == "WavText5K": | |
assert split == 'train' | |
map_split = lambda split: 'Webcrawl/44100/audios' | |
file_path = os.path.join( | |
dataset_path, | |
map_split(split) | |
) | |
assert os.path.exists(file_path), '{} not exist'.format(file_path) | |
dataset_dic["split_path"] = map_split(split) | |
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path))) | |
dic = defaultdict(str) | |
with open(os.path.join(dataset_path, 'WavText5K.csv'), newline='') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',', quotechar='"') | |
next(reader) | |
for row in tqdm(reader): | |
_, _, title, description, filename, tags = row | |
dic[filename] = (title, description, tags) | |
if flamingo_task == "AudioCaptioning": | |
for filename in tqdm(dic.keys()): | |
if filter_file(file_path, file_list, filename): | |
continue | |
title, description, tags = dic[filename] | |
text_output = description | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'generate audio caption' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
elif flamingo_task == "Tagging": | |
for filename in tqdm(dic.keys()): | |
if filter_file(file_path, file_list, filename): | |
continue | |
title, description, tags = dic[filename] | |
if len(tags) < 2 or not tags.startswith('[') or not tags.endswith(']'): | |
continue | |
tags = tags[1:-1].split(', ') | |
tags = list(map(lambda x: x.replace("'", ""), tags)) | |
text_output = '{}'.format(', '.join(tags)) | |
if len(text_output) <= 1: | |
continue | |
text_prompt = 'generate tags' | |
dataset_dic["data"][dataset_dic["total_num"]] = { | |
"name": filename, | |
"prompt": text_prompt, | |
"output": text_output.replace('\n', ' ') | |
} | |
dataset_dic["total_num"] += 1 | |
with open(output_file, 'w') as json_file: | |
json.dump(dataset_dic, json_file) | |
# ==================== Precompute CLAP and build Hashing ==================== | |
def int16_to_float32(x): | |
return (x / 32767.0).astype(np.float32) | |
def float32_to_int16(x): | |
x = np.clip(x, a_min=-1., a_max=1.) | |
return (x * 32767.).astype(np.int16) | |
def update_progress_bar(arg): | |
pbar.update() | |
def load_clap_model(checkpoint): | |
if checkpoint in ['630k-audioset-best.pt', '630k-best.pt', '630k-audioset-fusion-best.pt', '630k-fusion-best.pt']: | |
amodel = 'HTSAT-tiny' | |
elif checkpoint in ['music_speech_audioset_epoch_15_esc_89.98.pt']: | |
amodel = 'HTSAT-base' | |
else: | |
raise NotImplementedError | |
model = laion_clap.CLAP_Module( | |
enable_fusion=('fusion' in checkpoint.lower()), | |
amodel=amodel | |
).cuda() | |
model.load_ckpt(ckpt=os.path.join( | |
'/lustre/fsw/portfolios/adlr/users/zkong/audio-flamingo-data/laion-clap-pretrained/laion_clap', | |
checkpoint | |
)) | |
return model | |
def load_audio(file_path, target_sr=44100, duration=30.0, start=0.0): | |
if file_path.endswith('.mp3'): | |
audio = AudioSegment.from_file(file_path) | |
if len(audio) > (start + duration) * 1000: | |
audio = audio[start * 1000:(start + duration) * 1000] | |
if audio.frame_rate != target_sr: | |
audio = audio.set_frame_rate(target_sr) | |
if audio.channels > 1: | |
audio = audio.set_channels(1) | |
data = np.array(audio.get_array_of_samples()) | |
if audio.sample_width == 2: | |
data = data.astype(np.float32) / np.iinfo(np.int16).max | |
elif audio.sample_width == 4: | |
data = data.astype(np.float32) / np.iinfo(np.int32).max | |
else: | |
raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) | |
else: | |
with sf.SoundFile(file_path) as audio: | |
original_sr = audio.samplerate | |
channels = audio.channels | |
max_frames = int((start + duration) * original_sr) | |
audio.seek(int(start * original_sr)) | |
frames_to_read = min(max_frames, len(audio)) | |
data = audio.read(frames_to_read) | |
if data.max() > 1 or data.min() < -1: | |
data = data / max(abs(data.max()), abs(data.min())) | |
if original_sr != target_sr: | |
if channels == 1: | |
data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) | |
else: | |
data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] | |
else: | |
if channels != 1: | |
data = data.T[0] | |
if data.min() >= 0: | |
data = 2 * data / abs(data.max()) - 1.0 | |
else: | |
data = data / max(abs(data.max()), abs(data.min())) | |
return data | |
def compute_clap_each(audio_file, model): | |
try: | |
data = load_audio(audio_file, target_sr=48000, duration=10) | |
print(audio_file, 'loaded') | |
except Exception as e: | |
print(audio_file, 'unsuccessful due to', e) | |
return None | |
audio_data = data.reshape(1, -1) | |
audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().cuda() | |
audio_embed = model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True) | |
audio_embed = audio_embed.squeeze(0).cpu() | |
return audio_embed | |
def compute_embeddings_batch(batch, audio_files, model): | |
batch_results = [] | |
for i in batch: | |
if i >= len(audio_files): | |
break | |
audio_file = audio_files[i] | |
audio_embed = compute_clap_each(audio_file, model) | |
batch_results.append((i, audio_file, audio_embed)) | |
return batch_results | |
def precompute_clap_for_dataset( | |
dataset_file, | |
embedding_output_file, | |
checkpoint='630k-audioset-fusion-best.pt' | |
): | |
contents, audio_files = load_dataset_file(dataset_file) | |
model = load_clap_model(checkpoint) | |
if os.path.exists(embedding_output_file): | |
print('loading already computed embedding file from', embedding_output_file) | |
with open(embedding_output_file, 'rb') as f: | |
saved_data = pickle.load(f) | |
curr_audio_indices = saved_data['audio_indices'] | |
curr_audio_files = saved_data['audio_files'] | |
curr_audio_embeds = saved_data['audio_embeds'] | |
else: | |
curr_audio_indices = [] | |
curr_audio_files = [] | |
curr_audio_embeds = [] | |
print('computing embeddings for {}'.format(dataset_file)) | |
start_index = len(curr_audio_files) | |
remaining_indices = list(range(start_index, len(audio_files))) | |
batch_size = 128 | |
batches = [ | |
list(range(i, min(i + batch_size, len(audio_files)))) \ | |
for i in range(start_index, len(audio_files), batch_size) | |
] | |
with multiprocessing.Pool(processes=4) as pool: | |
for i, batch in enumerate(batches): | |
batch_results = pool.map( | |
partial(compute_embeddings_batch, model=model, audio_files=audio_files), | |
[batch] | |
) | |
for result in batch_results[0]: | |
curr_audio_indices.append(result[0]) | |
curr_audio_files.append(result[1]) | |
curr_audio_embeds.append(result[2]) | |
with open(embedding_output_file, 'wb') as f: | |
pickle.dump({ | |
'audio_indices': curr_audio_indices, | |
'audio_files': curr_audio_files, | |
'audio_embeds': curr_audio_embeds | |
}, f) | |
print(f"Saved progress for batch {i+1}/{len(batches)}: \ | |
audio_indices {len(curr_audio_indices)}, \ | |
audio_files {len(curr_audio_files)}, \ | |
audio_embeds {len(curr_audio_embeds)}*{curr_audio_embeds[0].shape}") | |
return curr_audio_indices, curr_audio_files, curr_audio_embeds | |
def build_faiss_index(embeddings): | |
d = embeddings[0].size(0) | |
index = faiss.IndexFlatL2(d) | |
np_embeddings = np.vstack([emb.numpy() for emb in embeddings]) | |
index.add(np_embeddings) | |
return index | |
def build_faiss_index_dataset( | |
dataset_file, | |
embedding_output_file, | |
faiss_output_file, | |
checkpoint='630k-audioset-fusion-best.pt', | |
only_precompute_clap=False | |
): | |
audio_indices, audio_files, audio_embeds = precompute_clap_for_dataset(dataset_file, embedding_output_file, checkpoint) | |
if only_precompute_clap: | |
return | |
valid_indices, valid_files, valid_embeds = [], [], [] | |
for audio_index, audio_file, audio_embed in zip(audio_indices, audio_files, audio_embeds): | |
if audio_embed is not None: | |
valid_indices.append(audio_index) | |
valid_files.append(audio_file) | |
valid_embeds.append(audio_embed) | |
print('building faiss index') | |
faiss_index = build_faiss_index(valid_embeds) | |
print('saving faiss index') | |
faiss.write_index(faiss_index, faiss_output_file) | |
with open(faiss_output_file + '.filenames', 'wb') as f: | |
pickle.dump({'audio_indices': valid_indices, 'audio_files': valid_files}, f) | |
# ==================== Generate interleaved dataset files ==================== | |
# only save index so that one can recover | |
def build_interleaved_dataset(dataset_file, interleaved_output_file, embedding_output_file, faiss_output_file, mode='random', n_samples=3): | |
contents, audio_files = load_dataset_file(dataset_file) | |
dataset_dic = { | |
"dataset_path": contents["dataset_path"], | |
"split": contents["split"], | |
"split_path": contents["split_path"], | |
"flamingo_task": contents["flamingo_task"], | |
"total_num": 0, | |
"interleaved_data": {}, | |
} | |
# interleaved_data is | |
# { | |
# id: { | |
# "generation_index_in_split": index of sample in the train or val or test.json, | |
# "fewshot_indices_in_train": list(indices) of few shot samples in train.json | |
# } | |
# } | |
if mode == 'knn': | |
model = load_clap_model(checkpoint='630k-audioset-fusion-best.pt') | |
print('loading already computed embedding file from', embedding_output_file) | |
with open(embedding_output_file, 'rb') as f: | |
precomputed_data = pickle.load(f) | |
precomputed_audio_indices = precomputed_data['audio_indices'] | |
precomputed_audio_files = precomputed_data['audio_files'] | |
precomputed_audio_embeds = precomputed_data['audio_embeds'] | |
faiss_index = faiss.read_index(faiss_output_file) | |
with open(faiss_output_file+'.filenames', 'rb') as f: | |
_data = pickle.load(f) | |
faiss_index_audio_indices = _data['audio_indices'] | |
faiss_index_audio_files = _data['audio_files'] | |
print('looking for few shot samples and building interleaved_{} data'.format(mode)) | |
for i in tqdm(range(contents["total_num"])): | |
if mode == 'random': | |
few_shot_indices = list(np.random.choice( | |
list(set(list(range(contents["total_num"]))) - set([i])), | |
size=n_samples-1, | |
replace=False | |
)) | |
few_shot_indices = list(map(int, few_shot_indices)) | |
elif mode == 'knn': | |
if audio_files[i] in precomputed_audio_files: | |
idx = precomputed_audio_files.index(audio_files[i]) | |
query_embedding_np = precomputed_audio_embeds[idx] | |
if query_embedding_np is not None: | |
query_embedding_np = query_embedding_np.numpy().reshape(1, -1) | |
else: | |
continue | |
else: | |
query_embedding_np = compute_clap_each(audio_files[i], model) | |
if query_embedding_np is not None: | |
query_embedding_np = query_embedding_np.numpy().reshape(1, -1) | |
else: | |
continue | |
distances, knn_indices = faiss_index.search(query_embedding_np, n_samples+50) | |
distances = distances[0] | |
knn_indices = knn_indices[0] | |
knn_filenames = [faiss_index_audio_files[idx] for idx in knn_indices] | |
combined = list(zip(knn_indices, knn_filenames)) | |
unique_indices = defaultdict(list) | |
for idx, filename in combined: | |
unique_indices[filename].append(idx) | |
cleared_knn_indices = [random.choice(unique_indices[filename]) for filename in unique_indices if filename != audio_files[i]] | |
if dataset_file.endswith('train.json'): | |
cleared_knn_indices = [knn_i for knn_i in cleared_knn_indices if faiss_index_audio_indices[knn_i] != i] | |
cleared_knn_indices = cleared_knn_indices[:n_samples-1] | |
np.random.shuffle(cleared_knn_indices) | |
few_shot_indices = [faiss_index_audio_indices[knn_i] for knn_i in cleared_knn_indices] | |
dataset_dic["interleaved_data"][dataset_dic["total_num"]] = { | |
"generation_index_in_split": i, | |
"fewshot_indices_in_train": few_shot_indices | |
} | |
dataset_dic["total_num"] += 1 | |
with open(interleaved_output_file, 'w') as json_file: | |
json.dump(dataset_dic, json_file) | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-d', '--dataset_name', type=str, help='dataset name') | |
parser.add_argument('-f', '--flamingo_task', type=str, help='flamingo task') | |
parser.add_argument('--interleave', action="store_true", help='prepare the interleave dataset') | |
args = parser.parse_args() | |
ROOT = "/lustre/fsw/portfolios/adlr/users/zkong" | |
dataset_root = os.path.join(ROOT, "datasets") | |
output_root = os.path.join(ROOT, "audio-flamingo-data/dataset_files") | |
os.makedirs(output_root, exist_ok=True) | |
dataset_name = args.dataset_name # "Clotho-v2", "AudioSet", "Clotho-AQA", "WavText5K", "FSD50k", ... | |
flamingo_task = args.flamingo_task # AQA, AudioCaptioning, EventClassification, SceneClassification, Tagging, ... | |
# must be train first otherwise there's no train.embedding for query | |
for split in ["train", "val", "test"]: | |
dataset_path = os.path.join(dataset_root, dataset_name) | |
output_folder = '{}-{}'.format(dataset_name, flamingo_task) | |
os.makedirs(os.path.join(output_root, output_folder), exist_ok=True) | |
dataset_file = os.path.join(output_root, output_folder, '{}.json'.format(split)) | |
if not os.path.exists(dataset_file): | |
try: | |
prepare_files(dataset_name, dataset_path, split, flamingo_task, dataset_file) | |
except AssertionError as e: | |
print('split {} not exist for {}: {}'.format(split, dataset_name, e)) | |
continue | |
else: | |
print('{} exists; exiting'.format(dataset_file)) | |
if args.interleave: | |
faiss_output_file = dataset_file.replace('{}.json'.format(split), "train_faiss_index.index") | |
embedding_output_file = dataset_file.replace('.json', ".embedding") | |
if split == 'train': | |
if (not os.path.exists(faiss_output_file)) or (not os.path.exists(faiss_output_file + '.filenames')): | |
build_faiss_index_dataset( | |
dataset_file, embedding_output_file, faiss_output_file, | |
only_precompute_clap=False | |
) | |
else: | |
print('{} exists; exiting'.format(faiss_output_file)) | |
else: | |
build_faiss_index_dataset( | |
dataset_file, embedding_output_file, | |
faiss_output_file=None, | |
only_precompute_clap=True | |
) | |
print('precomputing embedding for {} subset finished'.format(split)) | |
for mode in ['knn', 'random']: | |
interleaved_output_file = '/'.join( | |
dataset_file.split('/')[:-1] + \ | |
['interleaved_{}-'.format(mode) + dataset_file.split('/')[-1]] | |
) | |
if not os.path.exists(interleaved_output_file): | |
build_interleaved_dataset( | |
dataset_file=dataset_file, | |
interleaved_output_file=interleaved_output_file, | |
embedding_output_file=embedding_output_file, | |
faiss_output_file=faiss_output_file, | |
mode=mode, | |
n_samples=4 | |
) | |
else: | |
print('{} exists; exiting'.format(interleaved_output_file)) | |