import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import soundfile as sf
from collections import defaultdict
from dtw import dtw
from sklearn_extra.cluster import KMedoids
from copy import deepcopy
import os, librosa, json
# based on original implementation by
# https://colab.research.google.com/drive/1RApnJEocx3-mqdQC2h5SH8vucDkSlQYt?authuser=1#scrollTo=410ecd91fa29bc73
# by magnús freyr morthens 2023 supported by rannís nsn
def z_score(x, mean, std):
return (x - mean) / std
# given a sentence and list of its speakers + their alignment files,
# return a dictionary of word alignments
def get_word_aligns(norm_sent, aln_paths):
Returns a dictionary of word alignments for a given sentence.
word_aligns = defaultdict(list)
slist = norm_sent.split(" ")
for spk,aln_path in aln_paths:
with open(aln_path) as f:
lines = f.read().splitlines()
lines = [l.split('\t') for l in lines]
assert len(lines) == len(slist)
word_aligns[spk] = [(w,float(s),float(e)) for w,s,e in lines]
print(slist, lines, "<---- something didn't match")
return word_aligns
#TODO pass whole path
def get_pitches(start_time, end_time, fpath):
Returns an array of pitch values for a given speech.
Reads from .f0 file of Time, F0, IsVoiced
with open(fpath) as f:
lines = f.read().splitlines()
lines = [[float(x) for x in line.split()] for line in lines] # split lines into floats
pitches = []
# find the mean of all pitches in the whole sentence
mean = np.mean([line[1] for line in lines if line[2] != -1])
# find the std of all pitches in the whole sentence
std = np.std([line[1] for line in lines if line[2] != -1])
for line in lines:
time, pitch, is_pitch = line
if start_time <= time <= end_time:
if is_pitch:
pitches.append(z_score(pitch, mean, std))
#pitches.append(z_score(fifth_percentile, mean, std))
return pitches
# TODO take whole path
# jcheng used energy from esps get_f0
# get f0 says (?) :
#The RMS value of each record is computed based on a 30 msec hanning
#window with its left edge placed 5 msec before the beginning of the
# jcheng z-scored the energys, per file.
# TODO: implement that. ?
# not sure librosa provides hamming window in rms function directly
# TODO handle audio that not originally .wav
def get_rmse(start_time, end_time, wpath):
Returns an array of RMSE values for a given speech.
audio, sr = librosa.load(wpath, sr=16000)
segment = audio[int(np.floor(start_time * sr)):int(np.ceil(end_time * sr))]
rmse = librosa.feature.rms(y=segment,frame_length=480,hop_length=80)#librosa.feature.rms(y=segment)
rmse = rmse[0]
#idx = np.round(np.linspace(0, len(rmse) - 1, pitch_len)).astype(int)
return rmse#[idx]
# may be unnecessary depending how rmse and pitch window/hop are calculated already
def downsample_rmse2pitch(rmse,pitch_len):
idx = np.round(np.linspace(0, len(rmse) - 1, pitch_len)).astype(int)
return rmse[idx]
# parse user input string to usable word indices for the sentence
# TODO handle cases
def parse_word_indices(start_end_word_index):
ixs = start_end_word_index.split('-')
if len(ixs) == 1:
s = int(ixs[0])
e = int(ixs[0])
s = int(ixs[0])
e = int(ixs[-1])
return s-1,e-1
# take any (1stword, lastword) or (word)
# unit and prepare data for that unit
def get_data(norm_sent,path_key,start_end_word_index):
Returns a dictionary of pitch, rmse, and spectral centroids values for a given sentence/word combinations.
s_ix, e_ix = parse_word_indices(start_end_word_index)
words = '_'.join(norm_sent.split(' ')[s_ix:e_ix+1])
align_paths = [(spk,pdict['aln']) for spk,pdict in path_key]
word_aligns = get_word_aligns(norm_sent, align_paths)
data = defaultdict(list)
align_data = defaultdict(list)
for spk, pdict in path_key:
word_al = word_aligns[spk]
start_time = word_al[s_ix][1]
end_time = word_al[e_ix][2]
seg_aligns = word_al[s_ix:e_ix+1]
seg_aligns = [(w,round(s-start_time,2),round(e-start_time,2)) for w,s,e in seg_aligns]
pitches = get_pitches(start_time, end_time, pdict['f0'])
rmses = get_rmse(start_time, end_time, pdict['wav'])
rmses = downsample_rmse2pitch(rmses,len(pitches))
#spectral_centroids = get_spectral_centroids(start_time, end_time, id, wav_dir, len(pitches))
pitches_cpy = np.array(deepcopy(pitches))
rmses_cpy = np.array(deepcopy(rmses))
d = [[p, r] for p, r in zip(pitches_cpy, rmses_cpy)]
#words = "-".join(word_combs)
data[f"{words}**{spk}"] = d
align_data[f"{words}**{spk}"] = seg_aligns
return words, data, align_data
def dtw_distance(x, y):
Returns the DTW distance between two pitch sequences.
alignment = dtw(x, y, keep_internals=True)
return alignment.normalizedDistance
# recs is a sorted list of rec IDs
# all recs/data contain the same words
# rec1 and rec2 can be the same
def pair_dists(data,words,recs):
dtw_dists = []
for rec1 in recs:
key1 = f'{words}**{rec1}'
val1 = data[key1]
for rec2 in recs:
key2 = f'{words}**{rec2}'
val2 = data[key2]
dtw_dists.append((f"{rec1}**{rec2}", dtw_distance(val1, val2)))
return dtw_dists
# make n_clusters a param with default 3
def kmedoids_clustering(X):
kmedoids = KMedoids(n_clusters=3, random_state=0).fit(X)
y_km = kmedoids.labels_
return y_km, kmedoids
def match_tts(clusters, speech_data, tts_data, tts_align, words, seg_aligns, voice):
tts_info = []
for label in set([c for r,c in clusters]):
recs = [r for r,c in clusters if c==label]
dists = []
for rec in recs:
key = f'{words}**{rec}'
dists.append(dtw_distance(tts_data, speech_data[key]))
tts_info = sorted(tts_info,key = lambda x: x[1])
best_cluster = tts_info[0][0]
best_cluster_score = tts_info[0][1]
matched_data = {f'{words}**{r}': speech_data[f'{words}**{r}'] for r,c in clusters if c==best_cluster}
# now do graphs of matched_data with tts_data
# and report best_cluster_score
mid_cluster = tts_info[1][0]
mid_data = {f'{words}**{r}': speech_data[f'{words}**{r}'] for r,c in clusters if c==mid_cluster}
bad_cluster = tts_info[2][0]
bad_data = {f'{words}**{r}': speech_data[f'{words}**{r}'] for r,c in clusters if c==bad_cluster}
#tts_fig_p = plot_pitch_tts(matched_data,tts_data, tts_align, words,seg_aligns,best_cluster,voice)
tts_fig_p = plot_one_cluster(words,'pitch',matched_data,seg_aligns,cluster,tts_data=tts_data,tts_align=tts_align,voice=voice)
fig_mid_p = plot_one_cluster(words,'pitch',mid_data,seg_aligns,cluster)
fig_bad_p = plot_one_cluster(words,'pitch',bad_data,seg_aligns,cluster)
tts_fig_e = plot_one_cluster(words,'rmse',matched_data,seg_aligns,cluster,tts_data=tts_data,tts_align=tts_align,voice=voice)
fig_mid_e = plot_one_cluster(words,'rmse',mid_data,seg_aligns,cluster)
fig_bad_e = plot_one_cluster(words,'rmse',bad_data,seg_aligns,cluster)
return best_cluster_score, tts_fig_p, fig_mid_p, fig_bad_p, tts_fig_e, fig_mid_e, fig_bad_e
def gp(d,s,x):
return os.path.join(d, f'{s}.{x}')
def gen_tts_paths(tdir,voices):
plist = [(v, {'wav': gp(tdir,v,'wav'), 'aln': gp(tdir,v,'tsv'), 'f0': gp(tdir,v,'f0')}) for v in voices]
return plist
def gen_h_paths(wdir,adir,f0dir,spks):
plist = [(s, {'wav': gp(wdir,s,'wav'), 'aln': gp(adir,s,'tsv'), 'f0': gp(f0dir,s,'f0')}) for s in spks]
return plist
# since clustering strictly operates on X,
# once reduce a duration metric down to pair-distances,
# it no longer matters that duration and pitch/energy had different dimensionality
# TODO option to dtw on 3 feats pitch/ener/dur separately
# check if possible cluster with 3dim distance mat?
# or can it not take that input in multidimensional space
# then the 3 dists can still be averaged to flatten, if appropriately scaled
def cluster(norm_sent,orig_sent,h_spk_ids, h_align_dir, h_f0_dir, h_wav_dir, tts_sent_dir, voices, start_end_word_index):
h_spk_ids = sorted(h_spk_ids)
nsents = len(h_spk_ids)
h_all_paths = gen_h_paths(h_wav_dir,h_align_dir,h_f0_dir,h_spk_ids)
words, h_data, h_seg_aligns = get_data(norm_sent,h_all_paths,start_end_word_index)
dtw_dists = pair_dists(h_data,words,h_spk_ids)
kmedoids_cluster_dists = []
X = [d[1] for d in dtw_dists]
X = [X[i:i+nsents] for i in range(0, len(X), nsents)]
X = np.array(X)
y_km, kmedoids = kmedoids_clustering(X)
#plot_clusters(X, y_km, words)
#c1, c2, c3 = [X[np.where(kmedoids.labels_ == i)] for i in range(3)]
result = zip(X, kmedoids.labels_)
groups = [[r,c] for r,c in zip(h_spk_ids,kmedoids.labels_)]
tts_all_paths = gen_tts_paths(tts_sent_dir, voices)
_, tts_data, tts_seg_aligns = get_data(norm_sent,tts_all_paths,start_end_word_index)
for v in voices:
voice_data = tts_data[f"{words}**{v}"]
voice_align = tts_seg_aligns[f"{words}**{v}"]
#tts_data, tts_align = get_one_tts_data(tts_sent_dir,v,norm_sent,start_end_word_index)
# match the data with a cluster -----
best_cluster_score, tts_fig_p, fig_mid_p, fig_bad_p, tts_fig_e, fig_mid_e, fig_bad_e = match_tts(groups, h_data, voice_data, voice_align, words, h_seg_aligns,v)
# only supports one voice at a time currently
return best_cluster_score, tts_fig_p, fig_mid_p, fig_bad_p, tts_fig_e, fig_mid_e, fig_bad_e
#return words, kmedoids_cluster_dists, group
# TODO there IS sth for making tts_data
# but im probably p much on my own rlly for that.
# TODO this one is v v helpful.
# but mind if i adjusted a dictionaries earlier.
def spks_all_cdist():
speaker_to_tts_dtw_dists = defaultdict(list)
for key1, value1 in data.items():
d = key1.split("-")
words1 = d[:-2]
id1, id2 = d[-2], d[-1]
for key2, value2 in tts_data.items():
d = key2.split("-")
words2 = d[:-2]
id3, id4 = d[-2], d[-1]
if all([w1 == w2 for w1, w2 in zip(words1, words2)]):
speaker_to_tts_dtw_dists[f"{'-'.join(words1)}"].append((f"{id1}-{id2}_{id3}-{id4}", dtw_distance(value1, value2)))
return speaker_to_tts_dtw_dists
#TODO i think this is also gr8
# but like figure out how its doing
# bc dict format and stuff,
# working keying by word index instead of word text, ***********
# and for 1 wd or 3+ wd units...
def tts_cdist():
tts_dist_to_cluster = defaultdict(list)
for words1, datas1 in kmedoids_cluster_dists.items():
for d1 in datas1:
cluster, sp_id1, arr = d1
for words2, datas2 in speaker_to_tts_dtw_dists.items():
for d2 in datas2:
ids, dist = d2
sp_id2, tts_alfur = ids.split("_")
if sp_id1 == sp_id2 and words1 == words2:
tts_mean_dist_to_cluster = {
key: np.mean(value) for key, value in tts_dist_to_cluster.items()
return tts_mean_dist_to_cluster
# TODO check if anything uses this?
def get_audio_part(start_time, end_time, id, path):
Returns a dictionary of RMSE values for a given sentence.
f = os.path.join(path, id + ".wav")
audio, sr = librosa.load(f, sr=16000)
segment = audio[int(np.floor(start_time * sr)):int(np.ceil(end_time * sr))]
return segment
def plot_one_cluster(words,feature,speech_data,seg_aligns,cluster_id,tts_data=None,tts_align=None,voice=None):
#(speech_data, tts_data, tts_align, words, seg_aligns, cluster_id, voice):
colors = ["red", "green", "blue", "orange", "purple", "pink", "brown", "gray", "cyan"]
cc = 0
fig = plt.figure(figsize=(10, 5))
if feature.lower() in ['pitch','f0']:
fname = 'Pitch'
ffunc = lambda x: [p for p,e in x]
elif feature.lower() in ['energy', 'rmse']:
fname = 'Energy'
ffunc = lambda x: [e for p,e in x]
print('problem with the figure')
return fig
plt.title(f"{words} - {fname} - Cluster {cluster_id}")
for k,v in speech_data.items():
spk = k.split('**')[1]
word_times = seg_aligns[k]
feats = ffunc(v)
# datapoint interval is 0.005 seconds
feat_xvals = [x*0.005 for x in range(len(feats))]
# centre around the first word boundary -
# if 3+ words, too bad.
if len(word_times)>1:
realign = np.mean([word_times[0][2],word_times[1][1]])
feat_xvals = [x - realign for x in feat_xvals]
word_times = [(w,s-realign,e-realign) for w,s,e in word_times]
plt.axvline(x= 0, color="gray", linestyle='--', linewidth=1, label=f"{word_times[0][0]} -> {word_times[1][0]} boundary")
if len(word_times)>2:
for i in range(1,len(word_times)-1):
bound_line = np.mean([word_times[i][2],word_times[i+1][1]])
plt.axvline(x=bound_line, color=colors[cc], linestyle='--', linewidth=1, label=f"Speaker {spk} -> {word_times[i+1][0]}")
plt.scatter(feat_xvals, feats, color=colors[cc], label=f"Speaker {spk}")
cc += 1
if cc >= len(colors):
if voice:
tfeats = [p for p,e in tts_data]
t_xvals = [x*0.005 for x in range(len(tfeats))]
if len(tts_align)>1:
realign = np.mean([tts_align[0][2],tts_align[1][1]])
t_xvals = [x - realign for x in t_xvals]
tts_align = [(w,s-realign,e-realign) for w,s,e in tts_align]
if len(tts_align)>2:
for i in range(1,len(tts_align)-1):
bound_line = np.mean([tts_align[i][2],tts_align[i+1][1]])
plt.axvline(x=bound_line, color="black", linestyle='--', linewidth=1, label=f"TTS -> {tts_align[i+1][0]}")
plt.scatter(t_xvals, tfeats, color="black", label=f"TTS {voice}")
return fig