Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import re | |
| from transformers import AutoTokenizer, pipeline | |
| from youtube_transcript_api._transcripts import TranscriptListFetcher | |
| tagger = pipeline( | |
| "token-classification", | |
| "./checkpoint-6000", | |
| aggregation_strategy="first", | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("./checkpoint-6000") | |
| max_size = 512 | |
| classes = [False, True] | |
| pattern = re.compile( | |
| r"(?:https?:\/\/)?(?:[0-9A-Z-]+\.)?(?:youtube|youtu|youtube-nocookie)\.(?:com|be)\/(?:watch\?v=|watch\?.+&v=|embed\/|v\/|.+\?v=)?([^&=\n%\?]{11})" | |
| ) | |
| def video_id(url): | |
| p = pattern.match(url) | |
| return p.group(1) if p else None | |
| def process(obj): | |
| o = obj["events"] | |
| new_l = [] | |
| start_dur = None | |
| for line in o: | |
| if "segs" in line: | |
| if len(line["segs"]) == 1 and line["segs"][0]["utf8"] == "\n": | |
| if start_dur is not None: | |
| new_l.append( | |
| { | |
| "w": prev["utf8"], | |
| "s": start_dur + prev["tOffsetMs"], | |
| "e": line["tStartMs"], | |
| } | |
| ) | |
| continue | |
| start_dur = line["tStartMs"] | |
| prev = line["segs"][0] | |
| prev["tOffsetMs"] = 0 | |
| for word in line["segs"][1:]: | |
| try: | |
| new_l.append( | |
| { | |
| "w": prev["utf8"], | |
| "s": start_dur + prev["tOffsetMs"], | |
| "e": start_dur + word["tOffsetMs"], | |
| } | |
| ) | |
| prev = word | |
| except KeyError: | |
| pass | |
| return new_l | |
| def get_transcript(video_id, session): | |
| fetcher = TranscriptListFetcher(session) | |
| _json = fetcher._extract_captions_json( | |
| fetcher._fetch_video_html(video_id), video_id | |
| ) | |
| captionTracks = _json["captionTracks"] | |
| transcript_track_url = "" | |
| for track in captionTracks: | |
| if track["languageCode"] == "en": | |
| transcript_track_url = track["baseUrl"] + "&fmt=json3" | |
| if not transcript_track_url: | |
| return None | |
| obj = session.get(transcript_track_url) | |
| p = process(obj.json()) | |
| return p | |
| def transcript(url): | |
| i = video_id(url) | |
| if i: | |
| return " ".join(l["w"].strip() for l in get_transcript(i, requests.Session())) | |
| else: | |
| return "ERROR: Failed to load transcript (it the link a valid youtube url?)..." | |
| def inference(transcript): | |
| tokens = tokenizer(transcript.split(" "))["input_ids"] | |
| current_length = 0 | |
| current_word_length = 0 | |
| batches = [] | |
| for i, w in enumerate(tokens): | |
| word = w[:-1] if i == 0 else w[1:] if i == (len(tokens) - 1) else w[1:-1] | |
| if (current_length + len(word)) > max_size: | |
| batch = " ".join( | |
| tokenizer.batch_decode( | |
| [ | |
| tok[1:-1] | |
| for tok in tokens[max(0, i - current_word_length - 1) : i] | |
| ] | |
| ) | |
| ) | |
| batches.append(batch) | |
| current_word_length = 0 | |
| current_length = 0 | |
| continue | |
| current_length += len(word) | |
| current_word_length += 1 | |
| if current_length > 0: | |
| batches.append( | |
| " ".join( | |
| tokenizer.batch_decode( | |
| [tok[1:-1] for tok in tokens[i - current_word_length :]] | |
| ) | |
| ) | |
| ) | |
| results = [] | |
| for split in batches: | |
| values = tagger(split) | |
| results.extend( | |
| { | |
| "sponsor": v["entity_group"] == "LABEL_1", | |
| "phrase": v["word"], | |
| } | |
| for v in values | |
| ) | |
| return results | |
| def predict(transcript): | |
| return [(span["phrase"], "Sponsor" if span["sponsor"] else None) for span in inference(transcript)] | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Textbox(label="Video URL", placeholder="Video URL", lines=1, max_lines=1) | |
| btn = gr.Button("Fetch Transcript") | |
| gr.Examples(["youtu.be/xsLJZyih3Ac"], [inp]) | |
| text = gr.Textbox(label="Transcript", placeholder="<generated transcript>") | |
| btn.click(fn=transcript, inputs=inp, outputs=text) | |
| with gr.Column(): | |
| p = gr.Button("Predict Sponsors") | |
| highlight = gr.HighlightedText() | |
| p.click(fn=predict, inputs=text, outputs=highlight) | |
| demo.launch() | |