audio-flamingo-2-0.5B / data /prepare_each_dataset.py
root
initial commit
a344f64
import os
import json
import csv
import yaml
from collections import defaultdict
import pickle
import glob
import math
from functools import partial
import sys
import io
import warnings
import random
import numpy as np
import torch
import laion_clap
import librosa
from pydub import AudioSegment
import soundfile as sf
import faiss
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)
try:
from tqdm import tqdm
except:
tqdm = lambda x: x
def suppress_all_output(func):
def wrapper(*args, **kwargs):
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = io.StringIO()
sys.stderr = io.StringIO()
old_fd_out = os.dup(1)
old_fd_err = os.dup(2)
null_fd = os.open(os.devnull, os.O_RDWR)
os.dup2(null_fd, 1)
os.dup2(null_fd, 2)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
result = func(*args, **kwargs)
finally:
os.dup2(old_fd_out, 1)
os.dup2(old_fd_err, 2)
os.close(null_fd)
os.close(old_fd_out)
os.close(old_fd_err)
sys.stdout = old_stdout
sys.stderr = old_stderr
return result
return wrapper
def filter_file(file_path, file_list, filename):
if file_list is not None:
if filename not in file_list:
print(filename, 'not exist')
return True
else:
if not os.path.exists(os.path.join(file_path, filename)):
print(filename, 'not exist')
return True
if os.path.getsize(os.path.join(file_path, filename)) < 16000:
print(filename, 'less than 0.5 to 1 second')
return True
return False
# ==================== Prepare dataset files from each data folder ====================
EMOTION_MAP_DICT = {
'amused': 'amused' ,
'anger': 'angry' , 'angry': 'angry' ,
'anxious': 'anxious' ,
'apologetic': 'apologetic' ,
'assertive': 'assertive' ,
'calm': 'calm' ,
'concerned': 'concerned' ,
'contempt': 'contempt' ,
'disgust': 'disgusted' , 'disgusted': 'disgusted' ,
'encouraging': 'encouraging' ,
'excited': 'excited' ,
'fear': 'fearful' , 'fearful': 'fearful' ,
'frustated': 'frustated' ,
'happy': 'happy' , 'joy': 'happy' ,
'neutral': 'neutral' ,
'sad': 'sad' , 'sadness': 'sad' ,
'sleepy': 'sleepy' ,
'surprise': 'surprised' , 'surprised': 'surprised' ,
'pleasantly surprised': 'pleasantly surprised' ,
}
def load_dataset_file(dataset_file):
with open(dataset_file) as f:
contents = f.read()
contents = json.loads(contents)
audio_files = [
os.path.join(
contents["dataset_path"],
contents["split_path"],
contents["data"][str(i)]["name"]
) for i in range(contents["total_num"])
]
return contents, audio_files
def compute_label_graph(dataset_name, dataset_path, top_n, output_file):
if os.path.exists(output_file):
print('loading precomputed graph:', output_file)
with open(output_file, 'r') as json_file:
graph = json.load(json_file)
else:
import torch
from sentence_transformers import SentenceTransformer, util
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print('precomputing graph and save to:', output_file)
if dataset_name == 'AudioSetSL_singlelabel':
names = []
with open(os.path.join(dataset_path, 'class_labels_indices.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in reader:
_, label, name = row # 123, /m/02zsn, "Female speech, woman speaking"
names += name.split(', ')
names = [x.lower().strip() for x in names]
elif dataset_name == "Clotho-AQA_singlelabel":
names = set([])
with open(os.path.join(dataset_path, 'clotho_aqa_metadata.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
_, file_name, keywords, _, _, _, _ = row
names |= set(keywords.split(';'))
names = [x.lower().strip() for x in names]
names_embeddings = embedding_model.encode(names, convert_to_tensor=True)
similarity_matrix = util.pytorch_cos_sim(names_embeddings, names_embeddings)
similarity_threshold = 0.75
n_items = len(names)
graph = {}
for i in range(n_items):
adjusted_top_n = min(top_n, n_items - 1)
values, indices = torch.topk(similarity_matrix[i], adjusted_top_n + 1, largest=True)
most_similar_items = []
for value, idx in zip(values, indices):
if idx != i and value <= similarity_threshold:
most_similar_items.append(idx.item())
if len(most_similar_items) == adjusted_top_n:
break
graph[names[i]] = [names[j] for j in most_similar_items]
with open(output_file, 'w') as json_file:
json.dump(graph, json_file)
# graph is a dict: key = each label, value = List[20 similar labels]
return graph
def prepare_files(dataset_name, dataset_path, split, flamingo_task, output_file):
assert not os.path.exists(output_file)
dataset_dic = {
"dataset_path": dataset_path,
"split": split,
"split_path": None,
"flamingo_task": "{}-{}".format(dataset_name, flamingo_task),
"total_num": 0,
"data": {} # {id: {'name': name, 'prompt': prompt, 'output': output}}
}
if dataset_name == "AudioSet":
assert flamingo_task == "EventClassification"
assert split == 'train'
map_split = lambda split: 'train_wav' if split == 'train' else ''
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
dic = defaultdict(str)
with open(os.path.join(dataset_path, 'class_labels_indices.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
_, label, name = row # /m/02zsn,"Female speech, woman speaking"
dic[label] = name
with open(os.path.join(dataset_path, 'train.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
filename, _, _, labels = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r
filename = filename + '.wav'
if filter_file(file_path, file_list, filename):
continue
label_list = labels.split(",")
assert all(label in dic for label in label_list)
text_output = ", ".join([dic[label] for label in label_list])
if len(text_output) <= 1:
continue
text_prompt = 'this is a sound of'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "AudioSetFull":
assert flamingo_task == "EventClassification"
assert split == 'train'
map_split = lambda split: '/mnt/fsx-main/rafaelvalle/datasets/audioset/unbalanced_train_segments/22khz'
file_path = map_split(split)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
dic_code2label = defaultdict(str)
with open(os.path.join(dataset_path, 'audioset-processing/data/class_labels_indices.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
_, code, name = row # /m/02zsn,"Female speech, woman speaking"
dic_code2label[code] = name
dic_filename2code = {}
with open(os.path.join(dataset_path, 'audioset-processing/data/unbalanced_train_segments.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
next(reader)
for row in tqdm(reader):
filename, _, _, codes = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r
filename = 'Y' + filename + '.wav'
dic_filename2code[filename] = codes.split(",")
for part in tqdm(range(41)):
part_str = str(part)
if len(part_str) == 1:
part_str = '0' + part_str
part_folder = 'unbalanced_train_segments_part{}'.format(part_str)
for filename in os.listdir(os.path.join(file_path, part_folder)):
if not filename.endswith('.wav'):
continue
if filter_file(file_path, file_list, os.path.join(part_folder, filename)):
continue
if filename not in dic_filename2code:
continue
text_output = ", ".join([dic_code2label[code] for code in dic_filename2code[filename] if code in dic_code2label])
if len(text_output) <= 1:
continue
text_prompt = 'this is a sound of'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": os.path.join(part_folder, filename),
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "AudioSetFullwoAudioMusicCaps":
assert flamingo_task == "EventClassification"
assert split == 'train'
map_split = lambda split: '/mnt/fsx-main/rafaelvalle/datasets/audioset/unbalanced_train_segments/22khz'
file_path = map_split(split)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
print('extracting AudioCaps and MusicCaps ytid to avoid these samples')
audiocaps_ytid = []
for f in ['audiocaps_dataset/train.csv', 'audiocaps_dataset/test.csv', 'audiocaps_dataset/val.csv']:
with open(os.path.join(dataset_path, f), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in reader:
_, ytid, _, _ = row
audiocaps_ytid.append('Y' + ytid + '.wav')
audiocaps_ytid = set(audiocaps_ytid)
musiccaps_ytid = []
with open(os.path.join(dataset_path, 'musiccaps_dataset/musiccaps_manifest.json')) as f:
data = f.read()
musiccaps_list = json.loads(data)
for row in musiccaps_list:
musiccaps_ytid.append('Y' + row["ytid"] + '.wav')
musiccaps_ytid = set(musiccaps_ytid)
print('Will exclude {} samples from MusicCaps and {} from AudioCaps'.format(len(audiocaps_ytid), len(musiccaps_ytid)))
dic_code2label = defaultdict(str)
with open(os.path.join(dataset_path, '../AudioSetFull/audioset-processing/data/class_labels_indices.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
_, code, name = row # /m/02zsn,"Female speech, woman speaking"
dic_code2label[code] = name
dic_filename2code = {}
with open(os.path.join(dataset_path, '../AudioSetFull/audioset-processing/data/unbalanced_train_segments.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
next(reader)
for row in tqdm(reader):
filename, _, _, codes = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r
filename = 'Y' + filename + '.wav'
dic_filename2code[filename] = codes.split(",")
music_audio_caps_excluded = 0
for part in tqdm(range(41)):
part_str = str(part)
if len(part_str) == 1:
part_str = '0' + part_str
part_folder = 'unbalanced_train_segments_part{}'.format(part_str)
for filename in os.listdir(os.path.join(file_path, part_folder)):
if not filename.endswith('.wav'):
continue
if filename in audiocaps_ytid or filename in musiccaps_ytid:
music_audio_caps_excluded += 1
continue
if filter_file(file_path, file_list, os.path.join(part_folder, filename)):
continue
if filename not in dic_filename2code:
continue
text_output = ", ".join([dic_code2label[code] for code in dic_filename2code[filename] if code in dic_code2label])
if len(text_output) <= 1:
continue
text_prompt = 'this is a sound of'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": os.path.join(part_folder, filename),
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "AudioSetSL_singlelabel":
import numpy as np
assert flamingo_task == "EventClassification"
assert split == 'train'
map_split = lambda split: '../AudioSet/train_wav'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
dic = defaultdict(str)
with open(os.path.join(dataset_path, 'class_labels_indices.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
_, label, name = row # /m/02zsn,"Female speech, woman speaking"
dic[label] = name
graph = compute_label_graph(
dataset_name,
dataset_path,
top_n=200,
output_file=os.path.join(dataset_path, 'label_graph.json')
)
with open(os.path.join(dataset_path, 'train.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
filename, _, _, labels = row # --aE2O5G5WE /m/03fwl,/m/04rlf,/m/09x0r
filename = filename + '.wav'
if filter_file(file_path, file_list, filename):
continue
label_list = labels.split(",")
assert all(label in dic for label in label_list)
text_labels = ", ".join([dic[label] for label in label_list]).lower()
text_labels = text_labels.split(', ')
text_output = np.random.choice(text_labels)
if len(text_output) <= 1:
continue
num_options = np.random.choice(
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
p=[ 0.05, 0.1, 0.1, 0.1, 0.1,
0.05, 0.05, 0.05, 0.1, 0.05,
0.05, 0.1, 0.05, 0.05]
)
negative_samples = [x for x in graph[text_output] if x not in set(text_labels)]
candidate_negative_labels = list(np.random.choice(
negative_samples[:num_options*10],
size=num_options-1,
replace=False
))
if type(candidate_negative_labels) is str:
candidate_negative_labels = [candidate_negative_labels]
all_options = [text_output] + candidate_negative_labels
np.random.shuffle(all_options)
text_prompt = 'Classify this sound.\nOPTIONS:\n - {}.'.format(
'.\n - '.join(all_options)
)
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "AUDIOCAPS13k":
assert flamingo_task == 'AudioCaptioning'
map_split = lambda split: 'audio_32000Hz/{}'.format(split)
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path)))
with open(os.path.join(
dataset_path,
'{}_manifest.json'.format(split + ('_v2' if split == 'train' else ''))
), 'r') as f:
data = f.readlines()
data = [json.loads(row) for row in data]
for row in tqdm(data):
filename = row['audio_filepath'].split('/')[-1]
if filter_file(file_path, file_list, filename):
continue
text_output = row['text']
if len(text_output) <= 1:
continue
text_prompt = 'generate audio caption'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "audiocaps":
assert flamingo_task == 'AudioCaptioning'
map_split = lambda split: 'audio/{}'.format(split if split in ['train', 'test'] else 'valid')
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path)))
for filename in tqdm(file_list):
if filter_file(file_path, file_list, filename):
continue
with open(os.path.join(file_path, filename.replace('.flac', '.json')), 'r') as f:
data = json.load(f)
captions = data['text']
for text_output in captions:
if len(text_output) <= 1:
continue
text_prompt = 'generate audio caption'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == 'BG-Gun-Sound-Dataset':
assert flamingo_task == "SoundClassification"
assert split in ["train", "test"]
map_split = lambda split: 'data/gun_sound_v2'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = os.listdir(file_path)
all_cates = set([])
with open(os.path.join(dataset_path, 'data/v3_exp3_{}.csv'.format(split)), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
filename, cate, dist, dire = row
if filter_file(file_path, file_list, filename):
continue
text_output = cate
if len(text_output) <= 1:
continue
text_prompt = 'What is the gun of this sound?'
all_cates.add(cate)
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
print(all_cates)
elif dataset_name == "BirdsDataset":
assert flamingo_task == "SoundClassification"
assert split == 'train'
map_split = lambda split: 'Voice_of_Birds'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
for bird_type in tqdm(os.listdir(file_path)):
bird_name = ' '.join(bird_type.split('_')[:-1])
for filename in os.listdir(os.path.join(file_path, bird_type)):
if filter_file(file_path, file_list, os.path.join(bird_type, filename)):
continue
text_output = bird_name
if len(text_output) <= 1:
continue
text_prompt = 'What is the name of bird in this sound?'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": os.path.join(bird_type, filename),
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "BBCSoundEffects":
assert split in ['train']
assert flamingo_task == 'AudioDescription'
map_split = lambda split: '../WavCaps/BBC_Sound_Effects_flac'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path)))
with open(os.path.join(dataset_path, 'BBCSoundDownloader/BBCSoundEffects.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
if len(row) != 7:
continue
filename, description, _, _, _, _, _ = row
filename = filename.replace('.wav', '.flac')
if filter_file(file_path, file_list, filename):
continue
text_output = description
if len(text_output) <= 1:
continue
text_prompt = 'generate audio description'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "chime-home":
assert flamingo_task == "EventClassification"
assert split == 'train'
map_split = lambda split: 'chime_home/chunks'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file48k_list = list(filter(lambda x: x.endswith('48kHz.wav'), os.listdir(file_path)))
file16k_list = list(filter(lambda x: x.endswith('16kHz.wav'), os.listdir(file_path)))
csv_file_list = list(filter(lambda x: x.endswith('.csv'), os.listdir(file_path)))
label_mapping = {
'c': 'child speaking',
'm': 'male speaking',
'f': 'female speaking',
'p': 'human activity',
't': 'television',
'b': 'household appliances',
's': 'silence'
}
for csv_file in tqdm(csv_file_list):
with open(os.path.join(file_path, csv_file), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
labels = None
for row in reader:
if row[0] == 'majorityvote':
labels = row[1]
break
if labels is None or len(labels) == 0:
continue
filename = csv_file.replace('.csv', '.48kHz.wav')
if filter_file(file_path, file48k_list, filename):
filename = csv_file.replace('.csv', '.16kHz.wav')
if filter_file(file_path, file16k_list, filename):
continue
text_output = ", ".join([label_mapping[l] for l in labels if l in label_mapping])
if len(text_output) <= 1:
continue
text_prompt = 'this is a sound of'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "CLAP_freesound":
assert flamingo_task == "AudioCaptioning"
assert split in ["train", "test"]
map_split = lambda split: os.path.join('freesound_no_overlap/split', split)
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path)))
with open(os.path.join(
dataset_path,
'freesound_no_overlap_meta.csv'
), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
if len(row[0].split('/')) != 2:
continue
if len(row) <= 1:
continue
file_split, filename = row[0].split('/')
if file_split != split:
continue
if filter_file(file_path, file_list, filename):
continue
caption_1 = row[1] # caption_2 = row[2] but not very good
text_output = caption_1
if len(text_output) <= 2:
continue
text_prompt = 'generate audio caption'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "Clotho-AQA":
map_split = lambda split: 'audio_files'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
if flamingo_task == "EventClassification":
dic = defaultdict(str)
with open(os.path.join(dataset_path, 'clotho_aqa_metadata.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
_, file_name, keywords, _, _, _, _ = row
dic[file_name] = keywords.replace(';', ', ')
with open(os.path.join(dataset_path, 'clotho_aqa_{}.csv'.format(split)), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
filename = row[0]
if filename not in dic or filter_file(file_path, file_list, filename):
continue
text_output = dic[filename]
if len(text_output) <= 1:
continue
text_prompt = 'this is a sound of'
del dic[filename]
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif flamingo_task == "AQA":
dic_qa = defaultdict(list)
with open(os.path.join(dataset_path, 'clotho_aqa_{}.csv'.format(split)), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
filename, question, answer, confidence = row
dic_qa[(filename, question)].append((answer.lower(), confidence.lower()))
# get binary -> trinary
def preprocess(list_ans_conf):
assert set([x[1] for x in list_ans_conf]) <= set(['yes', 'no', 'maybe'])
answers = set([x[0].lower() for x in list_ans_conf])
if answers <= set(['yes', 'no']):
if len(answers) > 1:
return ['unsure']
else:
return list(answers)
else:
return list(answers)
# get majority vote
def majority_vote(list_ans_conf):
assert set([x[1] for x in list_ans_conf]) <= set(['yes', 'no', 'maybe'])
weight = {'yes': 1.0, 'no': 0.1, 'maybe': 0.6}
if set([x[0] for x in list_ans_conf]) <= set(['yes', 'no']):
score = {'yes': 1.0, 'no': -1.0}
pred = sum([score[x[0]] * weight[x[1]] for x in list_ans_conf])
if pred > 0:
return ['yes']
else:
return ['no']
else:
return list(set([x[0] for x in list_ans_conf]))
for key in dic_qa:
filename, question = key
if filter_file(file_path, file_list, filename):
continue
if split == 'train':
answers = majority_vote(dic_qa[key]) # majority vote
else:
answers = [x[0].strip().lower() for x in dic_qa[key]]
answers = [', '.join(answers)]
for answer in answers:
text_output = answer
if len(text_output) <= 1:
continue
text_prompt = "Question: " + question
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "Clotho-AQA_singlelabel":
import numpy as np
assert flamingo_task == "EventClassification"
map_split = lambda split: '../Clotho-AQA/audio_files'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
dic = defaultdict(str)
with open(os.path.join(dataset_path, 'clotho_aqa_metadata.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
_, file_name, keywords, _, _, _, _ = row
dic[file_name] = keywords.split(';')
graph = compute_label_graph(
dataset_name,
dataset_path,
top_n=300,
output_file=os.path.join(dataset_path, 'label_graph.json')
)
with open(os.path.join(dataset_path, 'clotho_aqa_{}.csv'.format(split)), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
filename = row[0]
if filename not in dic or filter_file(file_path, file_list, filename):
continue
text_labels = [x.lower().strip() for x in dic[filename]]
del dic[filename]
for _ in range(6):
text_output = np.random.choice(text_labels)
if len(text_output) <= 1:
continue
num_options = np.random.choice(
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
p=[ 0.05, 0.1, 0.1, 0.1, 0.1,
0.05, 0.05, 0.05, 0.1, 0.05,
0.05, 0.1, 0.05, 0.05]
)
negative_samples = [x for x in graph[text_output] if x not in set(text_labels)]
candidate_negative_labels = list(np.random.choice(
negative_samples[:num_options*20],
size=num_options-1,
replace=False
))
if type(candidate_negative_labels) is str:
candidate_negative_labels = [candidate_negative_labels]
all_options = [text_output] + candidate_negative_labels
np.random.shuffle(all_options)
text_prompt = 'Classify this sound.\nOPTIONS:\n - {}.'.format(
'.\n - '.join(all_options)
)
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "Clotho-v2":
assert flamingo_task == "AudioCaptioning"
assert split in ["train", "val", "test"]
map_split = lambda split: 'development' if split == 'train' else ('validation' if split == "val" else "evaluation")
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
with open(os.path.join(
dataset_path,
'clotho_captions_{}.csv'.format(map_split(split))
), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
filename = row[0]
if filter_file(file_path, file_list, filename):
continue
for text_output in row[1:]:
if len(text_output) <= 1:
continue
text_prompt = 'generate audio caption'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "CochlScene":
import ndjson
assert flamingo_task == "SceneClassification"
map_split = lambda split: split.capitalize()
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
with open(os.path.join(dataset_path, 'cochlscene_{}.ndjson'.format(split))) as ndjsonfile:
reader = ndjson.load(ndjsonfile)
for row in tqdm(reader):
filename = "/".join(row["audiopath"].split("/")[1:])
if filter_file(file_path, file_list, filename):
continue
text_output = row["labels"].lower()
if len(text_output) <= 1:
continue
text_prompt = 'this acoustic scene is'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "common-accent":
import ndjson
import re
assert flamingo_task == "AccentClassification"
assert split in ["train", "test"]
map_split = lambda split: '22khz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = os.listdir(file_path)
all_accent = []
split_file = [f for f in os.listdir(dataset_path) if f.startswith(split) and f.endswith('.ndjson')][0]
with open(os.path.join(dataset_path, split_file)) as ndjsonfile:
reader = ndjson.load(ndjsonfile)
for row in tqdm(reader):
accent = row["accent"]
accent = re.sub(r'\(.*?\)', '', accent)
accent = accent.replace('English', '')
accent = accent.split(',')
accent = [x.strip() for x in accent if 'school' not in x]
all_accent += accent
filename = row["filename"]
if filter_file(file_path, file_list, filename):
continue
for accent_each in accent:
if accent_each == 'Javanese':
accent_each = 'Japanese'
if len(accent_each) > 25:
continue
text_output = accent_each
if len(text_output) <= 1:
continue
text_prompt = 'Classify the accent of this speech.'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
print('all accents:', list(set(all_accent)))
elif dataset_name == "CREMA-D":
assert flamingo_task == "EmotionClassification"
assert split in ["train"]
map_split = lambda split: 'AudioWAV'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
split_file = os.path.join(
dataset_path,
'crema-d_audiopath_text_sid_emotion_filelist.txt'
)
with open(split_file, 'r') as f:
data = f.readlines()
data = [x.replace('\n', '') for x in data]
for row in tqdm(data):
if row.count('|') != 3:
continue
filename, utterances, speaker, emotion = row.split('|')
if filter_file(file_path, file_list, filename):
continue
text_output = emotion
text_prompt = 'this emotion is'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "DCASE17Task4":
assert flamingo_task == "SceneClassification"
assert split in ["test"]
map_split = lambda split: 'unbalanced_train_segments_testing_set_audio_formatted_and_segmented_downloads'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
split_file = os.path.join(
dataset_path,
'Task-4-Large-scale-weakly-supervised-sound-event-detection-for-smart-cars',
'groundtruth_release',
'groundtruth_strong_label_testing_set.csv'
)
dic = defaultdict(list)
all_labels = []
with open(split_file, newline='') as csvfile:
reader = csv.reader(csvfile, delimiter='\t', quotechar='"')
for row in tqdm(reader):
filename = 'Y' + row[0]
label = row[-1]
if filter_file(file_path, file_list, filename):
continue
dic[filename] += label.split(', ')
all_labels += label.split(', ')
print('all labels:\n', ', '.join(list(set(all_labels))))
for filename in dic:
text_output = ', '.join(list(set(dic[filename])))
text_prompt = 'this acoustic scene is'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "emov-db":
assert flamingo_task == "EmotionClassification"
assert split in ["train", "val"]
map_split = lambda split: '22khz_from_16khz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
split_file = os.path.join(
dataset_path,
'cleaned_emov_db_audiopath_text_sid_emotion_duration_filelist_merged_{}.txt'.format(split)
)
with open(split_file, 'r') as f:
data = f.readlines()
data = [x.replace('\n', '') for x in data]
for row in tqdm(data):
if row.count('|') != 4:
continue
filename, utterances, speaker, emotion, duration = row.split('|')
if filter_file(file_path, file_list, filename):
continue
text_output = emotion
text_output = EMOTION_MAP_DICT[text_output]
text_prompt = 'this emotion is'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "Epidemic_sound":
assert split == 'train'
assert flamingo_task in ["AudioCaptioning", "Tagging"]
map_split = lambda split: 'audio'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.mp3'), os.listdir(file_path)))
with open(os.path.join(dataset_path, 'Epidemic_all_debiased.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
if len(row) != 5:
continue
_, caption_1, caption_2, caption_t5, fileid = row
filename = '{}.mp3'.format(fileid)
if filter_file(file_path, file_list, filename):
continue
if flamingo_task == "AudioCaptioning":
text_output = caption_t5
if len(text_output) <= 1:
continue
text_prompt = 'generate audio caption'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif flamingo_task == "Tagging":
if not caption_2.startswith('the sounds of'):
continue
caption_2 = caption_2.replace('the sounds of ', '')
caption_2 = caption_2.replace(', and', ',')
if len(caption_2) < 2:
continue
tags = caption_2.split(', ')
tags = list(map(lambda x: x.replace("'", "").strip().lower(), tags))
text_output = '{}'.format(', '.join(tags))
if len(text_output) <= 1:
continue
text_prompt = 'generate tags'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "ESC50":
assert flamingo_task in ["EventClassification"]
assert split == 'train'
map_split = lambda split: 'ESC-50-master/audio'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
with open(os.path.join(dataset_path, 'ESC-50-master/meta/esc50.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
if len(row) != 7:
continue
filename, fold, target, category, esc10, src_file, take = row
if filter_file(file_path, file_list, filename):
continue
text_output = category.replace('_', ' ')
text_prompt = 'classify this sound.'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "FMA":
import ast
assert flamingo_task in ["GenreClassification"]
assert split == 'train'
map_split = lambda split: 'fma_large'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
with open(os.path.join(dataset_path, 'fma_metadata/raw_tracks.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
if len(row) != 39:
continue
track_id,album_id,album_title,album_url, \
artist_id,artist_name,artist_url,artist_website, \
license_image_file,license_image_file_large, \
license_parent_id,license_title,license_url, \
tags,track_bit_rate,track_comments,track_composer, \
track_copyright_c,track_copyright_p,track_date_created,track_date_recorded, \
track_disc_number,track_duration,track_explicit,track_explicit_notes, \
track_favorites,track_file,track_genres,track_image_file,track_information, \
track_instrumental,track_interest,track_language_code, \
track_listens,track_lyricist,track_number,track_publisher,track_title,track_url = row
l = len(str(track_id))
if l <= 3:
filename = '{}/{}.mp3'.format(
'000',
'0'*(6-l)+str(track_id)
)
else:
filename = '{}/{}.mp3'.format(
'0'*(6-l)+str(track_id)[:l-3],
'0'*(6-l)+str(track_id)
)
if filter_file(file_path, file_list, filename):
continue
if len(track_genres) == 0:
continue
track_genres = ast.literal_eval(track_genres)
genres = ', '.join([dic['genre_title'].lower().strip() for dic in track_genres])
text_output = genres + '.'
text_prompt = "what is the genre of this music?"
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "FSD50k":
import ndjson
assert flamingo_task == "EventClassification"
assert split in ["train", "test"]
map_split = lambda split: '44khz/dev' if split == 'train' else '44khz/eval'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
with open(os.path.join(dataset_path, '{}.ndjson'.format(map_split(split).replace('44khz/', '')))) as ndjsonfile:
reader = ndjson.load(ndjsonfile)
for row in tqdm(reader):
filename = row["filepath"].split("/")[1]
if filter_file(file_path, file_list, filename):
continue
labels = [x.replace("_", " ").lower() for x in row["labels"]]
text_output = ", ".join(labels)
if len(text_output) <= 1:
continue
text_prompt = 'this is a sound of'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "GTZAN":
assert flamingo_task == "GenreClassification"
assert split in ["train"]
map_split = lambda split: 'gtzan/data/genres'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
for genre in os.listdir(file_path):
genre_wavs = [x for x in os.listdir(os.path.join(file_path, genre)) if x.endswith('.wav')]
for genre_wav in genre_wavs:
filename = os.path.join(genre, genre_wav)
if filter_file(file_path, file_list, filename):
continue
text_output = genre
if len(text_output) <= 1:
continue
text_prompt = 'What is the genre of this music?'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "IEMOCAP":
assert flamingo_task == "EmotionClassification"
assert split in ["train", "test"]
map_split = lambda split: 'IEMOCAP_full_release/16khz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
def read_this_ndjson(file_path):
dic_list = []
with open(file_path, 'r') as f:
for line in f:
turn_name = line.split("'turn_name': ")[-1].split(',')[0].replace("'", "")
emotion = line.split("'emotion': ")[-1].split(',')[0].replace("'", "")
dic = {
'turn_name': turn_name,
'emotion': emotion
}
dic_list.append(dic)
return dic_list
all_emotions = []
meta_files = [x for x in os.listdir(os.path.join(dataset_path, 'IEMOCAP_full_release/ndjson')) if x.endswith('.ndjson')]
for meta_file in tqdm(meta_files):
main_folder = meta_file.split('_')[0]
sub_folder = (meta_file.split('.ndjson')[0])[len(main_folder)+1:]
if split == "train" and main_folder == "Session5":
continue
elif split == "test" and main_folder != "Session5":
continue
metadata_list = read_this_ndjson(os.path.join(dataset_path, 'IEMOCAP_full_release/ndjson', meta_file))
for dic in metadata_list:
filename = os.path.join(main_folder, sub_folder, dic['turn_name']+'.wav')
if filter_file(file_path, file_list, filename):
continue
if dic['emotion'] in ['unknown', 'other']:
continue
text_output = dic['emotion']
text_output = EMOTION_MAP_DICT[text_output]
all_emotions.append(text_output)
text_prompt = 'this emotion is'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
print('all emotions:', list(set(all_emotions)))
elif dataset_name == "jl-corpus":
assert flamingo_task == "EmotionClassification"
assert split in ["train", "val"]
map_split = lambda split: '44khz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
split_file = os.path.join(
dataset_path,
'jl-corpus_audiopath_text_sid_emotion_duration_{}_filelist.txt'.format(split)
)
with open(split_file, 'r') as f:
data = f.readlines()
data = [x.replace('\n', '') for x in data]
for row in tqdm(data):
if row.count('|') != 4:
continue
filename, utterances, speaker, emotion, duration = row.split('|')
if filter_file(file_path, file_list, filename):
continue
text_output = emotion
text_output = EMOTION_MAP_DICT[text_output]
text_prompt = 'this emotion is'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "LP-MusicCaps-MC":
import pandas as pd
assert flamingo_task in ["AudioCaptioning"]
assert split in ["train", "test"]
map_split = lambda split: '../MusicCaps/44khz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
parquet_files = [f for f in os.listdir(os.path.join(dataset_path, 'data')) if f.endswith('.parquet') and f.startswith(split)]
print('parquet_files', parquet_files)
metadata_df = pd.concat([pd.read_parquet(os.path.join(dataset_path, 'data', f)) for f in parquet_files])
for index, row in tqdm(metadata_df.iterrows()):
filename = row['ytid'] + '.wav'
if filter_file(file_path, file_list, filename):
continue
text_prompt = 'generate audio caption'
for caption in [row['caption_writing'], row['caption_summary'], row['caption_paraphrase']]:
text_output = caption
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "LP-MusicCaps-MSD":
import pandas as pd
assert flamingo_task in ["AudioCaptioning"]
assert split in ["train", "test", "val"]
map_split = lambda split: '../MSD/mp3s_22khz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
parquet_files = [f for f in os.listdir(dataset_path) if f.endswith('.parquet') and f.startswith(split)]
print('parquet_files', parquet_files)
metadata_df = pd.concat([pd.read_parquet(os.path.join(dataset_path, f)) for f in parquet_files])
for index, row in tqdm(metadata_df.iterrows()):
filename = row['path']
if filter_file(file_path, file_list, filename):
continue
text_prompt = 'generate audio caption'
for caption in [row['caption_writing'], row['caption_summary'], row['caption_paraphrase']]:
text_output = caption
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "LP-MusicCaps-MTT":
import pandas as pd
assert flamingo_task in ["AudioCaptioning"]
assert split in ["train", "test", "val"]
map_split = lambda split: '../MagnaTagATune/16khz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
parquet_files = [f for f in os.listdir(dataset_path) if f.endswith('.parquet') and f.startswith(split)]
print('parquet_files', parquet_files)
metadata_df = pd.concat([pd.read_parquet(os.path.join(dataset_path, f)) for f in parquet_files])
for index, row in tqdm(metadata_df.iterrows()):
filename = row['path']
if filter_file(file_path, file_list, filename):
continue
text_prompt = 'generate audio caption'
for caption in [row['caption_writing'], row['caption_summary'], row['caption_paraphrase']]:
text_output = caption
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "MACS":
assert flamingo_task in ["AudioCaptioning", "Tagging"]
assert split == 'train'
map_split = lambda split: 'TAU_Urban_Acoustic_Scenes_2019/TAU-urban-acoustic-scenes-2019-development/audio'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
metadata_list = yaml.load(open(os.path.join(dataset_path, 'MACS.yaml')), Loader=yaml.FullLoader)['files']
for file_metadata in tqdm(metadata_list):
filename = file_metadata['filename']
if filter_file(file_path, file_list, filename):
continue
for each_annotated in file_metadata['annotations']:
caption = each_annotated['sentence']
tags = ', '.join(each_annotated['tags']).replace('_', ' ')
if flamingo_task == "AudioCaptioning":
text_output = caption
text_prompt = 'generate audio caption'
elif flamingo_task == "Tagging":
raise NotImplementedError
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "Medley-solos-DB":
import ndjson
assert flamingo_task in ["InstrClassification"]
map_split = lambda split: '44khz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
with open(os.path.join(dataset_path, 'medleysolosdb_manifest.ndjson')) as ndjsonfile:
metadata_list = ndjson.load(ndjsonfile)
for file_metadata in tqdm(metadata_list):
subset = file_metadata['subset']
if not subset.startswith(split):
continue
filename = file_metadata['filepath']
if filter_file(file_path, file_list, filename):
continue
instrument = file_metadata["instrument"]
text_output = instrument
text_prompt = 'this music note is produced by'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "MELD":
import numpy as np
assert flamingo_task in ["EmotionClassification", "SentimentClassification"]
map_split = lambda split: '44khz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
split_file = os.path.join(
dataset_path,
'{}.txt'.format(split if split in ['train', 'test'] else 'dev')
)
with open(split_file, 'r') as f:
data = f.readlines()
data = [x.replace('\n', '') for x in data]
emotion_count = {
'neutral': 4703, 'happy': 1739, 'sad': 683, 'surprised': 1204,
'disgusted': 271, 'angry': 1108, 'fearful': 268,
}
sentiment_count = {
'neutral': 4703, 'positive': 2330, 'negative': 2943,
}
balancing_factor = 1
for row in tqdm(data):
if row.count('|') != 4:
continue
filename, utterances, speaker, emotion, sentiment = row.split('|')
if filter_file(file_path, file_list, filename):
continue
if flamingo_task == "EmotionClassification":
text_output = emotion
text_output = EMOTION_MAP_DICT[text_output]
text_prompt = 'this emotion is'
if split == 'train':
balancing_factor = float(emotion_count['neutral']) / float(emotion_count[text_output])
elif flamingo_task == "SentimentClassification":
text_output = sentiment
text_prompt = 'this sentiment is'
if split == 'train':
balancing_factor = float(sentiment_count['neutral']) / float(sentiment_count[text_output])
if len(text_output) <= 1:
continue
for _ in range(int(np.floor(balancing_factor))):
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
if np.random.rand() < balancing_factor - np.floor(balancing_factor):
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "MSP-PODCAST-Publish-1.9":
assert flamingo_task == "EmotionClassification"
assert split in ["train", "val", "test"]
map_split = lambda split: 'Audio'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = glob.glob('{}/*/*.wav'.format(file_path))
file_list = [x[len(file_path)+1:] for x in file_list]
subfolder_map = {}
for f in tqdm(file_list):
subfolder, filename = f.split('/')
subfolder_map[filename] = subfolder
file_list = None
emotion_dic = {
'A': 'Angry',
'S': 'Sad',
'H': 'Happy',
'U': 'Surprise',
'F': 'Fear',
'D': 'Disgust',
'C': 'Contempt',
'N': 'Neutral',
'O': 'Other',
'X': 'Not clear'
}
with open(os.path.join(dataset_path, 'Labels/labels_concensus.json')) as f:
data = f.read()
metadata_dic = json.loads(data)
for filename in tqdm(list(metadata_dic.keys())):
values = metadata_dic[filename]
if not values["Split_Set"].lower().startswith(split):
continue
if values["EmoClass"] in ["O", "X"] or values["EmoClass"] not in emotion_dic.keys():
continue
subfolder = subfolder_map[filename]
filename = '{}/{}'.format(subfolder, filename)
if filter_file(file_path, file_list, filename):
continue
text_output = emotion_dic[values["EmoClass"]].lower()
text_output = EMOTION_MAP_DICT[text_output]
text_prompt = 'this emotion is'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "mtg-jamendo":
import ndjson
assert flamingo_task == "MusicTagging"
assert split in ["train", "val"]
map_split = lambda split: '44khz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
with open(os.path.join(dataset_path, 'mtg_jamendo_{}_manifest.ndjson'.format(split))) as ndjsonfile:
reader = ndjson.load(ndjsonfile)
for row in tqdm(reader):
filename = row["audiopath"]
if filter_file(file_path, file_list, filename):
continue
text_output = row["caption"]
text_prompt = 'generate music tags (genre, instrument, mood/theme)'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "MU-LLAMA":
assert flamingo_task in ['AQA']
assert split in ['train', 'test']
map_split = lambda split: 'MusicQA/audios'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
split_file = 'MusicQA/FinetuneMusicQA.json' if split == 'train' else 'MusicQA/EvalMusicQA.json'
with open(os.path.join(dataset_path, split_file), 'r') as f:
data = f.read()
metadata_list = json.loads(data)
for dic in tqdm(metadata_list):
filename = dic["audio_name"]
if filter_file(file_path, file_list, filename):
continue
text_prompt = 'Question: ' + dic["conversation"][0]["value"].strip()
if not (text_prompt.endswith('.') or text_prompt.endswith('?')):
text_prompt = text_prompt + '.'
text_output = dic["conversation"][1]["value"].strip()
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "musdbhq":
assert flamingo_task in ["InstrClassification"]
assert split in ["train", "test", "val"]
map_split = lambda split: './'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
with open(os.path.join(dataset_path, 'file_list_44k_{}.txt'.format(split))) as f:
data = f.readlines()
data = [x.replace('\n', '') for x in data]
for row in tqdm(data):
if row.count('|') != 1:
continue
filename, duration = row.split('|')
duration = float(duration)
if filter_file(file_path, file_list, filename):
continue
text_output = filename.split('/')[-1].split('.wav')[0]
if len(text_output) <= 1:
continue
text_prompt = 'this music is produced by'
segment_length = 10
for audio_start_idx in range(int(duration // segment_length)):
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' '),
"audio_start": audio_start_idx * segment_length
}
dataset_dic["total_num"] += 1
elif dataset_name == "Music-AVQA":
import ast
import re
assert flamingo_task in [
"{}_{}".format(q, t) \
for q in ['AQA', 'AVQA'] \
for t in ['Comparative', 'Counting', 'Existential', 'Location', 'Temporal', 'All']
]
def replace_bracketed_words(input_string, replacements):
def replacer(match):
word = next(replacements)
return word
replacements = iter(replacements)
output_string = re.sub(r'<[^>]*>', replacer, input_string)
return output_string
map_split = lambda split: 'audio'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
with open(os.path.join(dataset_path, 'MUSIC-AVQA/data/json/avqa-{}.json'.format(split)), 'r') as f:
data = f.read()
metadata_list = json.loads(data)
for dic in tqdm(metadata_list):
filename = dic["video_id"] + '.wav'
if filter_file(file_path, file_list, filename):
continue
types = ast.literal_eval(dic["type"])
if 'Visual' in types:
continue
if flamingo_task.startswith('AQA_') and 'Audio-Visual' in types:
continue
if flamingo_task.startswith('AVQA_') and 'Audio' in types:
continue
t = flamingo_task.split('_')[1]
if (not t == 'All') and (not t in types):
continue
text_output = dic["anser"]
if len(text_output) <= 1:
continue
question = dic["question_content"].replace("\uff1f", '?')
templ_values = ast.literal_eval(dic["templ_values"])
if len(templ_values) > 0:
question = replace_bracketed_words(question, templ_values)
text_prompt = "Question: " + question
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "MusicCaps":
assert flamingo_task in ["AudioCaptioning", "EventClassification"]
assert split in ["train", "test"]
map_split = lambda split: '44khz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
with open(os.path.join(dataset_path, 'musiccaps_manifest.json')) as f:
data = f.read()
metadata_list = json.loads(data)
for file_metadata in tqdm(metadata_list):
filename = file_metadata['filepath']
if filter_file(file_path, file_list, filename):
continue
start_s, end_s = file_metadata["start_s"], file_metadata["end_s"]
caption = file_metadata["caption"]
audioset_positive_labels = file_metadata["audioset_positive_labels"] # audioset classes
aspect_list = file_metadata["aspect_list"] # annotated classes
if (split == 'train') == file_metadata["is_audioset_eval"]:
continue
if flamingo_task == "AudioCaptioning":
text_output = caption
text_prompt = 'generate audio caption'
elif flamingo_task == "EventClassification":
raise NotImplementedError
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "NonSpeech7k":
assert flamingo_task in ["EventClassification"]
assert split in ["train", "test"]
map_split = lambda split: split
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
all_classes = []
with open(os.path.join(dataset_path, 'metadata of {} set.csv').format(split), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
filename, _, _, _, classname, _, _, _ = row
if filter_file(file_path, file_list, filename):
continue
text_output = classname.lower()
if len(text_output) <= 1:
continue
text_prompt = 'this is a sound of'
all_classes.append(classname)
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
print('all classes:', list(set(all_classes)))
elif dataset_name == "NSynth":
import ndjson
assert flamingo_task in [
"InstrClassification",
"PitchClassification",
"VelocityClassification",
"SourceClassification",
"QualityClassification",
"MIR"
]
assert split in ["train", "test", "val"]
map_split = lambda split: 'nsynth-{}/audio'.format('valid' if split == 'val' else split)
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
with open(os.path.join(dataset_path, map_split(split), '../examples.json')) as f:
data = f.read()
reader = json.loads(data)
for key in tqdm(reader):
filename = key + '.wav'
if filter_file(file_path, file_list, filename):
continue
if flamingo_task == "InstrClassification":
text_output = reader[key]["instrument_family_str"]
text_prompt = 'this music note is produced by'
elif flamingo_task == "PitchClassification":
text_output = str(reader[key]["pitch"])
text_prompt = 'this music note has pitch'
elif flamingo_task == "VelocityClassification":
text_output = str(reader[key]["velocity"])
text_prompt = 'this music note has velocity'
elif flamingo_task == "SourceClassification":
text_output = reader[key]["instrument_source_str"]
text_prompt = 'this music note has sonic source'
elif flamingo_task == "QualityClassification":
qualities_str = reader[key]["qualities_str"]
if len(qualities_str) >= 1:
text_output = ', '.join(qualities_str).replace('_', ' ')
else:
text_output = 'none'
text_prompt = 'this music note has sonic qualities'
elif flamingo_task == "MIR":
instrument = reader[key]["instrument_family_str"]
pitch = str(reader[key]["pitch"])
velocity = str(reader[key]["velocity"])
source = reader[key]["instrument_source_str"]
qualities_str = ', '.join(reader[key]["qualities_str"]).replace('_', ' ')
assert len(instrument) > 0
text_output = 'produced by {}'.format(instrument)
if len(pitch) > 0:
text_output = text_output + ', pitch {}'.format(pitch)
if len(velocity) > 0:
text_output = text_output + ', velocity {}'.format(velocity)
if len(source) > 0:
text_output = text_output + ', source {}'.format(source)
if len(qualities_str) > 0:
text_output = text_output + ', and having qualities like {}'.format(qualities_str)
text_prompt = 'this music note is'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "OMGEmotion":
import numpy as np
import webrtcvad
import wave
from pydub import AudioSegment
assert flamingo_task == "EmotionClassification"
assert split in ["train", "val"]
def convert_to_wav(file_path):
audio = AudioSegment.from_file(file_path).set_frame_rate(16000).set_channels(1)
wav_path = file_path.rsplit('.', 1)[0] + "_converted.wav"
audio.export(wav_path, format="wav")
return wav_path
def contains_speech(file_path, aggressiveness=0):
# aggressiveness between 0 and 3, 0 for very clean speech, and 3 for noisy speech
wav_path = convert_to_wav(file_path)
vad = webrtcvad.Vad(aggressiveness)
with wave.open(wav_path, 'rb') as audio:
assert audio.getsampwidth() == 2, "Audio must be 16-bit"
assert audio.getnchannels() == 1, "Audio must be mono"
assert audio.getframerate() == 16000, "Audio must be sampled at 16kHz"
frame_duration = 10 # ms
frame_size = int(audio.getframerate() * frame_duration / 1000)
num_frames = int(audio.getnframes() / frame_size)
for _ in range(num_frames):
frame = audio.readframes(frame_size)
if vad.is_speech(frame, audio.getframerate()):
return True
return False
map_split = lambda split: 'processed-{}_utterance_data'.format(split)
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
dic_code2emotion = {
"0": "anger",
"1": "disgust",
"2": "fear",
"3": "happy",
"4": "neutral",
"5": "sad",
"6": "surprise",
}
all_emotions = []
meta_file = os.path.join(
dataset_path,
'OMGEmotionChallenge',
'omg_{}Videos.csv'.format('Train' if split == 'train' else 'Validation')
)
with open(meta_file, newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
link, start, end, video, utterance, _, _, EmotionMaxVote = row
emotion = dic_code2emotion[str(EmotionMaxVote)]
filename = os.path.join(video, utterance.replace('.mp4', '.mp3'))
if filter_file(file_path, file_list, filename):
continue
if not contains_speech(os.path.join(file_path, filename)):
print('{} does not contain speech'.format(filename))
continue
text_prompt = 'this emotion is'
text_output = emotion
if len(text_output) <= 1:
continue
all_emotions.append(EMOTION_MAP_DICT[emotion])
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
print('all emotions:', list(set(all_emotions)))
elif dataset_name == "OpenAQA":
assert flamingo_task == 'AQA'
assert split == 'train'
map_split = lambda split: './'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
no_word_list = [
'cannot determine', 'not provided', 'cannot be determined', 'sorry', 'i cannot',
'without more information', 'enough information',
'not possible', 'more context', 'enough', 'impossible', 'cannot be determined',
'without additional information',
'unclear', 'cannot', 'not clear', 'do not provide sufficient', 'does not provide',
'difficult to determine', 'no information provided',
"can't infer", "difficult to infer", "not specified", "no specific", "no information",
"without additional", 'it is difficult to', "no indication"
]
print('computing dic_audiosetfull_parts')
audiosetfull_root = '/mnt/fsx-main/rafaelvalle/datasets/audioset/unbalanced_train_segments/22khz/'
part_strings = [('0'*(2-len(str(p))) + str(p)) for p in range(41)]
dic_audiosetfull_parts = {
part: set(os.listdir(os.path.join(audiosetfull_root, 'unbalanced_train_segments_part{}'.format(part)))) \
for part in part_strings
}
audioset20k_filelist = set(os.listdir(os.path.join(file_path, '../AudioSet/train_wav')))
print('computing dic_clotho_filename')
clotho_files = os.listdir(os.path.join(dataset_path, '../Clotho-AQA/audio_files'))
dic_clotho_filename = {
'_'.join([s for s in f.split(' ') if len(s) > 0]): f \
for f in clotho_files
}
print('reading open_ended/all_open_qa.json')
with open(os.path.join(dataset_path, 'openaqa/data/open_ended/all_open_qa.json'), 'r') as f:
data = f.read()
metadata_list = json.loads(data)
for dic in tqdm(metadata_list):
#keys: instruction, input, dataset, audio_id, output, task
text_output = dic["output"]
if len(text_output) <= 1:
continue
if any(word in text_output.lower() for word in no_word_list):
continue
question = dic["instruction"]
text_prompt = question
audio_id = dic["audio_id"]
subset = dic["dataset"]
if subset == 'clotho_development':
filename = audio_id.split('/')[-1]
processed_filename = '_'.join([s for s in filename.split('_') if len(s) > 0])
if processed_filename in dic_clotho_filename:
filename = os.path.join(
'../Clotho-AQA/audio_files',
dic_clotho_filename[processed_filename]
)
else:
continue
elif subset in ['audiocaps_train', 'as_20k', 'as_strong_train']:
found = False
filename = audio_id.split('/')[-1].split('.flac')[0] + '.wav'
if filename in audioset20k_filelist:
filename = os.path.join('../AudioSet/train_wav', filename)
found = True
else:
filename = 'Y' + filename
for part in part_strings:
if filename in dic_audiosetfull_parts[part]:
filename = os.path.join(
audiosetfull_root,
'unbalanced_train_segments_part{}'.format(part),
filename
)
found = True
break
if not found:
print(filename, 'not found')
continue
elif subset == 'freesound_10s':
filename = os.path.join(
'../CLAP_freesound/freesound_no_overlap/split/train',
audio_id.split('/')[-1]
)
elif subset == 'vggsound_train':
continue
if filter_file(file_path, file_list, filename):
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "ravdess":
assert flamingo_task == "EmotionClassification"
assert split in ["train", "val"]
map_split = lambda split: '44khz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
split_file = os.path.join(
dataset_path,
'ravdess_audiopath_text_sid_emotion_duration_{}_filelist.txt'.format(split)
)
with open(split_file, 'r') as f:
data = f.readlines()
data = [x.replace('\n', '') for x in data]
for row in tqdm(data):
if row.count('|') != 4:
continue
filename, utterances, speaker, emotion, duration = row.split('|')
if filter_file(file_path, file_list, filename):
continue
text_output = emotion
text_prompt = 'this emotion is'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "SongDescriber":
assert flamingo_task in ["AudioCaptioning"]
assert split in ["train"]
map_split = lambda split: './audio/audio'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
with open(os.path.join(dataset_path, 'song_describer.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
caption_id,track_id,caption,is_valid_subset,familiarity,artist_id,album_id,path,duration = row
filename = '{}/{}.2min.mp3'.format(track_id[-2:], track_id)
duration = float(duration)
if filter_file(file_path, file_list, filename):
continue
text_output = caption
if len(text_output) <= 1:
continue
text_prompt = 'generate audio caption'
segment_length = 30
for audio_start_idx in range(int(duration // segment_length)):
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' '),
"audio_start": audio_start_idx * segment_length
}
dataset_dic["total_num"] += 1
elif dataset_name == "SONYC-UST":
import numpy as np
assert flamingo_task == "EventClassification"
assert split in ["train", "test", "val"]
map_split = lambda split: 'audio'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
all_labels = []
with open(os.path.join(dataset_path, 'annotations.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
for idx, row in tqdm(enumerate(reader)):
if idx == 0:
header = np.array(row)
continue
if not row[0].startswith(split):
continue
filename = row[2]
if filter_file(file_path, file_list, filename):
continue
labels = [header[i] for i in range(12, len(header)-8) if str(row[i]) == "1"]
labels = [x.split("_")[1].replace('-', ' ').lower() for x in labels if 'X_' not in x]
all_labels += labels
text_output = ", ".join(labels)
if len(text_output) <= 1:
continue
text_prompt = 'this is a sound of'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
print('all labels:', list(set(all_labels)))
elif dataset_name == "SoundDescs":
import torch
assert flamingo_task in ["AudioDescription"]
assert split in ["train"]
map_split = lambda split: 'raw/audios'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
split_file = os.path.join(dataset_path, 'audio-retrieval-benchmark/data/SoundDescs/{}_list.txt'.format(split))
with open(split_file, 'r') as f:
data = f.readlines()
names = set([x.replace('\n', '') for x in data])
with open(os.path.join(dataset_path, 'audio-retrieval-benchmark/sounddescs_data/descriptions.pkl'), 'rb') as f:
obj = f.read()
metadata_dic = pickle.loads(obj, encoding='latin1')
for name in tqdm(names):
if name not in metadata_dic.keys():
continue
filename = '{}.wav'.format(name)
if filter_file(file_path, file_list, filename):
continue
description = metadata_dic[name]
text_output = description
text_prompt = 'generate audio description'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "tess":
assert flamingo_task == "EmotionClassification"
assert split in ["train", "val"]
map_split = lambda split: '24414hz'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
split_file = os.path.join(
dataset_path,
'tess_audiopath_text_sid_emotion_duration_{}_filelist.txt'.format(split)
)
with open(split_file, 'r') as f:
data = f.readlines()
data = [x.replace('\n', '') for x in data]
for row in tqdm(data):
if row.count('|') != 4:
continue
filename, utterances, speaker, emotion, duration = row.split('|')
if filter_file(file_path, file_list, filename):
continue
text_output = emotion.replace('_', ' ')
text_output = EMOTION_MAP_DICT[text_output]
text_prompt = 'this emotion is'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "UrbanSound8K":
assert flamingo_task in ["EventClassification"]
assert split in ["train"]
map_split = lambda split: 'audio'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = None
with open(os.path.join(dataset_path, 'metadata/UrbanSound8K.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
filename, fsID, start, end, salience, fold, classID, class_name = row
filename = 'fold{}/{}'.format(fold, filename)
if filter_file(file_path, file_list, filename):
continue
text_output = class_name.replace("_", " ").lower()
if len(text_output) <= 1:
continue
text_prompt = 'this is a sound of'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "VocalSound":
assert flamingo_task == "VocalClassification"
map_split = lambda split: 'data_44k'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
split_file = os.path.join(
dataset_path,
'meta/{}_meta.csv'.format(split[:2] if split in ['train', 'test'] else split[:3])
)
prefix = set([])
with open(split_file, newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
for row in reader:
prefix.add(row[0])
all_labels = set([])
for filename in tqdm(file_list):
if not filename.split('_')[0] in prefix:
continue
if filter_file(file_path, file_list, filename):
continue
label = filename.split('_')[2].split('.wav')[0]
if label == 'throatclearing':
label = 'throat clearing'
text_output = label
text_prompt = 'this vocal sound is'
all_labels.add(label)
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
print('all labels:\n', "\'" + "\', \'".join(list(all_labels)) + "\'")
elif dataset_name.startswith("WavCaps"):
assert split in ["train"]
dataset_name, subset_name = dataset_name.split('-')
dataset_path = os.path.join(
'/'.join(dataset_path.split('/')[:-1]),
dataset_name
)
dataset_dic['dataset_path'] = dataset_path
map_split = lambda split: subset_name + '_flac'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.flac'), os.listdir(file_path)))
metadata_file = os.listdir(os.path.join(dataset_path, "json_files", subset_name))
metadata_file = [x for x in metadata_file if x.endswith('json')][0]
with open(os.path.join(dataset_path, "json_files", subset_name, metadata_file)) as f:
data = f.read()
reader = json.loads(data)
if subset_name == "AudioSet_SL":
assert flamingo_task == 'AudioCaptioning'
for sample in tqdm(reader['data']):
filename = sample["id"].replace('.wav', '.flac')
if filter_file(file_path, file_list, filename):
continue
text_output = sample['caption']
if len(text_output) <= 1:
continue
text_prompt = 'generate audio caption'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
else:
assert flamingo_task in ['AudioCaptioning', 'AudioDescription']
for sample in tqdm(reader['data']):
filename = sample["id"] + '.flac'
if filter_file(file_path, file_list, filename):
continue
if flamingo_task == 'AudioCaptioning':
text_output = sample['caption']
text_prompt = 'generate audio caption'
elif flamingo_task == 'AudioDescription':
text_output = sample['description']
text_prompt = 'generate audio description'
if len(text_output) <= 1:
continue
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif dataset_name == "WavText5K":
assert split == 'train'
map_split = lambda split: 'Webcrawl/44100/audios'
file_path = os.path.join(
dataset_path,
map_split(split)
)
assert os.path.exists(file_path), '{} not exist'.format(file_path)
dataset_dic["split_path"] = map_split(split)
file_list = list(filter(lambda x: x.endswith('.wav'), os.listdir(file_path)))
dic = defaultdict(str)
with open(os.path.join(dataset_path, 'WavText5K.csv'), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader)
for row in tqdm(reader):
_, _, title, description, filename, tags = row
dic[filename] = (title, description, tags)
if flamingo_task == "AudioCaptioning":
for filename in tqdm(dic.keys()):
if filter_file(file_path, file_list, filename):
continue
title, description, tags = dic[filename]
text_output = description
if len(text_output) <= 1:
continue
text_prompt = 'generate audio caption'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
elif flamingo_task == "Tagging":
for filename in tqdm(dic.keys()):
if filter_file(file_path, file_list, filename):
continue
title, description, tags = dic[filename]
if len(tags) < 2 or not tags.startswith('[') or not tags.endswith(']'):
continue
tags = tags[1:-1].split(', ')
tags = list(map(lambda x: x.replace("'", ""), tags))
text_output = '{}'.format(', '.join(tags))
if len(text_output) <= 1:
continue
text_prompt = 'generate tags'
dataset_dic["data"][dataset_dic["total_num"]] = {
"name": filename,
"prompt": text_prompt,
"output": text_output.replace('\n', ' ')
}
dataset_dic["total_num"] += 1
with open(output_file, 'w') as json_file:
json.dump(dataset_dic, json_file)
# ==================== Precompute CLAP and build Hashing ====================
def int16_to_float32(x):
return (x / 32767.0).astype(np.float32)
def float32_to_int16(x):
x = np.clip(x, a_min=-1., a_max=1.)
return (x * 32767.).astype(np.int16)
def update_progress_bar(arg):
pbar.update()
@suppress_all_output
def load_clap_model(checkpoint):
if checkpoint in ['630k-audioset-best.pt', '630k-best.pt', '630k-audioset-fusion-best.pt', '630k-fusion-best.pt']:
amodel = 'HTSAT-tiny'
elif checkpoint in ['music_speech_audioset_epoch_15_esc_89.98.pt']:
amodel = 'HTSAT-base'
else:
raise NotImplementedError
model = laion_clap.CLAP_Module(
enable_fusion=('fusion' in checkpoint.lower()),
amodel=amodel
).cuda()
model.load_ckpt(ckpt=os.path.join(
'/lustre/fsw/portfolios/adlr/users/zkong/audio-flamingo-data/laion-clap-pretrained/laion_clap',
checkpoint
))
return model
def load_audio(file_path, target_sr=44100, duration=30.0, start=0.0):
if file_path.endswith('.mp3'):
audio = AudioSegment.from_file(file_path)
if len(audio) > (start + duration) * 1000:
audio = audio[start * 1000:(start + duration) * 1000]
if audio.frame_rate != target_sr:
audio = audio.set_frame_rate(target_sr)
if audio.channels > 1:
audio = audio.set_channels(1)
data = np.array(audio.get_array_of_samples())
if audio.sample_width == 2:
data = data.astype(np.float32) / np.iinfo(np.int16).max
elif audio.sample_width == 4:
data = data.astype(np.float32) / np.iinfo(np.int32).max
else:
raise ValueError("Unsupported bit depth: {}".format(audio.sample_width))
else:
with sf.SoundFile(file_path) as audio:
original_sr = audio.samplerate
channels = audio.channels
max_frames = int((start + duration) * original_sr)
audio.seek(int(start * original_sr))
frames_to_read = min(max_frames, len(audio))
data = audio.read(frames_to_read)
if data.max() > 1 or data.min() < -1:
data = data / max(abs(data.max()), abs(data.min()))
if original_sr != target_sr:
if channels == 1:
data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr)
else:
data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0]
else:
if channels != 1:
data = data.T[0]
if data.min() >= 0:
data = 2 * data / abs(data.max()) - 1.0
else:
data = data / max(abs(data.max()), abs(data.min()))
return data
@torch.no_grad()
def compute_clap_each(audio_file, model):
try:
data = load_audio(audio_file, target_sr=48000, duration=10)
print(audio_file, 'loaded')
except Exception as e:
print(audio_file, 'unsuccessful due to', e)
return None
audio_data = data.reshape(1, -1)
audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().cuda()
audio_embed = model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True)
audio_embed = audio_embed.squeeze(0).cpu()
return audio_embed
@torch.no_grad()
def compute_embeddings_batch(batch, audio_files, model):
batch_results = []
for i in batch:
if i >= len(audio_files):
break
audio_file = audio_files[i]
audio_embed = compute_clap_each(audio_file, model)
batch_results.append((i, audio_file, audio_embed))
return batch_results
@torch.no_grad()
def precompute_clap_for_dataset(
dataset_file,
embedding_output_file,
checkpoint='630k-audioset-fusion-best.pt'
):
contents, audio_files = load_dataset_file(dataset_file)
model = load_clap_model(checkpoint)
if os.path.exists(embedding_output_file):
print('loading already computed embedding file from', embedding_output_file)
with open(embedding_output_file, 'rb') as f:
saved_data = pickle.load(f)
curr_audio_indices = saved_data['audio_indices']
curr_audio_files = saved_data['audio_files']
curr_audio_embeds = saved_data['audio_embeds']
else:
curr_audio_indices = []
curr_audio_files = []
curr_audio_embeds = []
print('computing embeddings for {}'.format(dataset_file))
start_index = len(curr_audio_files)
remaining_indices = list(range(start_index, len(audio_files)))
batch_size = 128
batches = [
list(range(i, min(i + batch_size, len(audio_files)))) \
for i in range(start_index, len(audio_files), batch_size)
]
with multiprocessing.Pool(processes=4) as pool:
for i, batch in enumerate(batches):
batch_results = pool.map(
partial(compute_embeddings_batch, model=model, audio_files=audio_files),
[batch]
)
for result in batch_results[0]:
curr_audio_indices.append(result[0])
curr_audio_files.append(result[1])
curr_audio_embeds.append(result[2])
with open(embedding_output_file, 'wb') as f:
pickle.dump({
'audio_indices': curr_audio_indices,
'audio_files': curr_audio_files,
'audio_embeds': curr_audio_embeds
}, f)
print(f"Saved progress for batch {i+1}/{len(batches)}: \
audio_indices {len(curr_audio_indices)}, \
audio_files {len(curr_audio_files)}, \
audio_embeds {len(curr_audio_embeds)}*{curr_audio_embeds[0].shape}")
return curr_audio_indices, curr_audio_files, curr_audio_embeds
def build_faiss_index(embeddings):
d = embeddings[0].size(0)
index = faiss.IndexFlatL2(d)
np_embeddings = np.vstack([emb.numpy() for emb in embeddings])
index.add(np_embeddings)
return index
def build_faiss_index_dataset(
dataset_file,
embedding_output_file,
faiss_output_file,
checkpoint='630k-audioset-fusion-best.pt',
only_precompute_clap=False
):
audio_indices, audio_files, audio_embeds = precompute_clap_for_dataset(dataset_file, embedding_output_file, checkpoint)
if only_precompute_clap:
return
valid_indices, valid_files, valid_embeds = [], [], []
for audio_index, audio_file, audio_embed in zip(audio_indices, audio_files, audio_embeds):
if audio_embed is not None:
valid_indices.append(audio_index)
valid_files.append(audio_file)
valid_embeds.append(audio_embed)
print('building faiss index')
faiss_index = build_faiss_index(valid_embeds)
print('saving faiss index')
faiss.write_index(faiss_index, faiss_output_file)
with open(faiss_output_file + '.filenames', 'wb') as f:
pickle.dump({'audio_indices': valid_indices, 'audio_files': valid_files}, f)
# ==================== Generate interleaved dataset files ====================
# only save index so that one can recover
def build_interleaved_dataset(dataset_file, interleaved_output_file, embedding_output_file, faiss_output_file, mode='random', n_samples=3):
contents, audio_files = load_dataset_file(dataset_file)
dataset_dic = {
"dataset_path": contents["dataset_path"],
"split": contents["split"],
"split_path": contents["split_path"],
"flamingo_task": contents["flamingo_task"],
"total_num": 0,
"interleaved_data": {},
}
# interleaved_data is
# {
# id: {
# "generation_index_in_split": index of sample in the train or val or test.json,
# "fewshot_indices_in_train": list(indices) of few shot samples in train.json
# }
# }
if mode == 'knn':
model = load_clap_model(checkpoint='630k-audioset-fusion-best.pt')
print('loading already computed embedding file from', embedding_output_file)
with open(embedding_output_file, 'rb') as f:
precomputed_data = pickle.load(f)
precomputed_audio_indices = precomputed_data['audio_indices']
precomputed_audio_files = precomputed_data['audio_files']
precomputed_audio_embeds = precomputed_data['audio_embeds']
faiss_index = faiss.read_index(faiss_output_file)
with open(faiss_output_file+'.filenames', 'rb') as f:
_data = pickle.load(f)
faiss_index_audio_indices = _data['audio_indices']
faiss_index_audio_files = _data['audio_files']
print('looking for few shot samples and building interleaved_{} data'.format(mode))
for i in tqdm(range(contents["total_num"])):
if mode == 'random':
few_shot_indices = list(np.random.choice(
list(set(list(range(contents["total_num"]))) - set([i])),
size=n_samples-1,
replace=False
))
few_shot_indices = list(map(int, few_shot_indices))
elif mode == 'knn':
if audio_files[i] in precomputed_audio_files:
idx = precomputed_audio_files.index(audio_files[i])
query_embedding_np = precomputed_audio_embeds[idx]
if query_embedding_np is not None:
query_embedding_np = query_embedding_np.numpy().reshape(1, -1)
else:
continue
else:
query_embedding_np = compute_clap_each(audio_files[i], model)
if query_embedding_np is not None:
query_embedding_np = query_embedding_np.numpy().reshape(1, -1)
else:
continue
distances, knn_indices = faiss_index.search(query_embedding_np, n_samples+50)
distances = distances[0]
knn_indices = knn_indices[0]
knn_filenames = [faiss_index_audio_files[idx] for idx in knn_indices]
combined = list(zip(knn_indices, knn_filenames))
unique_indices = defaultdict(list)
for idx, filename in combined:
unique_indices[filename].append(idx)
cleared_knn_indices = [random.choice(unique_indices[filename]) for filename in unique_indices if filename != audio_files[i]]
if dataset_file.endswith('train.json'):
cleared_knn_indices = [knn_i for knn_i in cleared_knn_indices if faiss_index_audio_indices[knn_i] != i]
cleared_knn_indices = cleared_knn_indices[:n_samples-1]
np.random.shuffle(cleared_knn_indices)
few_shot_indices = [faiss_index_audio_indices[knn_i] for knn_i in cleared_knn_indices]
dataset_dic["interleaved_data"][dataset_dic["total_num"]] = {
"generation_index_in_split": i,
"fewshot_indices_in_train": few_shot_indices
}
dataset_dic["total_num"] += 1
with open(interleaved_output_file, 'w') as json_file:
json.dump(dataset_dic, json_file)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--dataset_name', type=str, help='dataset name')
parser.add_argument('-f', '--flamingo_task', type=str, help='flamingo task')
parser.add_argument('--interleave', action="store_true", help='prepare the interleave dataset')
args = parser.parse_args()
ROOT = "/lustre/fsw/portfolios/adlr/users/zkong"
dataset_root = os.path.join(ROOT, "datasets")
output_root = os.path.join(ROOT, "audio-flamingo-data/dataset_files")
os.makedirs(output_root, exist_ok=True)
dataset_name = args.dataset_name # "Clotho-v2", "AudioSet", "Clotho-AQA", "WavText5K", "FSD50k", ...
flamingo_task = args.flamingo_task # AQA, AudioCaptioning, EventClassification, SceneClassification, Tagging, ...
# must be train first otherwise there's no train.embedding for query
for split in ["train", "val", "test"]:
dataset_path = os.path.join(dataset_root, dataset_name)
output_folder = '{}-{}'.format(dataset_name, flamingo_task)
os.makedirs(os.path.join(output_root, output_folder), exist_ok=True)
dataset_file = os.path.join(output_root, output_folder, '{}.json'.format(split))
if not os.path.exists(dataset_file):
try:
prepare_files(dataset_name, dataset_path, split, flamingo_task, dataset_file)
except AssertionError as e:
print('split {} not exist for {}: {}'.format(split, dataset_name, e))
continue
else:
print('{} exists; exiting'.format(dataset_file))
if args.interleave:
faiss_output_file = dataset_file.replace('{}.json'.format(split), "train_faiss_index.index")
embedding_output_file = dataset_file.replace('.json', ".embedding")
if split == 'train':
if (not os.path.exists(faiss_output_file)) or (not os.path.exists(faiss_output_file + '.filenames')):
build_faiss_index_dataset(
dataset_file, embedding_output_file, faiss_output_file,
only_precompute_clap=False
)
else:
print('{} exists; exiting'.format(faiss_output_file))
else:
build_faiss_index_dataset(
dataset_file, embedding_output_file,
faiss_output_file=None,
only_precompute_clap=True
)
print('precomputing embedding for {} subset finished'.format(split))
for mode in ['knn', 'random']:
interleaved_output_file = '/'.join(
dataset_file.split('/')[:-1] + \
['interleaved_{}-'.format(mode) + dataset_file.split('/')[-1]]
)
if not os.path.exists(interleaved_output_file):
build_interleaved_dataset(
dataset_file=dataset_file,
interleaved_output_file=interleaved_output_file,
embedding_output_file=embedding_output_file,
faiss_output_file=faiss_output_file,
mode=mode,
n_samples=4
)
else:
print('{} exists; exiting'.format(interleaved_output_file))