Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| import h5py | |
| import soundfile | |
| import librosa | |
| import numpy as np | |
| import pandas as pd | |
| from scipy import stats | |
| import datetime | |
| import pickle | |
| def create_folder(fd): | |
| if not os.path.exists(fd): | |
| os.makedirs(fd) | |
| def get_filename(path): | |
| path = os.path.realpath(path) | |
| na_ext = path.split('/')[-1] | |
| na = os.path.splitext(na_ext)[0] | |
| return na | |
| def get_sub_filepaths(folder): | |
| paths = [] | |
| for root, dirs, files in os.walk(folder): | |
| for name in files: | |
| path = os.path.join(root, name) | |
| paths.append(path) | |
| return paths | |
| def create_logging(log_dir, filemode): | |
| create_folder(log_dir) | |
| i1 = 0 | |
| while os.path.isfile(os.path.join(log_dir, '{:04d}.log'.format(i1))): | |
| i1 += 1 | |
| log_path = os.path.join(log_dir, '{:04d}.log'.format(i1)) | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', | |
| datefmt='%a, %d %b %Y %H:%M:%S', | |
| filename=log_path, | |
| filemode=filemode) | |
| # Print to console | |
| console = logging.StreamHandler() | |
| console.setLevel(logging.INFO) | |
| formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') | |
| console.setFormatter(formatter) | |
| logging.getLogger('').addHandler(console) | |
| return logging | |
| def read_metadata(csv_path, classes_num, id_to_ix): | |
| """Read metadata of AudioSet from a csv file. | |
| Args: | |
| csv_path: str | |
| Returns: | |
| meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} | |
| """ | |
| with open(csv_path, 'r') as fr: | |
| lines = fr.readlines() | |
| lines = lines[3:] # Remove heads | |
| audios_num = len(lines) | |
| targets = np.zeros((audios_num, classes_num), dtype=np.bool) | |
| audio_names = [] | |
| for n, line in enumerate(lines): | |
| items = line.split(', ') | |
| """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" | |
| audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading | |
| label_ids = items[3].split('"')[1].split(',') | |
| audio_names.append(audio_name) | |
| # Target | |
| for id in label_ids: | |
| ix = id_to_ix[id] | |
| targets[n, ix] = 1 | |
| meta_dict = {'audio_name': np.array(audio_names), 'target': targets} | |
| return meta_dict | |
| def float32_to_int16(x): | |
| assert np.max(np.abs(x)) <= 1.2 | |
| x = np.clip(x, -1, 1) | |
| return (x * 32767.).astype(np.int16) | |
| def int16_to_float32(x): | |
| return (x / 32767.).astype(np.float32) | |
| def pad_or_truncate(x, audio_length): | |
| """Pad all audio to specific length.""" | |
| if len(x) <= audio_length: | |
| return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0) | |
| else: | |
| return x[0 : audio_length] | |
| def d_prime(auc): | |
| d_prime = stats.norm().ppf(auc) * np.sqrt(2.0) | |
| return d_prime | |
| class Mixup(object): | |
| def __init__(self, mixup_alpha, random_seed=1234): | |
| """Mixup coefficient generator. | |
| """ | |
| self.mixup_alpha = mixup_alpha | |
| self.random_state = np.random.RandomState(random_seed) | |
| def get_lambda(self, batch_size): | |
| """Get mixup random coefficients. | |
| Args: | |
| batch_size: int | |
| Returns: | |
| mixup_lambdas: (batch_size,) | |
| """ | |
| mixup_lambdas = [] | |
| for n in range(0, batch_size, 2): | |
| lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0] | |
| mixup_lambdas.append(lam) | |
| mixup_lambdas.append(1. - lam) | |
| return np.array(mixup_lambdas) | |
| class StatisticsContainer(object): | |
| def __init__(self, statistics_path): | |
| """Contain statistics of different training iterations. | |
| """ | |
| self.statistics_path = statistics_path | |
| self.backup_statistics_path = '{}_{}.pkl'.format( | |
| os.path.splitext(self.statistics_path)[0], | |
| datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) | |
| self.statistics_dict = {'bal': [], 'test': []} | |
| def append(self, iteration, statistics, data_type): | |
| statistics['iteration'] = iteration | |
| self.statistics_dict[data_type].append(statistics) | |
| def dump(self): | |
| pickle.dump(self.statistics_dict, open(self.statistics_path, 'wb')) | |
| pickle.dump(self.statistics_dict, open(self.backup_statistics_path, 'wb')) | |
| logging.info(' Dump statistics to {}'.format(self.statistics_path)) | |
| logging.info(' Dump statistics to {}'.format(self.backup_statistics_path)) | |
| def load_state_dict(self, resume_iteration): | |
| self.statistics_dict = pickle.load(open(self.statistics_path, 'rb')) | |
| resume_statistics_dict = {'bal': [], 'test': []} | |
| for key in self.statistics_dict.keys(): | |
| for statistics in self.statistics_dict[key]: | |
| if statistics['iteration'] <= resume_iteration: | |
| resume_statistics_dict[key].append(statistics) | |
| self.statistics_dict = resume_statistics_dict |