Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import huggingface_hub | |
| import os | |
| import subprocess | |
| import threading | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from scipy.io import wavfile | |
| # download model | |
| huggingface_hub.snapshot_download( | |
| repo_id='ariesssxu/vta-ldm-clip4clip-v-large', | |
| local_dir='./ckpt/vta-ldm-clip4clip-v-large' | |
| ) | |
| def stream_output(pipe): | |
| for line in iter(pipe.readline, ''): | |
| print(line, end='') | |
| def print_directory_contents(path): | |
| for root, dirs, files in os.walk(path): | |
| level = root.replace(path, '').count(os.sep) | |
| indent = ' ' * 4 * (level) | |
| print(f"{indent}{os.path.basename(root)}/") | |
| subindent = ' ' * 4 * (level + 1) | |
| for f in files: | |
| print(f"{subindent}{f}") | |
| # Print the ckpt directory contents | |
| print_directory_contents('./ckpt') | |
| def get_wav_files(path): | |
| wav_files = [] # Initialize an empty list to store the paths of .wav files | |
| for root, dirs, files in os.walk(path): | |
| level = root.replace(path, '').count(os.sep) | |
| indent = ' ' * 4 * (level) | |
| print(f"{indent}{os.path.basename(root)}/") | |
| subindent = ' ' * 4 * (level + 1) | |
| for f in files: | |
| file_path = os.path.join(root, f) | |
| if f.lower().endswith('.wav'): | |
| wav_files.append(file_path) # Add .wav file paths to the list | |
| print(f"{subindent}{file_path}") | |
| else: | |
| print(f"{subindent}{f}") | |
| return wav_files # Return the list of .wav file paths | |
| def check_outputs_folder(folder_path): | |
| # Check if the folder exists | |
| if os.path.exists(folder_path) and os.path.isdir(folder_path): | |
| # Delete all contents inside the folder | |
| for filename in os.listdir(folder_path): | |
| file_path = os.path.join(folder_path, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) # Remove file or link | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) # Remove directory | |
| except Exception as e: | |
| print(f'Failed to delete {file_path}. Reason: {e}') | |
| else: | |
| print(f'The folder {folder_path} does not exist.') | |
| def plot_spectrogram(wav_file, output_image): | |
| # Read the WAV file | |
| sample_rate, audio_data = wavfile.read(wav_file) | |
| # Check if audio_data is stereo (2 channels) and convert it to mono (1 channel) if needed | |
| if len(audio_data.shape) == 2: | |
| audio_data = audio_data.mean(axis=1) | |
| # Create a plot for the spectrogram | |
| plt.figure(figsize=(10, 4)) | |
| plt.specgram(audio_data, Fs=sample_rate, NFFT=1024, noverlap=512, cmap='inferno', aspect='auto') | |
| plt.title('Spectrogram') | |
| plt.xlabel('Time [s]') | |
| plt.ylabel('Frequency [Hz]') | |
| # Save the plot as an image file | |
| plt.colorbar(label='Intensity [dB]') | |
| plt.savefig(output_image) | |
| plt.close() | |
| def infer(video_in): | |
| # check if 'outputs' dir exists and empty it if necessary | |
| check_outputs_folder('./outputs/tmp') | |
| # Need to find path to gradio temp vid from video input | |
| print(f"VIDEO IN PATH: {video_in}") | |
| # Get the directory name | |
| folder_path = os.path.dirname(video_in) | |
| # Execute the inference command | |
| command = ['python', 'inference_from_video.py', '--original_args', 'ckpt/vta-ldm-clip4clip-v-large/summary.jsonl', '--model', 'ckpt/vta-ldm-clip4clip-v-large/pytorch_model_2.bin', '--data_path', folder_path] | |
| process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1) | |
| # Create threads to handle stdout and stderr | |
| stdout_thread = threading.Thread(target=stream_output, args=(process.stdout,)) | |
| stderr_thread = threading.Thread(target=stream_output, args=(process.stderr,)) | |
| # Start the threads | |
| stdout_thread.start() | |
| stderr_thread.start() | |
| # Wait for the process to complete and the threads to finish | |
| process.wait() | |
| stdout_thread.join() | |
| stderr_thread.join() | |
| print("Inference script finished with return code:", process.returncode) | |
| # Need to find where are the results stored, default should be "./outputs/tmp" | |
| # Print the outputs directory contents | |
| print_directory_contents('./outputs/tmp') | |
| wave_files = get_wav_files('./outputs/tmp') | |
| print(wave_files) | |
| plot_spectrogram(wave_files[0], 'spectrogram.png') | |
| return wave_files[0], 'spectrogram.png' | |
| with gr.Blocks() as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# Video-To-Audio") | |
| video_in = gr.Video(label='Video IN') | |
| submit_btn = gr.Button("Submit") | |
| output_sound = gr.Audio(label="Audio OUT") | |
| output_spectrogram = gr.Image(label='Spectrogram') | |
| #output_sound = gr.Textbox(label="Audio OUT") | |
| submit_btn.click( | |
| fn = infer, | |
| inputs = [video_in], | |
| outputs = [output_sound, output_spectrogram], | |
| show_api = False | |
| ) | |
| demo.launch(show_api=False, show_error=True) |