AnsenH's picture
feat: add our model
24615d9
raw
history blame
4.03 kB
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from visualization import highlight_score_plot
from get_highlight.get_highlight_algorithm import HighlightModel, get_clip_scores_diff, get_clip_similarity_scores
def calculate_MAP(human_label, scores, human_label_threshold=1):
sorted_label_indices, sorted_scores_indices = np.argsort(human_label)[::-1], np.argsort(scores)[::-1]
top_n = np.sum(np.array(human_label) >= human_label_threshold)
target_indices = sorted_label_indices[:top_n]
precisions = []
n_matched = 0
for n, time_stamp_idx in enumerate(sorted_scores_indices):
if time_stamp_idx in target_indices:
n_matched += 1
pr = n_matched / (n + 1)
precisions.append(pr)
return np.mean(precisions)
def parse_log_nohup(path):
file = open(path, 'r')
Lines = file.readlines()
exp_name = os.path.basename(os.path.dirname(path))
train_loss, val_loss = [], []
count = 0
for line in Lines:
count += 1
if 'Finished training for epoch' in line or 'Finish training for epoch' in line:
loss = line.split(' ')[-1]
if '/' in loss:
loss = loss.split('/')[0]
train_loss.append(float(loss))
if 'Finished calculating validation loss' in line:
loss = line.split(' ')[-1]
if '/' in loss:
loss = loss.split('/')[0]
val_loss.append(float(loss))
assert len(train_loss) == len(val_loss)
return exp_name, train_loss, val_loss
def min_max_normalized_scores(scores):
scores = np.array(scores)
return (scores - scores.min()) / (scores.max() - scores.min())
def evaluate(testing_data_paths, history_scores, ckpt_paths, videofile_to_frames,
use_min_max:bool=False, use_similarity:bool=False, use_score_diff:bool=False):
metrics = defaultdict(list)
for idx, video_path in enumerate(testing_data_paths):
basename = os.path.basename(video_path)
frames = videofile_to_frames(video_path)
print('========', basename,'========')
plot_scores = dict()
model_scores = history_scores[basename]
gt = model_scores['Human label']
for ckpt_path in ckpt_paths:
model = HighlightModel(ckpt_path = ckpt_path, i2v_path=os.path.join('..', 'weight', 'heads24_attn_epoch30_loss0.22810565.pt'))
model.eval()
frames_per_clip = model.scoring_model.frames_per_clip
frame_vectors, _, scores = model(frames)
clip_similarity_scores, _ = get_clip_similarity_scores(frame_vectors, model.scoring_model.frames_per_clip)
scores_diff = get_clip_scores_diff(scores)
if use_similarity:
scores *= clip_similarity_scores
if use_score_diff:
scores *= scores_diff
exp_name = os.path.basename(os.path.dirname(ckpt_path))
if len(gt) != len(scores):
scores = scores[:-1]
plot_scores[exp_name] = min_max_normalized_scores(scores) if use_min_max else scores
plot_scores['Summarizer'] = min_max_normalized_scores(model_scores['Summarizer']) if use_min_max else model_scores['Summarizer']
plot_scores['Human label'] = model_scores['Human label']
for model_name in plot_scores.keys():
MAP_1 = calculate_MAP(plot_scores['Human label'], plot_scores[model_name], human_label_threshold=1)
MAP_05 = calculate_MAP(plot_scores['Human label'], plot_scores[model_name], human_label_threshold=0.5)
metrics[f'{model_name}_MAP_1'].append(MAP_1)
metrics[f'{model_name}_MAP_05'].append(MAP_05)
print(f'{model_name:<80} MAP@1:{MAP_1:.4f} [email protected]:{MAP_05:.4f}')
try:
highlight_score_plot(basename, plot_scores, frames_per_clip, None)
except:
plt.close()
print(f'Fail:{basename}, {len(scores)}, {len(gt)}')
return metrics