Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from collections import defaultdict | |
from visualization import highlight_score_plot | |
from get_highlight.get_highlight_algorithm import HighlightModel, get_clip_scores_diff, get_clip_similarity_scores | |
def calculate_MAP(human_label, scores, human_label_threshold=1): | |
sorted_label_indices, sorted_scores_indices = np.argsort(human_label)[::-1], np.argsort(scores)[::-1] | |
top_n = np.sum(np.array(human_label) >= human_label_threshold) | |
target_indices = sorted_label_indices[:top_n] | |
precisions = [] | |
n_matched = 0 | |
for n, time_stamp_idx in enumerate(sorted_scores_indices): | |
if time_stamp_idx in target_indices: | |
n_matched += 1 | |
pr = n_matched / (n + 1) | |
precisions.append(pr) | |
return np.mean(precisions) | |
def parse_log_nohup(path): | |
file = open(path, 'r') | |
Lines = file.readlines() | |
exp_name = os.path.basename(os.path.dirname(path)) | |
train_loss, val_loss = [], [] | |
count = 0 | |
for line in Lines: | |
count += 1 | |
if 'Finished training for epoch' in line or 'Finish training for epoch' in line: | |
loss = line.split(' ')[-1] | |
if '/' in loss: | |
loss = loss.split('/')[0] | |
train_loss.append(float(loss)) | |
if 'Finished calculating validation loss' in line: | |
loss = line.split(' ')[-1] | |
if '/' in loss: | |
loss = loss.split('/')[0] | |
val_loss.append(float(loss)) | |
assert len(train_loss) == len(val_loss) | |
return exp_name, train_loss, val_loss | |
def min_max_normalized_scores(scores): | |
scores = np.array(scores) | |
return (scores - scores.min()) / (scores.max() - scores.min()) | |
def evaluate(testing_data_paths, history_scores, ckpt_paths, videofile_to_frames, | |
use_min_max:bool=False, use_similarity:bool=False, use_score_diff:bool=False): | |
metrics = defaultdict(list) | |
for idx, video_path in enumerate(testing_data_paths): | |
basename = os.path.basename(video_path) | |
frames = videofile_to_frames(video_path) | |
print('========', basename,'========') | |
plot_scores = dict() | |
model_scores = history_scores[basename] | |
gt = model_scores['Human label'] | |
for ckpt_path in ckpt_paths: | |
model = HighlightModel(ckpt_path = ckpt_path, i2v_path=os.path.join('..', 'weight', 'heads24_attn_epoch30_loss0.22810565.pt')) | |
model.eval() | |
frames_per_clip = model.scoring_model.frames_per_clip | |
frame_vectors, _, scores = model(frames) | |
clip_similarity_scores, _ = get_clip_similarity_scores(frame_vectors, model.scoring_model.frames_per_clip) | |
scores_diff = get_clip_scores_diff(scores) | |
if use_similarity: | |
scores *= clip_similarity_scores | |
if use_score_diff: | |
scores *= scores_diff | |
exp_name = os.path.basename(os.path.dirname(ckpt_path)) | |
if len(gt) != len(scores): | |
scores = scores[:-1] | |
plot_scores[exp_name] = min_max_normalized_scores(scores) if use_min_max else scores | |
plot_scores['Summarizer'] = min_max_normalized_scores(model_scores['Summarizer']) if use_min_max else model_scores['Summarizer'] | |
plot_scores['Human label'] = model_scores['Human label'] | |
for model_name in plot_scores.keys(): | |
MAP_1 = calculate_MAP(plot_scores['Human label'], plot_scores[model_name], human_label_threshold=1) | |
MAP_05 = calculate_MAP(plot_scores['Human label'], plot_scores[model_name], human_label_threshold=0.5) | |
metrics[f'{model_name}_MAP_1'].append(MAP_1) | |
metrics[f'{model_name}_MAP_05'].append(MAP_05) | |
print(f'{model_name:<80} MAP@1:{MAP_1:.4f} [email protected]:{MAP_05:.4f}') | |
try: | |
highlight_score_plot(basename, plot_scores, frames_per_clip, None) | |
except: | |
plt.close() | |
print(f'Fail:{basename}, {len(scores)}, {len(gt)}') | |
return metrics |