Spaces:
Running
Running
| import os | |
| import subprocess | |
| import glob | |
| import streamlit as st | |
| from utils import get_configs, get_display_names, get_path_for_viz, get_video_height, get_text_str | |
| # from gdrive import download_file | |
| # st.header("EVREAL - Event-based Video Reconstruction Evaluation and Analysis Library") | |
| # | |
| # paper_link = "https://arxiv.org/abs/2305.00434" | |
| # code_link = "https://github.com/ercanburak/EVREAL" | |
| # page_link = "https://ercanburak.github.io/evreal.html" | |
| # instructions_video = "https://www.youtube.com/watch?v=" | |
| # | |
| # st.markdown("Paper: " + paper_link, unsafe_allow_html=True) | |
| # st.markdown("Code: " + paper_link, unsafe_allow_html=True) | |
| # st.markdown("Page: " + paper_link, unsafe_allow_html=True) | |
| # st.markdown("Please see this video for instructions on how to use this tool: " + instructions_video, unsafe_allow_html=True) | |
| st.title("Result Analysis Tool") | |
| font_path = "font/Ubuntu-B.ttf" | |
| dataset_cfg_path = os.path.join("cfg", "dataset") | |
| model_cfg_path = os.path.join("cfg", "model") | |
| metric_cfg_path = os.path.join("cfg", "metric") | |
| viz_cfg_path = os.path.join("cfg", "viz") | |
| datasets = get_configs(dataset_cfg_path) | |
| models = get_configs(model_cfg_path) | |
| metrics = get_configs(metric_cfg_path) | |
| visualizations = get_configs(viz_cfg_path) | |
| dataset_display_names = get_display_names(datasets) | |
| model_display_names = get_display_names(models) | |
| metric_display_names = get_display_names(metrics) | |
| viz_display_names = get_display_names(visualizations) | |
| assert len(set(dataset_display_names)) == len(dataset_display_names), "Dataset display names are not unique" | |
| assert len(set(model_display_names)) == len(model_display_names), "Model display names are not unique" | |
| assert len(set(metric_display_names)) == len(metric_display_names), "Metric display names are not unique" | |
| assert len(set(viz_display_names)) == len(viz_display_names), "Viz display names are not unique" | |
| selected_model_names = st.multiselect('Select multiple methods to compare', model_display_names) | |
| selected_models = [model for model in models if model['display_name'] in selected_model_names] | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| selected_dataset_name = st.selectbox('Select dataset', options=dataset_display_names) | |
| selected_dataset = [dataset for dataset in datasets if dataset['display_name'] == selected_dataset_name][0] | |
| with col2: | |
| selected_sequence = st.selectbox('Select sequence', options=selected_dataset["sequences"].keys()) | |
| usable_metrics = [metric for metric in metrics if metric['no_ref'] == selected_dataset['no_ref']] | |
| usable_metric_display_names = get_display_names(usable_metrics) | |
| selected_metric_names = st.multiselect('Select metrics to display', usable_metric_display_names) | |
| selected_metrics = [metric for metric in usable_metrics if metric['display_name'] in selected_metric_names] | |
| if not selected_dataset['has_frames']: | |
| usable_viz = [viz for viz in visualizations if viz['gt_type'] != 'frame'] | |
| else: | |
| usable_viz = visualizations | |
| usable_viz_display_names = get_display_names(usable_viz) | |
| selected_viz = st.multiselect('Select other visualizations to display', usable_viz_display_names) | |
| selected_visualizations = [viz for viz in visualizations if viz['display_name'] in selected_viz] | |
| if not st.button('Get Results'): | |
| st.stop() | |
| st.write("Retrieving results...") | |
| progress_bar = st.progress(0) | |
| gt_only_viz = [viz for viz in selected_visualizations if viz['viz_type'] == 'gt_only'] | |
| model_only_viz = [viz for viz in selected_visualizations if viz['viz_type'] == 'model_only'] | |
| both_viz = [viz for viz in selected_visualizations if viz['viz_type'] == 'both'] | |
| recon_viz = {"name": "recon", "display_name": "Reconstruction", "viz_type": "both", "gt_type": "frame"} | |
| ground_truth = {"name": "gt", "display_name": "Ground Truth", "model_id": "groundtruth"} | |
| model_viz = [recon_viz] + both_viz + selected_metrics + model_only_viz | |
| num_model_rows = len(model_viz) | |
| gt_viz = [] | |
| if selected_dataset['has_frames']: | |
| gt_viz.append(recon_viz) | |
| gt_viz.extend([viz for viz in both_viz if viz['gt_type'] == 'frame']) | |
| gt_viz.extend([viz for viz in gt_only_viz if viz['gt_type'] == 'frame']) | |
| gt_viz.extend([viz for viz in both_viz if viz['gt_type'] == 'event']) | |
| gt_viz.extend([viz for viz in gt_only_viz if viz['gt_type'] == 'event']) | |
| num_gt_rows = len(gt_viz) | |
| num_rows = max(num_model_rows, num_gt_rows) | |
| total_videos_needed = len(selected_models) * num_model_rows + num_gt_rows | |
| if len(gt_viz) > 0: | |
| selected_models.append(ground_truth) | |
| padding = 2 | |
| font_size = 20 | |
| num_cols = len(selected_models) | |
| crop_str = "crop=trunc(iw/2)*2:trunc(ih/2)*2" | |
| pad_str = "pad=ceil(iw/2)*2+{}:ceil(ih/2)*2+{}:{}:{}:white".format(padding*2, padding*2, padding, padding) | |
| num_elements = num_rows * num_cols | |
| # remove previous temp data | |
| files = glob.glob('temp_data/temp_*.mp4') | |
| for f in files: | |
| os.remove(f) | |
| w = selected_dataset["width"] | |
| h = selected_dataset["height"] | |
| input_filter_parts = [] | |
| xstack_input_parts = [] | |
| layout_parts = [] | |
| video_paths = [] | |
| row_heights = [""]*num_rows | |
| gt_viz_indices = [] | |
| if len(model_viz) > 1: | |
| left_pad = (font_size*0.7)*max([len(viz['display_name']) for viz in model_viz[1:]]) + padding*2 | |
| else: | |
| left_pad = 0 | |
| for row_idx in range(num_rows): | |
| for col_idx in range(num_cols): | |
| vid_idx = len(video_paths) | |
| progress_bar.progress(float(vid_idx) / total_videos_needed) | |
| cur_model = selected_models[col_idx] | |
| if cur_model['name'] == "gt": | |
| if row_idx < len(gt_viz): | |
| video_path = get_path_for_viz(selected_dataset, selected_sequence, cur_model, gt_viz[row_idx]) | |
| video_path = os.path.join("data", video_path) | |
| # download_file(video_path, local_video_path) | |
| if not os.path.isfile(video_path): | |
| raise ValueError("Could not find video: " + video_path) | |
| gt_viz_indices.append(vid_idx) | |
| else: | |
| continue | |
| else: | |
| if row_idx < len(model_viz): | |
| video_path = get_path_for_viz(selected_dataset, selected_sequence, cur_model, model_viz[row_idx]) | |
| video_path = os.path.join("data", video_path) | |
| # download_file(video_path, local_video_path) | |
| if not os.path.isfile(video_path): | |
| raise ValueError("Could not find video: " + video_path) | |
| else: | |
| continue | |
| if row_heights[row_idx] == "": | |
| row_heights[row_idx] = "h{}".format(vid_idx) | |
| if row_idx == 0: | |
| pad_height = font_size+padding*2 | |
| pad_txt_str = ",pad={}:{}:0:{}:white".format(w+padding*2, h+font_size+padding*4, pad_height) | |
| text_str = get_text_str(pad_height, w, cur_model['display_name'], font_path, font_size) | |
| pad_txt_str = pad_txt_str + "," + text_str | |
| elif row_idx > 0 and col_idx == 0: | |
| pad_txt_str = ",pad={}:ih:{}:0:white".format(w + left_pad + padding*2, left_pad) | |
| text_str = get_text_str("h", left_pad, model_viz[row_idx]['display_name'], font_path, font_size) | |
| pad_txt_str = pad_txt_str + "," + text_str | |
| else: | |
| pad_txt_str = "" | |
| input_filter_part = "[{}:v]scale={}:-1,{}{}[v{}]".format(vid_idx, w, pad_str, pad_txt_str, vid_idx) | |
| input_filter_parts.append(input_filter_part) | |
| xstack_input_part = "[v{}]".format(vid_idx) | |
| xstack_input_parts.append(xstack_input_part) | |
| video_paths.append(video_path) | |
| if row_idx == 0 or col_idx > 0: | |
| layout_w_parts = [str(left_pad)] + ["w{}".format(i) for i in range(col_idx)] | |
| layout_w = "+".join(layout_w_parts) | |
| else: | |
| layout_w = "+".join(["w{}".format(i) for i in range(col_idx)]) if col_idx > 0 else "0" | |
| if cur_model['name'] == "gt": | |
| layout_h = "+".join(["h{}".format(i) for i in gt_viz_indices[:-1]]) if row_idx > 0 else "0" | |
| else: | |
| layout_h = "+".join(row_heights[:row_idx]) if row_idx > 0 else "0" | |
| layout_part = layout_w + "_" + layout_h | |
| layout_parts.append(layout_part) | |
| inputs_str = " ".join(["-i " + video_path for video_path in video_paths]) | |
| num_inputs = len(video_paths) | |
| input_scaling_str = ";".join(input_filter_parts) | |
| xstack_input_str = "".join(xstack_input_parts) | |
| layout_str = "|".join(layout_parts) | |
| # opt = "-c:v libx264 -preset veryslow -crf 18 -c:a copy" | |
| opt = "" | |
| # opt_fill = ":fill=black" | |
| opt_fill = ":fill=white" | |
| # opt_fill = "" | |
| ffmpeg_command_str = "ffmpeg -y " + inputs_str + " -filter_complex \"" + input_scaling_str + ";" + xstack_input_str + "xstack=inputs=" + str(num_inputs) + ":layout=" + layout_str + opt_fill + "\"" + opt + " output.mp4" | |
| print(ffmpeg_command_str) | |
| ret = subprocess.call(ffmpeg_command_str, shell=True) | |
| if ret != 0: | |
| st.error("Error while generating video.") | |
| st.stop() | |
| video_file = open('output.mp4', 'rb') | |
| video_bytes = video_file.read() | |
| st.video(video_bytes) | |