Spaces:
Runtime error
Runtime error
Add time conversions from outputs
Browse files- app.py +4 -1
- functions/convert_time.py +50 -0
- functions/model_infer.py +1 -1
- functions/punctuation.py +1 -1
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import re
|
|
| 3 |
import gradio as gr
|
| 4 |
from functions.punctuation import punctuate
|
| 5 |
from functions.model_infer import predict_from_document
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
title = "sponsoredBye - never listen to sponsors again"
|
|
@@ -12,16 +13,18 @@ article = "Check out [the original Rick and Morty Bot](https://huggingface.co/sp
|
|
| 12 |
|
| 13 |
def pipeline(video_url):
|
| 14 |
video_id = video_url.split("?v=")[-1]
|
| 15 |
-
punctuated_text = punctuate(video_id)
|
| 16 |
sentences = re.split(r"[\.\!\?]\s", punctuated_text)
|
| 17 |
classification, probs = predict_from_document(sentences)
|
| 18 |
# return punctuated_text
|
|
|
|
| 19 |
return [
|
| 20 |
{
|
| 21 |
"start": "12:05",
|
| 22 |
"end": "12:52",
|
| 23 |
"classification": str(classification),
|
| 24 |
"probabilities": probs,
|
|
|
|
| 25 |
}
|
| 26 |
]
|
| 27 |
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
from functions.punctuation import punctuate
|
| 5 |
from functions.model_infer import predict_from_document
|
| 6 |
+
from functions.convert_time import match_mask_and_transcript
|
| 7 |
|
| 8 |
|
| 9 |
title = "sponsoredBye - never listen to sponsors again"
|
|
|
|
| 13 |
|
| 14 |
def pipeline(video_url):
|
| 15 |
video_id = video_url.split("?v=")[-1]
|
| 16 |
+
punctuated_text, transcript = punctuate(video_id)
|
| 17 |
sentences = re.split(r"[\.\!\?]\s", punctuated_text)
|
| 18 |
classification, probs = predict_from_document(sentences)
|
| 19 |
# return punctuated_text
|
| 20 |
+
times = match_mask_and_transcript(sentences, transcript, classification)
|
| 21 |
return [
|
| 22 |
{
|
| 23 |
"start": "12:05",
|
| 24 |
"end": "12:52",
|
| 25 |
"classification": str(classification),
|
| 26 |
"probabilities": probs,
|
| 27 |
+
"times": times,
|
| 28 |
}
|
| 29 |
]
|
| 30 |
|
functions/convert_time.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from thefuzz import fuzz
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def match_mask_and_transcript(split_punct, transcript, classification):
|
| 7 |
+
"""
|
| 8 |
+
Input:
|
| 9 |
+
split_punct: the punctuated text, split on ?/!/.\s,
|
| 10 |
+
transcript: original transcript with timestamps
|
| 11 |
+
classification: classification object (list of numbers 0,1)
|
| 12 |
+
Output: times
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
# Get the sponsored part
|
| 16 |
+
sponsored_segment = []
|
| 17 |
+
for i, val in enumerate(classification):
|
| 18 |
+
if val == 1:
|
| 19 |
+
sponsored_segment.append(split_punct[i])
|
| 20 |
+
|
| 21 |
+
segment = " ".join(sponsored_segment)
|
| 22 |
+
sim_scores = list()
|
| 23 |
+
|
| 24 |
+
# Check the similarity scores between the sponsored part and the transcript parts
|
| 25 |
+
for elem in transcript:
|
| 26 |
+
sim_scores.append(fuzz.partial_ratio(segment, elem["text"]))
|
| 27 |
+
|
| 28 |
+
# Get the scores and check if they are above mean + 2*stdev
|
| 29 |
+
scores = np.array(sim_scores)
|
| 30 |
+
timestamp_mask = (scores > np.mean(scores) + np.std(scores) * 2).astype(int)
|
| 31 |
+
timestamps = [
|
| 32 |
+
(transcript[i]["start"], transcript[i]["duration"])
|
| 33 |
+
for i, elem in enumerate(timestamp_mask)
|
| 34 |
+
if elem == 1
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
# Get the timestamp segments
|
| 38 |
+
times = []
|
| 39 |
+
current = -1
|
| 40 |
+
current_time = 0
|
| 41 |
+
for elem in timestamps:
|
| 42 |
+
# Threshold of 5 to see if it is a jump to another segment (also to make sure smaller segments are added together
|
| 43 |
+
if elem[0] > (current_time + 5):
|
| 44 |
+
current += 1
|
| 45 |
+
times.append((elem[0], elem[0] + elem[1]))
|
| 46 |
+
current_time = elem[0] + elem[1]
|
| 47 |
+
else:
|
| 48 |
+
times[current] = (times[current][0], elem[0] + elem[1])
|
| 49 |
+
current_time = elem[0] + elem[1]
|
| 50 |
+
return times
|
functions/model_infer.py
CHANGED
|
@@ -41,6 +41,6 @@ def predict_from_document(sentences):
|
|
| 41 |
# Set the prediction threshold to 0.8 instead of 0.5, now use mean
|
| 42 |
output = (
|
| 43 |
prediction.flatten()[: len(sentences)]
|
| 44 |
-
>= np.mean(prediction) + np.
|
| 45 |
).astype(int)
|
| 46 |
return output, prediction.flatten()[: len(sentences)]
|
|
|
|
| 41 |
# Set the prediction threshold to 0.8 instead of 0.5, now use mean
|
| 42 |
output = (
|
| 43 |
prediction.flatten()[: len(sentences)]
|
| 44 |
+
>= np.mean(prediction) + np.std(prediction) * 2
|
| 45 |
).astype(int)
|
| 46 |
return output, prediction.flatten()[: len(sentences)]
|
functions/punctuation.py
CHANGED
|
@@ -55,4 +55,4 @@ def punctuate(video_id):
|
|
| 55 |
) # Get the transcript from the YoutubeTranscriptApi
|
| 56 |
resp = query_punctuation(splits) # Get the response from the Inference API
|
| 57 |
punctuated_transcript = parse_output(resp, splits)
|
| 58 |
-
return punctuated_transcript
|
|
|
|
| 55 |
) # Get the transcript from the YoutubeTranscriptApi
|
| 56 |
resp = query_punctuation(splits) # Get the response from the Inference API
|
| 57 |
punctuated_transcript = parse_output(resp, splits)
|
| 58 |
+
return punctuated_transcript, transcript
|
requirements.txt
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
youtube_transcript_api
|
|
|
|
|
|
|
| 2 |
tensorflow==2.15
|
| 3 |
keras
|
| 4 |
keras-nlp
|
|
|
|
| 1 |
youtube_transcript_api
|
| 2 |
+
thefuzz
|
| 3 |
+
numpy
|
| 4 |
tensorflow==2.15
|
| 5 |
keras
|
| 6 |
keras-nlp
|