Spaces:
Runtime error
Runtime error
| import string | |
| import numpy as np | |
| import matplotlib.animation as animation | |
| from matplotlib import pyplot as plt | |
| import json | |
| from collections import defaultdict | |
| from bisect import bisect_left | |
| import os | |
| import torch | |
| import torchaudio | |
| torchaudio.set_audio_backend("sox_io") | |
| def load_json(json_path: str): | |
| """ | |
| Load a json file | |
| """ | |
| with open(json_path, "r", encoding="utf-8") as f_name: | |
| data = json.load(f_name) | |
| return data | |
| def check_window_signal(info_t, w_s, w_e): | |
| length = w_e - w_s | |
| frame_offset = int(w_s * info_t.sample_rate) | |
| num_frames = int(length * info_t.sample_rate) | |
| if frame_offset + num_frames > int(info_t.num_frames): | |
| return False | |
| else: | |
| return True | |
| def index_narrations(ann_path): | |
| narration_raw = load_json(ann_path) | |
| narration_dict = defaultdict(list) | |
| summary_dict = defaultdict(list) | |
| avg_len = [] | |
| for v_id, narr in narration_raw.items(): | |
| narr_list = [] | |
| summ_list = [] | |
| if "narration_pass_1" in narr: | |
| narr_list += narr["narration_pass_1"]["narrations"] | |
| summ_list += narr["narration_pass_1"]["summaries"] | |
| if "narration_pass_2" in narr: | |
| narr_list += narr["narration_pass_2"]["narrations"] | |
| summ_list += narr["narration_pass_2"]["summaries"] | |
| if len(narr_list) > 0: | |
| narration_dict[v_id] = [ | |
| ( | |
| float(n_t["timestamp_sec"]), | |
| n_t["narration_text"], | |
| n_t["annotation_uid"], | |
| n_t["timestamp_frame"], | |
| ) | |
| for n_t in narr_list | |
| ] | |
| avg_len.append(len(narration_dict[v_id])) | |
| else: | |
| narration_dict[v_id] = [] | |
| if len(summ_list) > 0: | |
| summary_dict[v_id] = [ | |
| ( | |
| float(s_t["start_sec"]), | |
| float(s_t["end_sec"]), | |
| s_t["summary_text"], | |
| ) | |
| for s_t in summ_list | |
| ] | |
| else: | |
| summary_dict[v_id] = [] | |
| # print(f"Number of Videos with narration {len(narration_dict)}") | |
| # print(f"Avg. narration length {np.mean(avg_len)}") | |
| # print(f"Number of Videos with summaries {len(summary_dict)}") | |
| return narration_dict, summary_dict | |
| def get_signal_info(signal_fn: str): | |
| return torchaudio.info(signal_fn) | |
| def get_signal_frames(signal_fn: str, video_start_sec: float, video_end_sec: float): | |
| """ | |
| Given a signal track return the frames between video_start_sec and video_end_sec | |
| """ | |
| info_t = get_signal_info(signal_fn) | |
| length = video_end_sec - video_start_sec | |
| aframes, _ = torchaudio.load( | |
| signal_fn, | |
| normalize=True, | |
| frame_offset=int(video_start_sec * info_t.sample_rate), | |
| num_frames=int(length * info_t.sample_rate), | |
| ) | |
| return {"signal": aframes, "meta": info_t} | |
| def tosec(value): | |
| return value / 1000 | |
| def toms(value): | |
| return value * 1000 | |
| def delta(first_num: float, second_num: float): | |
| """Compute the absolute value of the difference of two numbers""" | |
| return abs(first_num - second_num) | |
| def padIMU(signal, duration_sec): | |
| """ | |
| Pad the signal if necessary | |
| """ | |
| expected_elements = round(duration_sec) * 200 | |
| if signal.shape[0] > expected_elements: | |
| signal = signal[:expected_elements, :] | |
| elif signal.shape[0] < expected_elements: | |
| padding = expected_elements - signal.shape[0] | |
| padded_zeros = np.zeros((padding, 6)) | |
| signal = np.concatenate([signal, padded_zeros], 0) | |
| # signal = signal[:expected_elements, :] | |
| return signal | |
| def resample( | |
| signals: np.ndarray, | |
| timestamps: np.ndarray, | |
| original_sample_rate: int, | |
| resample_rate: int, | |
| ): | |
| """ | |
| Resamples data to new sample rate | |
| """ | |
| signals = torch.as_tensor(signals) | |
| timestamps = torch.from_numpy(timestamps).unsqueeze(-1) | |
| signals = torchaudio.functional.resample( | |
| waveform=signals.data.T, | |
| orig_freq=original_sample_rate, | |
| new_freq=resample_rate, | |
| ).T.numpy() | |
| nsamples = len(signals) | |
| period = 1 / resample_rate | |
| # timestamps are expected to be shape (N, 1) | |
| initital_seconds = timestamps[0] / 1e3 | |
| ntimes = (torch.arange(nsamples) * period).view(-1, 1) + initital_seconds | |
| timestamps = (ntimes * 1e3).squeeze().numpy() | |
| return signals, timestamps | |
| def resampleIMU(signal, timestamps): | |
| sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps))))) | |
| # resample all to 200hz | |
| if sampling_rate != 200: | |
| signal, timestamps = resample(signal, timestamps, sampling_rate, 200) | |
| return signal, timestamps | |
| def get_imu_frames( | |
| imu_path, | |
| uid: str, | |
| video_start_sec: float, | |
| video_end_sec: float, | |
| ): | |
| """ | |
| Given a IMU signal return the frames between video_start_sec and video_end_sec | |
| """ | |
| signal = np.load(os.path.join(imu_path, f"{uid}.npy")) | |
| signal = signal.transpose() | |
| timestamps = np.load(os.path.join(imu_path, f"{uid}_timestamps.npy")) | |
| if toms(video_start_sec) > timestamps[-1] or toms(video_end_sec) > timestamps[-1]: | |
| return None | |
| start_id = bisect_left(timestamps, toms(video_start_sec)) | |
| end_id = bisect_left(timestamps, toms(video_end_sec)) | |
| # make sure the retrieved window interval are correct by a max of 1 sec margin | |
| if ( | |
| delta(video_start_sec, tosec(timestamps[start_id])) > 4 | |
| or delta(video_end_sec, tosec(timestamps[end_id])) > 4 | |
| ): | |
| return None | |
| # get the window | |
| if start_id == end_id: | |
| start_id -= 1 | |
| end_id += 1 | |
| signal, timestamps = signal[start_id:end_id], timestamps[start_id:end_id] | |
| if len(signal) < 10 or len(timestamps) < 10: | |
| return None | |
| # resample the signal at 200hz if necessary | |
| signal, timestamps = resampleIMU(signal, timestamps) | |
| # pad the signal if necessary | |
| signal = padIMU(signal, video_end_sec - video_start_sec) | |
| sample_dict = { | |
| "timestamp": timestamps, | |
| "signal": torch.tensor(signal.T), | |
| "sampling_rate": 200, | |
| } | |
| return sample_dict | |
| def display_animation(frames, title, save_path_gif): | |
| fig, ax = plt.subplots() | |
| frames = [[ax.imshow(frames[i])] for i in range(len(frames))] | |
| plt.title(title) | |
| ani = animation.ArtistAnimation(fig, frames) | |
| ani.save(save_path_gif, writer="imagemagick") | |
| plt.close() | |
| def display_animation_imu(frames, imu, title, save_path_gif): | |
| fig, (ax1, ax2, ax3) = plt.subplots(3, 1) | |
| ax1.set_title(title) | |
| ax2.set_title("Acc.") | |
| ax3.set_title("Gyro.") | |
| frames = [[ax1.imshow(frames[i])] for i in range(len(frames))] | |
| ani = animation.ArtistAnimation(fig, frames) | |
| ax2.plot(imu[0].cpu().numpy(), color="red") | |
| ax2.plot(imu[1].cpu().numpy(), color="blue") | |
| ax2.plot(imu[2].cpu().numpy(), color="green") | |
| ax3.plot(imu[3].cpu().numpy(), color="red") | |
| ax3.plot(imu[4].cpu().numpy(), color="blue") | |
| ax3.plot(imu[5].cpu().numpy(), color="green") | |
| plt.tight_layout() | |
| ani.save(save_path_gif, writer="imagemagick") | |
| plt.close() | |
| def filter_narration(narration_text: str) -> bool: | |
| if "#c" in narration_text.lower(): | |
| return True | |
| return False | |
| def clean_narration_text(narration_text: str) -> str: | |
| return ( | |
| narration_text.replace("#C C ", "") | |
| .replace("#C", "") | |
| .replace("#unsure", "something") | |
| .strip() | |
| .strip(string.punctuation) | |
| .lower()[:128] | |
| ) | |