File size: 24,803 Bytes
5a11d0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0669ac2
a6f9eca
d2e201b
 
 
 
 
 
 
 
 
 
5a11d0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cf1343
5a11d0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0419f9
5a11d0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75c9641
d2e201b
4a6c53d
5a11d0a
 
 
 
75c9641
 
 
 
 
 
5a11d0a
75c9641
83f868a
 
75c9641
 
 
 
 
 
 
5a11d0a
 
 
 
75c9641
83f868a
 
75c9641
 
 
 
5a11d0a
75c9641
 
5a11d0a
 
 
 
 
 
 
 
 
 
 
 
 
75c9641
 
 
 
 
 
 
 
 
 
 
 
5a11d0a
 
 
 
db499bc
 
 
75c9641
 
 
5a11d0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75c9641
 
d2e201b
5a11d0a
 
75c9641
 
 
db499bc
d2e201b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
import streamlit as st
# from pytube import YouTube
# from pytube import extract
import cv2
from PIL import Image
import clip as openai_clip
import torch
import math
import numpy as np
import tempfile
# from humanfriendly import format_timespan
import json
import sys
from random import randrange
import logging
# from pyunsplash import PyUnsplash
import requests
import io
from io import BytesIO
import base64
import altair as alt
from streamlit_vega_lite import altair_component
import pandas as pd
from datetime import timedelta
import math
from decord import VideoReader, cpu, gpu
from moviepy.video.io.VideoFileClip import VideoFileClip
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
from moviepy.editor import *
import glob

# @st.cache(show_spinner=False)
def load_model():
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  model, preprocess = openai_clip.load('ViT-B/32', device=device)
  if 'model' not in st.session_state:
    st.session_state.model = model
    st.session_state.preprocess = preprocess
    st.session_state.device = device
  st.session_state.model = model
  st.session_state.preprocess = preprocess
  st.session_state.device = device

def fetch_video(url):
  yt = YouTube(url)
  streams = yt.streams.filter(adaptive=True, subtype='mp4', resolution='360p', only_video=True)
  length = yt.length
  if length >= 300:
    st.error('Please find a YouTube video shorter than 5 minutes. Sorry about this, the server capacity is limited for the time being.')
    st.stop()
  video = streams[0]
  return video, video.url

# @st.cache()
# def extract_frames(video):
#   frames = []
#   capture = cv2.VideoCapture(video)
#   fps = capture.get(cv2.CAP_PROP_FPS)
#   current_frame = 0
#   while capture.isOpened():
#     ret, frame = capture.read()
#     if ret == True:
#       frames.append(Image.fromarray(frame[:, :, ::-1]))
#     else:
#       break
#     current_frame += fps
#     capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
#   # print(f'Frames extracted: {len(frames)}')

#   return frames, fps

# @st.cache()
def video_to_frames(video):
  vr = VideoReader(video)
  frames = []
  frame_count = len(vr)
  fps = vr.get_avg_fps()
  for i in range(0, frame_count, round(fps)):
  # for i in range(0, frame_count):
    frame = vr[i].asnumpy()
    y_dim = frame.shape[0]
    x_dim = frame.shape[1]
    frames.append(Image.fromarray(frame))
  return frames, fps, x_dim, y_dim

def video_to_info(video):
  vr = VideoReader(video)
  frames = []
  frame_count = len(vr)
  fps = vr.get_avg_fps()
  frame = vr[0].asnumpy()
  y_dim = frame.shape[0]
  x_dim = frame.shape[1]
  return fps, x_dim, y_dim

# @st.cache()
def encode_frames(video_frames):
  batch_size = 256
  batches = math.ceil(len(video_frames) / batch_size)
  video_features = torch.empty([0, 512], dtype=torch.float16).to(st.session_state.device)
  for i in range(batches):
    batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
    batch_preprocessed = torch.stack([st.session_state.preprocess(frame) for frame in batch_frames]).to(st.session_state.device)
    with torch.no_grad():
      batch_features = st.session_state.model.encode_image(batch_preprocessed)
      batch_features /= batch_features.norm(dim=-1, keepdim=True)
    video_features = torch.cat((video_features, batch_features))
  # print(f'Features: {video_features.shape}')
  return video_features

def classify_activity(video_features, activities_list):
	text = torch.cat([openai_clip.tokenize(
		f'{activity}') for activity in activities_list]).to(st.session_state.device)
	with torch.no_grad():
		text_features = st.session_state.model.encode_text(text)
		text_features /= text_features.norm(dim=-1, keepdim=True)
	logit_scale = st.session_state.model.logit_scale.exp()
	video_features = torch.from_numpy(video_features)
	similarities = (logit_scale * video_features @
                 text_features.t()).softmax(dim=-1)
	probs, word_idxs = similarities[0].topk(5)
	primary_activity = []
	for prob, word_idx in zip(probs, word_idxs):
		primary_activity.append(activities_list[word_idx])
	# primary_activity = activities_list[word_idx]
	return primary_activity

def encode_photos(photos):
  batch_size = 256
  batches = math.ceil(len(photos) / batch_size)
  video_features = torch.empty([0, 512], dtype=torch.float16).to(st.session_state.device)
  for i in range(batches):
    batch_frames = photos[i*batch_size : (i+1)*batch_size]
    batch_preprocessed = torch.stack([st.session_state.preprocess(Image.open(frame)) for frame in batch_frames]).to(st.session_state.device)
    with torch.no_grad():
      batch_features = st.session_state.model.encode_image(batch_preprocessed)
      batch_features /= batch_features.norm(dim=-1, keepdim=True)
    video_features = torch.cat((video_features, batch_features))
  # print(f'Features: {video_features.shape}')
  return video_features

def img_to_bytes(img):
  img_byte_arr = io.BytesIO()
  img.save(img_byte_arr, format='JPEG')
  img_byte_arr = img_byte_arr.getvalue()
  return img_byte_arr

def normalize(vector):
  return (vector - np.min(vector)) / (np.max(vector) - np.min(vector))

def format_img(img):
  size = 150, 150
  # img = Image.fromarray(img)
  img.thumbnail(size, Image.Resampling.LANCZOS)
  output = io.BytesIO()
  img.save(output, format='PNG')
  encoded_string = f'data:image/png;base64,{base64.b64encode(output.getvalue()).decode()}'
  return encoded_string

def get_photos(keyword):
  photo_collection = []
  for filename in glob.glob(f'photos/{st.session_state.domain.lower()}/*.jpeg')[:1]:
    photo = Image.open(filename)
    photo_collection.append(photo)
  return photo_collection

  # # api_key = 'hzcKZ0e4we95wSd8_ip2zTB3m2DrOMWehAxrYjqjwg0'
  # api_key = 'fZ1nE7Y4NC-iYGmqgv-WuyM8m9p0LroCdAOZOR6tyho'
  # unsplash_search = PyUnsplash(api_key=api_key)
  # logging.getLogger('pyunsplash').setLevel(logging.DEBUG)
  # search = unsplash_search.search(type_='photos', query=keyword) # per_page
  # photo_collection = []
  # # st.markdown(f'**Unsplash photos for `{keyword}`**')
  # for result in search.entries:
  #   photo_url = result.link_download
  #   response = requests.get(photo_url)
  #   photo = Image.open(BytesIO(response.content))
  #   # st.image(photo, width=200)
  #   photo_collection.append(photo)
  # return photo_collection

def display_results(best_photo_idx):
  st.markdown('**Top 10 highlights**')
  result_arr = []
  for frame_id in best_photo_idx:
    result = st.session_state.video_frames[frame_id]
    st.image(result)
  return result_arr

def make_df(similarities):
  similarities = similarities
  df = pd.DataFrame()
  df['keyword'] = [keyword] * len(similarities)
  df['x'] = [i for i, _ in enumerate(similarities)]
  df['y'] = normalize(np.power(similarities, 8))
  df['image'] = [format_img(frame) for frame in st.session_state.video_frames]
  return df

# @st.cache()
def compute_scores(search_query, video_features, text_query, display_results_count=10):
  sum_photo = torch.zeros(1, 512)
  for photo in search_query:
    with torch.no_grad():
      image_features = st.session_state.model.encode_image(st.session_state.preprocess(photo).unsqueeze(0).to(st.session_state.device))
      image_features /= image_features.norm(dim=-1, keepdim=True)
      sum_photo += sum_photo + image_features
  avg_photo = sum_photo / len(search_query)
  video_features = torch.from_numpy(video_features)
  similarities = (100.0 * video_features @ avg_photo.T)  
  # values, best_photo_idx = similarities.topk(display_results_count, dim=0)
  # display_results(best_photo_idx)
  return similarities.cpu().numpy()

def avenir():
    font = 'Avenir'
    return {
        'config' : {
             'title': {'font': font},
             'axis': {
                  'labelFont': font,
                  'titleFont': font
             }
        }
    }

alt.themes.register('avenir', avenir)
alt.themes.enable('avenir')

# TODO: Make playhead scores and average according to keyword
# TODO: Maximum interval selection
# TODO: Interactive legend https://altair-viz.github.io/gallery/interactive_legend.html
# TODO: Multi-line highlight https://altair-viz.github.io/gallery/multiline_highlight.html
@st.cache
def draw_chart(df, mode):
  if st.session_state.mode == 'Automatic':
    nearest = alt.selection(type='single', nearest=True, on='mouseover', empty='none')
    line = alt.Chart(df).mark_line().encode(
      x=alt.X('x:Q', axis=alt.Axis(labels=True, tickSize=0, title='')),
      y=alt.Y('y', axis=alt.Axis(labels=False, tickSize=0, title='')),
      # color=alt.Color('keyword:N', scale=alt.Scale(scheme='tableau20')),
      color=alt.value('#00C7BE'),
      # color=alt.Color('#9b59b6'),
    )
    selectors = alt.Chart(df).mark_point().encode(
      x='x:Q',
      opacity=alt.value(0),
    ).add_selection(
        nearest
    )
    rules = alt.Chart(df).mark_rule(color='black').encode(
      x='x:Q',
    ).transform_filter(
      nearest
    )
    points = line.mark_point().encode(
      opacity=alt.condition(nearest, alt.value(1), alt.value(0))
    )
    text = line.mark_text(align='center', yOffset=-110, fontSize=16).encode(
      text=alt.condition(nearest, 'y:N', alt.value(' ')),
      color=alt.value('#000000'),
      # fontSize=30
    ).transform_calculate(y=f'format(datum.y, ".2f")')
    image = line.mark_image(align='center', width=150, height=150, yOffset=-60).encode(
      url=alt.condition(nearest, 'image', alt.value(' '))
    )
    chart = alt.layer(line, selectors, points, rules, text, image)
  elif st.session_state.mode == 'brush':
    brush = alt.selection(type='interval', encodings=['x'])
    line = alt.Chart(df).mark_line().encode( # https://www.rdocumentation.org/packages/vegalite/versions/0.6.1/topics/mark_line
      x=alt.X('x:Q', axis=alt.Axis(labels=True, tickSize=0, title='')),
      y=alt.Y('y:Q', axis=alt.Axis(labels=False, tickSize=0, title='')),
      # color=alt.Color('keyword:N', scale=alt.Scale(scheme='tableau20')),
      color=alt.value('#00C7BE'),      
    ).add_selection(
      brush
    )
    text = alt.Chart(df).transform_filter(brush).mark_text(
      align='right',
      # baseline='top',
      # dx=1500
      dx=750,
      dy=-12,
      fontSize=24,
      fontWeight=800,
    ).encode(
      # x='max(x):Q',
      y='mean(y):Q',
      # dy=alt.value(10),
      text=alt.Text('mean(y):Q', format='.2f'),
    )
    average = alt.Chart(df).mark_rule(color='black', strokeDash=[5, 5]).encode(
      y='mean(y):Q',
      # size=alt.SizeValue(3),
    ).transform_filter(
      brush
    )
    # chart = alt.layer(line, average, text)
    chart = line
  elif st.session_state.mode == 'User selection':
    brush = alt.selection(type='interval', encodings=['x'])
    line = alt.Chart(df).mark_line().encode( # https://www.rdocumentation.org/packages/vegalite/versions/0.6.1/topics/mark_line
      x=alt.X('x:Q', axis=alt.Axis(labels=True, tickSize=0, title='')),
      y=alt.Y('y:Q', axis=alt.Axis(labels=False, tickSize=0, title='')),
      # color=alt.Color('keyword:N', scale=alt.Scale(scheme='tableau20')),
      color=alt.value('#00C7BE'),      
    ).add_selection(
      brush
    )
    text = alt.Chart(df).transform_filter(brush).mark_text(
      align='right',
      # baseline='top',
      # dx=1500
      dx=750,
      dy=-12,
      fontSize=24,
      fontWeight=800,
    ).encode(
      # x='max(x):Q',
      y='mean(y):Q',
      # dy=alt.value(10),
      text=alt.Text('mean(y):Q', format='.2f'),
    )
    average = alt.Chart(df).mark_rule(color='black', strokeDash=[5, 5]).encode(
      y='mean(y):Q',
      # size=alt.SizeValue(3),
    ).transform_filter(
      brush
    )
    # chart = alt.layer(line, average, text)
    chart = line
  return chart.properties(width=1250, height=500).configure_axis(grid=False, domain=False).configure_view(strokeOpacity=0)
  # return line

def max_subarray(arr, k):
  n = len(arr)
  if (n < k):
    st.write('Video too short')
  res = 0
  left = 0
  right = k
  for i in range(k):
    res += arr[i]
  curr_sum = res
  for i in range(k, n):
    curr_sum += arr[i] - arr[i - k]
    if curr_sum > res:
      res = curr_sum
      left = i - k
      right = i
  return res, left, right

def edit_video(template, df_all):
  video_path = f'videos/{st.session_state.domain.lower()}.mp4'
  if template == 'Coming In Hot by Andy Mineo & Lecrae (hype, 7 seconds)':
    res, left, right = max_subarray(df_all['y'].tolist(), 7)
    video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
    fps = video.fps
    x_dim = st.session_state.x_dim
    y_dim = st.session_state.y_dim
    music_path = 'music/coming-in-hot.mp3'
    blank1 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.6)
    flash1 = video.subclip(t_start=0, t_end=1.2)
    blank2 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
    flash2 = video.subclip(t_start=1.3, t_end=1.4)
    blank3 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
    flash3 = video.subclip(t_start=1.5, t_end=3.3)
    blank4 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
    flash4 = video.subclip(t_start=3.4, t_end=3.5)
    blank5 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
    flash5 = video.subclip(t_start=3.6, t_end=4.6)
    blank6 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
    flash6 = video.subclip(t_start=4.7, t_end=4.8)
    blank7 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
    highlight = video.subclip(t_start=4.9, t_end=6.384)
    output = concatenate_videoclips([blank1, flash1, blank2, flash2, blank3, flash3, blank4, flash4, blank5, flash5, blank6, flash6, blank7, highlight])
  elif template == 'Thinking Out Loud Cypher by Jermsego (hype, 8 seconds)':
    res, left, right = max_subarray(df_all['y'].tolist(), 7)
    video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
    fps = video.fps
    x_dim = st.session_state.x_dim
    y_dim = st.session_state.y_dim
    music_path = 'music/thinking-out-loud.mp3'
    blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=1.6)
    highlight = video.subclip(t_start=0, t_end=6.852)
    output = concatenate_videoclips([blank, highlight])
  elif template == 'Sheesh by Surfaces (upbeat, 10 seconds)':
    res, left, right = max_subarray(df_all['y'].tolist(), 8)
    video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
    fps = video.fps
    x_dim = st.session_state.x_dim
    y_dim = st.session_state.y_dim
    music_path = 'music/sheesh.mp3'
    blank1 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=3.5)
    flash1 = video.subclip(t_start=0, t_end=0.1)
    blank2 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
    flash2 = video.subclip(t_start=0.2, t_end=0.3)
    blank3 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
    flash3 = video.subclip(t_start=0.4, t_end=0.5)
    blank4 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.1)
    flash4 = video.subclip(t_start=0.6, t_end=0.7)
    blank5 = ColorClip((x_dim, y_dim), (0, 0, 0), duration=0.9)
    highlight = video.subclip(t_start=1.6, t_end=7.18408163265)
    output = concatenate_videoclips([blank1, flash1, blank2, flash2, blank3, flash3, blank4, flash4, blank5, highlight])
  elif template == 'Moon by Kid Francescoli (tranquil, 10 seconds)':
    res, left, right = max_subarray(df_all['y'].tolist(), 9)
    video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
    fps = video.fps
    x_dim = st.session_state.x_dim
    y_dim = st.session_state.y_dim
    music_path = 'music/and-it-went-like.mp3'
    blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=1.9)
    highlight = video.subclip(t_start=0, t_end=8.132)
    output = concatenate_videoclips([blank, highlight])
  elif template == 'Ready Set by Joey Valence & Brae (old school, 10 seconds)':
    res, left, right = max_subarray(df_all['y'].tolist(), 11)
    video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
    fps = video.fps
    x_dim = st.session_state.x_dim
    y_dim = st.session_state.y_dim
    music_path = 'music/ready-set.mp3'
    highlight = video.subclip(t_start=0, t_end=10.512)
    output = highlight
  elif template == 'Lovewave by The 1-800 (tranquil, 13 seconds)':
    res, left, right = max_subarray(df_all['y'].tolist(), 12)
    video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
    fps = video.fps
    x_dim = st.session_state.x_dim
    y_dim = st.session_state.y_dim
    music_path = 'music/lovewave.mp3'
    blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=2.1)
    highlight = video.subclip(t_start=0, t_end=11.58)
    output = concatenate_videoclips([blank, highlight])
  elif template == 'And It Sounds Like by Forrest Nolan (tranquil, 17 seconds)':
    res, left, right = max_subarray(df_all['y'].tolist(), 16)
    video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
    fps = video.fps
    x_dim = st.session_state.x_dim
    y_dim = st.session_state.y_dim
    music_path = 'music/and-it-sounds-like.mp3'
    blank = ColorClip((x_dim, y_dim), (0, 0, 0), duration=2)
    highlight = video.subclip(t_start=0, t_end=15.928)
    output = concatenate_videoclips([blank, highlight])
  elif template == 'Comfort Chain by Instupendo (lofi, 18 seconds)':
    res, left, right = max_subarray(df_all['y'].tolist(), 19)
    video = VideoFileClip(video_path).subclip(t_start=left, t_end=right)
    fps = video.fps
    x_dim = st.session_state.x_dim
    y_dim = st.session_state.y_dim
    music_path = 'music/comfort-chain.mp3'
    highlight = video.subclip(t_start=0, t_end=18.432000000000002)
    output = highlight
  # st.write(res, left, right)
  song = AudioFileClip(music_path)
  output = output.set_audio(song)
  output.write_videofile('output.mp4', temp_audiofile='temp.m4a', remove_temp=True, audio_codec='aac', logger=None, fps=fps)
  st.video('output.mp4')
  # return output

def crop_video(df_all, left, right):
  video_path = f'videos/{st.session_state.domain.lower()}.mp4'
  video = VideoFileClip(video_path)
  fps = video.fps
  music_path = 'music/loop.mp3'
  song = AudioFileClip(music_path)
  video = video.set_audio(song)
  output = video.subclip(t_start=left, t_end=right)
  output.write_videofile('output.mp4', temp_audiofile='temp.m4a', remove_temp=True, audio_codec='aac', logger=None, fps=fps)
  st.video('output.mp4')
  # return output

st.set_page_config(page_title='Videogenic', page_icon = '✨', layout = 'wide', initial_sidebar_state = 'collapsed')

hide_streamlit_style = """
                      <style>
                      #MainMenu {visibility: hidden;}
                      footer {visibility: hidden;}
                      * {font-family: Avenir; cursor: pointer;}
                      .css-gma2qf {display: flex; justify-content: center; font-size: 42px; font-weight: bold;}
                      a:link {text-decoration: none;}
                      a:hover {text-decoration: none;}
                      .st-ba {font-family: Avenir;}
                      </style>
                      """
st.markdown(hide_streamlit_style, unsafe_allow_html=True)

# clustrmaps = """
#             <a href="https://clustrmaps.com/site/1bham" target="_blank" title="Visit tracker"><img src="//www.clustrmaps.com/map_v2.png?d=NhNk5g9hy6Y06nqo7RirhHvZSr89uSS8rPrt471wAXw&cl=ffffff" width="0" height="0"></a>
#             """

# st.markdown(clustrmaps, unsafe_allow_html=True)

# ss = SessionState.get(url=None, id=None, input=None, file_name=None, video=None, video_name=None, video_frames=None, video_features=None, fps=None, mode=None, query=None, progress=1)

st.title('Videogenic ✨')
if 'progress' not in st.session_state:
  st.session_state.progress = 1

# mode = 'play'
# mode = 'brush'
# mode = 'select'

if st.session_state.progress == 1:
  with st.spinner('Loading model...'):
    load_model()
  domain = st.selectbox('Select video',('Skydiving', 'Surfing', 'Skateboarding')) # Entire journey, montage, vlog
  if 'domain' not in st.session_state:
    st.session_state.domain = domain
  st.session_state.domain = domain
  if st.button('Process video'):
    with st.spinner('Processing video...'):
      video_name = f'videos/{st.session_state.domain.lower()}.mp4'
      video_file = open(video_name, 'rb')
      video_bytes = video_file.read()
      if 'video' not in st.session_state:
        st.session_state.video = video_bytes
      st.session_state.video = video_bytes
      # st.video(st.session_state.video)
      video_frames, fps, x_dim, y_dim = video_to_frames(video_name) # first run; video_to_info
      np.save(f'files/{st.session_state.domain.lower()}.npy', video_frames)
      fps, x_dim, y_dim = video_to_info(video_name)
      video_frames = np.load(f'files/{st.session_state.domain.lower()}.npy', allow_pickle=True)
      if 'video_frames' not in st.session_state:
        st.session_state.video_frames = video_frames
        st.session_state.fps = fps
        st.session_state.x_dim = x_dim
        st.session_state.y_dim = y_dim
      st.session_state.video_frames = video_frames
      st.session_state.fps = fps
      st.session_state.x_dim = x_dim
      st.session_state.y_dim = y_dim
      print('Extracted frames')
      encoded_frames = encode_frames(video_frames) # first run
      np.save(f'files/{st.session_state.domain.lower()}_features.npy', encoded_frames)
      encoded_frames = np.load(f'files/{st.session_state.domain.lower()}_features.npy', allow_pickle=True)
      if 'video_features' not in st.session_state:
        # st.session_state.video_features = encoded_frames
        st.session_state.video_features = encoded_frames
      st.session_state.video_features = encoded_frames
      print('Encoded frames')
      st.session_state.progress = 2

# with open('activities.txt') as f:
#   activities_list = [line.rstrip('\n') for line in f]
# keywords = classify_activity(st.session_state.video_features, activities_list)
# st.write(keywords)

if st.session_state.progress == 2:
  mode = st.radio('Select mode', ('Automatic', 'User selection'))
  if 'mode' not in st.session_state:
    st.session_state.mode = mode
  st.session_state.mode = mode
  # keywords = list(st.text_input('Enter topic').split(','))
  # if st.button('Compute scores') and keywords is not None:
  with st.spinner('Computing highlight scores...'):
    keyword = st.session_state.domain.lower()
    df_list = []
    # for keyword in keywords:
    img_set = get_photos(keyword)
    similarities = compute_scores(img_set, st.session_state.video_features, keyword)
    # st.write(similarities)
    df = make_df(similarities)
    df_list.append(df)
    df_all = pd.concat(df_list, ignore_index=True, sort=False)
    if 'df_all' not in st.session_state:
      st.session_state.df_all = df_all
    st.session_state.df_all = df_all
  # st.write(df_all)
  # highlight_length = 7.033
  # st.write(st.session_state.fps)
  with st.spinner('Visualizing highlight scores...'):
    selection = altair_component(draw_chart(df_all, st.session_state.mode))
    print(selection)
  if 'selection' not in st.session_state:
    st.session_state.selection = selection
  st.session_state.selection = selection
# if '_vgsid_' in selection:
#   # the ids start at 1
#   st.write(df.iloc[[selection['_vgsid_'][0] - 1]])
# else:
#   st.info('Hover over the chart above to see details about the Penguin here.')
  # if 'x' in selection:
  #   # the ids start at 1
  #   st.write(selection['x'])
    # chart = draw_chart(df_all, mode)
    # st.altair_chart(chart, use_container_width=False)
    # st.session_state.progress = 3

  # if st.session_state.progress == 3:
  if st.session_state.mode == 'Automatic':
    # template = st.selectbox('Select template', ['Coming In Hot by Andy Mineo & Lecrae (hype, 7 seconds)', 'Thinking Out Loud Cypher by Jermsego (hype, 8 seconds)', 'Sheesh by Surfaces (upbeat, 10 seconds)',
    #                           'Moon by Kid Francescoli (tranquil, 10 seconds)', 'Ready Set by Joey Valence & Brae (old school, 10 seconds)', 'Lovewave by The 1-800 (tranquil, 13 seconds)',
    #                           'And It Sounds Like by Forrest Nolan (tranquil, 17 seconds)', 'Comfort Chain by Instupendo (lofi, 18 seconds)'])
    template = st.selectbox('Select template', ['Coming In Hot by Andy Mineo & Lecrae (hype, 7 seconds)', 'Sheesh by Surfaces (upbeat, 10 seconds)', 'Lovewave by The 1-800 (tranquil, 13 seconds)'])
    if st.button('Generate video'):
      with st.spinner('Generating highlight video...'):
        edit_video(template, st.session_state.df_all)
        st.balloons()
  elif st.session_state.mode == 'User selection':
    if st.button('Generate video'):
      left = st.session_state.selection['x'][0]
      right = st.session_state.selection['x'][1]
      with st.spinner('Generating highlight video...'):
        crop_video(st.session_state.df_all, left, right)
        st.balloons()