import gradio as gr import json from difflib import Differ import ffmpeg import os import tempfile from pathlib import Path import time import aiohttp import asyncio import base64 from dotenv import load_dotenv import logging # --- Configuration --- # Set true if you're using huggingface inference API API https://huggingface.co/inference-api API_BACKEND = True MODEL = "facebook/wav2vec2-base-960h" API_URL = f'https://api-inference.huggingface.co/models/{MODEL}' RETRY_ATTEMPTS = 5 RETRY_DELAY = 5 TIMESTAMP_GROUPING_THRESHOLD = 0.1 # --- Logging Configuration --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s - %(funcName)s') # --- Initialization --- if API_BACKEND: load_dotenv(Path(".env")) HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: logging.error("HF_TOKEN environment variable not set. Please set it in a .env file.") raise ValueError("HF_TOKEN environment variable not set.") headers = {"Authorization": f"Bearer {HF_TOKEN}"} else: import torch from transformers import pipeline device = 0 if torch.cuda.is_available() else -1 try: logging.info(f"Initializing local model: {MODEL} on device: {device}") speech_recognizer = pipeline( task="automatic-speech-recognition", model=MODEL, tokenizer=MODEL, framework="pt", device=device, ) logging.info("Local model initialized successfully.") except Exception as e: logging.error(f"Error initializing local model {MODEL}: {e}") raise RuntimeError(f"Error initializing local model {MODEL}: {e}") videos_out_path = Path("./videos_out") videos_out_path.mkdir(parents=True, exist_ok=True) logging.info(f"Output directory created: {videos_out_path}") samples_data_files = sorted(Path('examples').glob('*.json')) SAMPLES = [] for file in samples_data_files: try: with open(file, 'r') as f: sample = json.load(f) if 'video' in sample and 'transcription' in sample and 'timestamps' in sample: SAMPLES.append(sample) else: logging.warning(f"Skipping sample file {file} due to missing keys (video, transcription, or timestamps).") except (json.JSONDecodeError, FileNotFoundError) as e: logging.error(f"Error loading sample file {file}: {e}") VIDEOS = [[sample['video']] for sample in SAMPLES] logging.info(f"Loaded {len(SAMPLES)} example samples.") # --- Helper Functions --- async def query_api(audio_bytes: bytes): """ Query the Hugging Face Inference API for Automatic Speech Recognition. Includes retry logic with exponential backoff. """ payload = json.dumps({ "inputs": base64.b64encode(audio_bytes).decode("utf-8"), "parameters": { "return_timestamps": "char", "chunk_length_s": 10, "stride_length_s": [4, 2] }, "options": {"use_gpu": False} }).encode("utf-8") async with aiohttp.ClientSession() as session: for attempt in range(RETRY_ATTEMPTS): logging.info(f'Transcribing from API attempt {attempt + 1}/{RETRY_ATTEMPTS}') try: async with session.post(API_URL, headers=headers, data=payload) as response: logging.info(f"API Response Status: {response.status}") content_type = response.headers.get('Content-Type', '') if response.status == 200 and 'application/json' in content_type: return await response.json() elif response.status != 200 and 'application/json' in content_type: error_response = await response.json() if 'error' in error_response and 'estimated_time' in error_response: wait_time = error_response['estimated_time'] logging.warning(f"Model loading, waiting for {wait_time} seconds...") await asyncio.sleep(wait_time + RETRY_DELAY) elif 'error' in error_response: raise RuntimeError(f"API Error: {error_response['error']}") else: raise RuntimeError(f"Unknown API Error: {error_response}") else: response_text = await response.text() raise RuntimeError(f"Unexpected API response format (Status: {response.status}, Content-Type: {content_type}): {response_text}") except aiohttp.ClientError as e: logging.error(f"AIOHTTP Client Error during API call (Attempt {attempt + 1}): {e}") except RuntimeError as e: logging.error(f"Runtime error during API call (Attempt {attempt + 1}): {e}") if attempt < RETRY_ATTEMPTS - 1: wait_time = RETRY_DELAY * (2 ** attempt) logging.info(f"Retrying in {wait_time} seconds...") await asyncio.sleep(wait_time) raise RuntimeError(f"Failed to get transcription after {RETRY_ATTEMPTS} attempts.") def ping_telemetry(name: str): """ Send a telemetry ping to Hugging Face Spaces. This is fire-and-forget and doesn't affect the main process flow. """ url = f'https://huggingface.co/api/telemetry/spaces/radames/edit-video-by-editing-text/{name}' logging.info(f"Pinging telemetry: {url}") async def send_ping(): try: async with aiohttp.ClientSession() as session: async with session.get(url) as response: logging.info(f"Telemetry pong: {response.status}") except aiohttp.ClientError as e: logging.warning(f"Failed to send telemetry ping: {e}") asyncio.create_task(send_ping()) # --- Main Gradio Functions --- async def speech_to_text(video_file_path, progress=gr.Progress()): """ Takes a video path to convert to audio, transcribe audio channel to text and char timestamps. Includes progress reporting. """ if video_file_path is None: raise gr.Error("Error: No video input provided.") video_path = Path(video_file_path) if not video_path.exists(): raise gr.Error(f"Error: Video file not found at {video_path}") temp_audio_file = None try: progress(0, desc="Converting video to audio...") with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: temp_audio_file = Path(tmpfile.name) loop = asyncio.get_running_loop() await loop.run_in_executor( None, lambda: ffmpeg.input(video_path).output( str(temp_audio_file), format="wav", ac=1, ar='16k').overwrite_output().global_args('-loglevel', 'quiet').run() ) logging.info(f"Video converted to temporary audio file: {temp_audio_file}") with open(temp_audio_file, 'rb') as f: audio_memory = f.read() except ffmpeg.Error as e: logging.error(f"Error converting video to audio: {e.stderr.decode()}") raise gr.Error(f"Error converting video to audio: {e.stderr.decode()}") except Exception as e: logging.error(f"An unexpected error occurred during audio conversion: {e}") raise gr.Error(f"An unexpected error occurred during audio conversion: {e}") finally: if temp_audio_file and temp_audio_file.exists(): os.remove(temp_audio_file) logging.info(f"Cleaned up temporary audio file: {temp_audio_file}") ping_telemetry("speech_to_text") progress(0.5, desc="Transcribing audio...") if API_BACKEND: try: inference_response = await query_api(audio_memory) logging.info("Inference Response received from API.") if not isinstance(inference_response, dict) or 'text' not in inference_response or 'chunks' not in inference_response: raise RuntimeError(f"Unexpected API response structure: {inference_response}") transcription = inference_response["text"].lower() timestamps = [[chunk.get("text", "").lower(), chunk.get("timestamp", [None, None])[0], chunk.get("timestamp", [None, None])[1]] for chunk in inference_response.get('chunks', []) if isinstance(chunk, dict)] timestamps = [ts for ts in timestamps if ts[1] is not None and ts[2] is not None] progress(1.0, desc="Transcription complete.") return (transcription, transcription, timestamps) except Exception as e: logging.error(f"Error fetching transcription from API: {e}") raise gr.Error(f"Error fetching transcription from API: {e}") else: try: logging.info(f'Transcribing via local model {MODEL}') loop = asyncio.get_running_loop() output = await loop.run_in_executor( None, lambda: speech_recognizer( audio_memory, return_timestamps="char", chunk_length_s=10, stride_length_s=(4, 2)) ) logging.info("Inference complete with local model.") if not isinstance(output, dict) or 'text' not in output or 'chunks' not in output: raise RuntimeError(f"Unexpected model output structure: {output}") transcription = output["text"].lower() timestamps = [[chunk.get("text", "").lower(), chunk.get("timestamp", [None, None])[0] if not isinstance(chunk.get("timestamp", [None, None])[0], list) else chunk.get("timestamp", [None, None])[0][0], chunk.get("timestamp", [None, None])[1] if not isinstance(chunk.get("timestamp", [None, None])[1], list) else chunk.get("timestamp", [None, None])[1][0] ] for chunk in output.get('chunks', []) if isinstance(chunk, dict)] timestamps = [ts for ts in timestamps if ts[1] is not None and ts[2] is not None] progress(1.0, desc="Transcription complete.") return (transcription, transcription, timestamps) except Exception as e: logging.error(f"Error running inference with local model: {e}") raise gr.Error(f"Error running inference with local model: {e}") async def cut_timestamps_to_video(video_in, transcription, text_in, timestamps, progress=gr.Progress()): """ Given original video input, text transcript + timestamps, and edited text cuts video segments into a single video. Includes progress reporting and improved timestamp alignment. """ if video_in is None or text_in is None or transcription is None or timestamps is None: raise gr.Error("Inputs undefined. Please provide video, transcription, and edited text.") video_path = Path(video_in) if not video_path.exists(): raise gr.Error(f"Error: Video file not found at {video_path}") progress(0, desc="Analyzing text differences...") d = Differ() diff_chars = list(d.compare(transcription, text_in)) # --- Improved Timestamp Alignment --- timestamps_to_keep = [] timestamp_idx = 0 diff_idx = 0 while diff_idx < len(diff_chars) and timestamp_idx < len(timestamps): diff_line = diff_chars[diff_idx] ts_info = timestamps[timestamp_idx] ts_char = ts_info[0] if diff_line.startswith(' '): if diff_line[2:].lower() == ts_char.lower(): timestamps_to_keep.append(ts_info) timestamp_idx += 1 diff_idx += 1 else: logging.warning(f"Timestamp alignment mismatch: Diff char '{diff_line[2:]}' vs Timestamp char '{ts_char}'. Skipping timestamp.") diff_idx += 1 elif diff_line.startswith('-'): if diff_line[2:].lower() == ts_char.lower(): timestamp_idx += 1 diff_idx += 1 else: logging.warning(f"Timestamp alignment mismatch for deletion: Diff char '{diff_line[2:]}' vs Timestamp char '{ts_char}'. Skipping diff char.") diff_idx += 1 elif diff_line.startswith('+'): diff_idx += 1 elif diff_line.startswith('?'): diff_idx += 1 else: logging.warning(f"Unexpected diff line format: {diff_line}. Skipping.") diff_idx += 1 logging.info(f"Identified {len(timestamps_to_keep)} timestamps to keep after diff alignment.") progress(0.2, desc="Grouping timestamps...") grouped_segments = [] if timestamps_to_keep: current_segment = [timestamps_to_keep[0]] for i in range(1, len(timestamps_to_keep)): if timestamps_to_keep[i][1] - current_segment[-1][2] < TIMESTAMP_GROUPING_THRESHOLD: current_segment.append(timestamps_to_keep[i]) else: grouped_segments.append(current_segment) current_segment = [timestamps_to_keep[i]] grouped_segments.append(current_segment) logging.info(f"Grouped timestamps into {len(grouped_segments)} segments.") cut_intervals = [[segment[0][1], segment[-1][2]] for segment in grouped_segments] video_file_name = video_path.stem output_video_path = videos_out_path / f"{video_file_name}_cut.mp4" if cut_intervals: progress(0.4, desc="Cutting video segments...") input_video_stream = ffmpeg.input(video_in) filter_complex_parts = [] input_streams = [] for i, interval in enumerate(cut_intervals): start, end = interval filter_complex_parts.append(f"[0:v]trim=start={start},end={end},setpts=PTS-STARTPTS[v{i}]") filter_complex_parts.append(f"[0:a]atrim=start={start},end={end},asetpts=PTS-STARTPTS[a{i}]") input_streams.append(f"[v{i}][a{i}]") concat_input_str = ''.join(input_streams) concat_filter = f"{concat_input_str}concat=n={len(cut_intervals)}:v=1:a=1[outv][outa]" filter_complex_parts.append(concat_filter) filter_complex_str = ';'.join(filter_complex_parts) try: loop = asyncio.get_running_loop() await loop.run_in_executor( None, lambda: ffmpeg.output( input_video_stream, str(output_video_path), filter_complex=filter_complex_str, map=['[outv]', '[outa]'], preset='fast', crf=23 ).overwrite_output().global_args('-loglevel', 'quiet').run() ) logging.info(f"Video segments cut and concatenated to: {output_video_path}") except ffmpeg.Error as e: logging.error(f"Error cutting video: {e.stderr.decode()}") raise gr.Error(f"Error cutting video: {e.stderr.decode()}") except Exception as e: logging.error(f"An unexpected error occurred during video cutting: {e}") raise gr.Error(f"An unexpected error occurred during video cutting: {e}") else: logging.warning("No text was kept, creating a short empty video.") try: loop = asyncio.get_running_loop() await loop.run_in_executor( None, lambda: ffmpeg.input('color=c=black:s=1280x720:d=0.1', f='lavfi').output( str(output_video_path), format='mp4', vcodec='libx264', pix_fmt='yuv420p', t='0.1' ).overwrite_output().global_args('-loglevel', 'quiet').run() ) logging.info(f"Created short empty video at: {output_video_path}") except ffmpeg.Error as e: logging.error(f"Error creating empty video: {e.stderr.decode()}") output_video_path = Path(video_in) logging.warning("Failed to create empty video, returning original video path as fallback.") except Exception as e: logging.error(f"An unexpected error occurred during empty video creation: {e}") output_video_path = Path(video_in) logging.warning("Failed to create empty video, returning original video path as fallback.") diff_output_tokens = [(token[2:], token[0] if token[0] != ' ' else None) for token in diff_chars] ping_telemetry("video_cuts") progress(1.0, desc="Video cutting complete.") return (diff_output_tokens, str(output_video_path)) def load_example(id): """Loads example video and transcription.""" if 0 <= id < len(SAMPLES): sample = SAMPLES[id] video = sample.get('video') transcription = sample.get('transcription', '').lower() timestamps = sample.get('timestamps', []) if video is None: logging.error(f"Example at index {id} is missing video path.") raise gr.Error(f"Example at index {id} is missing video path.") return (video, transcription, transcription, timestamps) else: logging.error(f"Invalid example index: {id}") raise gr.Error(f"Invalid example index: {id}") # --- Gradio Layout --- css = """ #cut_btn, #reset_btn { align-self:stretch; } #\\31 3 { max-width: 540px; } .output-markdown {max-width: 65ch !important;} #video-container{ max-width: 40rem; } """ with gr.Blocks(css=css) as demo: transcription_var = gr.State(value="") timestamps_var = gr.State(value=[]) video_in = gr.Video(label="Video file", elem_id="video-container") text_in = gr.Textbox(label="Transcription", lines=10, interactive=True) video_out = gr.Video(label="Video Out", interactive=False) diff_out = gr.HighlightedText(label="Cuts Diffs", combine_adjacent=True, show_legend=True) gr.Markdown(""" # Edit Video By Editing Text This project is a quick proof of concept of a simple video editor where the edits are made by editing the audio transcription. Using the [Huggingface Automatic Speech Recognition Pipeline](https://huggingface.co/tasks/automatic-speech-recognition) with a fine tuned [Wav2Vec2 model using Connectionist Temporal Classification (CTC)](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self) you can predict not only the text transcription but also the [character or word base timestamps](https://huggingface.co/docs/transformers/v4.19.2/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__.return_timestamps) """) with gr.Row(): examples = gr.Dataset(components=[video_in], samples=VIDEOS, type="index", label="Examples") examples.click( load_example, inputs=[examples], outputs=[video_in, text_in, transcription_var, timestamps_var], queue=False ) with gr.Row(): with gr.Column(): # video_in is rendered when defined within gr.Blocks transcribe_btn = gr.Button("Transcribe Audio") transcribe_btn.click( speech_to_text, inputs=[video_in], outputs=[text_in, transcription_var, timestamps_var] ) gr.Markdown(""" ### Now edit as text After running the video transcription, you can make cuts to the text below (only cuts, not additions!)""") with gr.Row(): with gr.Column(): # text_in is rendered when defined within gr.Blocks with gr.Row(): cut_btn = gr.Button("Cut to video", elem_id="cut_btn") cut_btn.click( cut_timestamps_to_video, inputs=[video_in, transcription_var, text_in, timestamps_var], outputs=[diff_out, video_out] ) reset_transcription = gr.Button( "Reset to last transcription", elem_id="reset_btn") reset_tran