AIPromoStudio / app.py
Bils's picture
Update app.py
7602ef4 verified
import os
import re
import torch
import tempfile
import logging
import math
from typing import Tuple, Union, Any
from scipy.io.wavfile import write
from pydub import AudioSegment
from dotenv import load_dotenv
import spaces
import gradio as gr
import numpy as np
# Transformers & Models
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration,
)
# Coqui TTS
from TTS.api import TTS
# Diffusers for sound design generation
from diffusers import DiffusionPipeline, AudioLDMPipeline
import diffusers
from packaging import version
# ---------------------------------------------------------------------
# Setup Logging and Environment Variables
# ---------------------------------------------------------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
logging.warning("HF_TOKEN is not set in your environment. Some model downloads might fail.")
# ---------------------------------------------------------------------
# Global Model Caches
# ---------------------------------------------------------------------
LLAMA_PIPELINES: dict[str, Any] = {}
MUSICGEN_MODELS: dict[str, Any] = {}
TTS_MODELS: dict[str, Any] = {}
SOUND_DESIGN_PIPELINES: dict[str, Any] = {}
# ---------------------------------------------------------------------
# Utility Functions
# ---------------------------------------------------------------------
def clean_text(text: str) -> str:
"""
Remove undesired characters that may not be recognized by the model.
Args:
text (str): Input text to be cleaned.
Returns:
str: Cleaned text.
"""
return re.sub(r'\*', '', text)
# ---------------------------------------------------------------------
# Model Helper Functions
# ---------------------------------------------------------------------
def get_llama_pipeline(model_id: str, token: str) -> Any:
"""
Returns a cached LLaMA text-generation pipeline or loads a new one.
Args:
model_id (str): Hugging Face model ID.
token (str): Hugging Face token.
Returns:
Any: A Hugging Face text-generation pipeline.
"""
if model_id in LLAMA_PIPELINES:
return LLAMA_PIPELINES[model_id]
logging.info(f"Loading LLaMA model from {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=token,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
LLAMA_PIPELINES[model_id] = text_pipeline
return text_pipeline
def get_musicgen_model(model_key: str = "facebook/musicgen-large") -> Tuple[Any, Any]:
"""
Returns a cached MusicGen model and processor, or loads new ones.
Args:
model_key (str): Hugging Face model key (default is 'facebook/musicgen-large').
Returns:
Tuple[Any, Any]: The MusicGen model and its processor.
"""
if model_key in MUSICGEN_MODELS:
return MUSICGEN_MODELS[model_key]
logging.info(f"Loading MusicGen model from {model_key}...")
model = MusicgenForConditionalGeneration.from_pretrained(model_key)
processor = AutoProcessor.from_pretrained(model_key)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
MUSICGEN_MODELS[model_key] = (model, processor)
return model, processor
def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC") -> TTS:
"""
Returns a cached TTS model or loads a new one.
Args:
model_name (str): Identifier for the TTS model.
Returns:
TTS: A Coqui TTS model.
"""
if model_name in TTS_MODELS:
return TTS_MODELS[model_name]
logging.info(f"Loading TTS model: {model_name}...")
tts_model = TTS(model_name)
TTS_MODELS[model_name] = tts_model
return tts_model
def get_sound_design_pipeline(model_name: str, token: str) -> Any:
"""
Returns a cached DiffusionPipeline for sound design, or loads a new one.
Raises an error if diffusers version is less than 0.21.0.
Args:
model_name (str): The model name to load.
token (str): Hugging Face token.
Returns:
Any: A DiffusionPipeline for sound design.
Raises:
ValueError: If diffusers version is lower than 0.21.0.
"""
if version.parse(diffusers.__version__) < version.parse("0.21.0"):
raise ValueError("AudioLDM2 requires diffusers>=0.21.0. Please upgrade your diffusers package.")
if model_name in SOUND_DESIGN_PIPELINES:
return SOUND_DESIGN_PIPELINES[model_name]
logging.info(f"Loading sound design pipeline from {model_name}...")
pipe = DiffusionPipeline.from_pretrained(
model_name,
pipeline_class=AudioLDMPipeline,
use_auth_token=token
)
SOUND_DESIGN_PIPELINES[model_name] = pipe
return pipe
# ---------------------------------------------------------------------
# Script Generation Function
# ---------------------------------------------------------------------
@spaces.GPU(duration=100)
def generate_script(user_prompt: str, model_id: str, token: str, duration: int) -> Tuple[str, str, str]:
"""
Generates a voice-over script, sound design suggestions, and music ideas based on the user prompt.
Args:
user_prompt (str): The user-provided concept.
model_id (str): The LLaMA model ID.
token (str): Hugging Face token.
duration (int): The desired duration in seconds.
Returns:
Tuple[str, str, str]: Voice-over script, sound design suggestions, and music suggestions.
"""
try:
text_pipeline = get_llama_pipeline(model_id, token)
system_prompt = (
"You are an expert radio imaging producer specializing in sound design and music. "
f"Based on the user's concept and the selected duration of {duration} seconds, produce the following:\n"
"1. A concise voice-over script. Prefix this section with 'Voice-Over Script:'\n"
"2. Suggestions for sound design. Prefix this section with 'Sound Design Suggestions:'\n"
"3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'"
)
combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"
with torch.inference_mode():
result = text_pipeline(
combined_prompt,
max_new_tokens=300,
do_sample=True,
temperature=0.8
)
generated_text = result[0]["generated_text"]
if "Output:" in generated_text:
generated_text = generated_text.split("Output:")[-1].strip()
# Extract sections using regex
pattern = r"Voice-Over Script:\s*(.*?)\s*Sound Design Suggestions:\s*(.*?)\s*Music Suggestions:\s*(.*)"
match = re.search(pattern, generated_text, re.DOTALL)
if match:
voice_script, sound_design, music_suggestions = (grp.strip() for grp in match.groups())
else:
voice_script = "No voice-over script found."
sound_design = "No sound design suggestions found."
music_suggestions = "No music suggestions found."
return voice_script, sound_design, music_suggestions
except Exception as e:
logging.exception("Error generating script")
return f"Error generating script: {e}", "", ""
# ---------------------------------------------------------------------
# Voice-Over Generation Function
# ---------------------------------------------------------------------
@spaces.GPU(duration=100)
def generate_voice(script: str, tts_model_name: str = "tts_models/en/ljspeech/tacotron2-DDC") -> Union[str, Any]:
"""
Generates a voice-over audio file from a script using Coqui TTS.
Args:
script (str): The voice-over script.
tts_model_name (str): The TTS model name.
Returns:
Union[str, Any]: The file path to the generated .wav file or an error message.
"""
try:
if not script.strip():
return "Error: No script provided."
cleaned_script = clean_text(script)
tts_model = get_tts_model(tts_model_name)
output_path = os.path.join(tempfile.gettempdir(), "voice_over.wav")
tts_model.tts_to_file(text=cleaned_script, file_path=output_path)
return output_path
except Exception as e:
logging.exception("Error generating voice")
return f"Error generating voice: {e}"
# ---------------------------------------------------------------------
# Music Generation Function
# ---------------------------------------------------------------------
@spaces.GPU(duration=200)
def generate_music(prompt: str, audio_length: int) -> Union[str, Any]:
"""
Generates a music track using the MusicGen model based on the prompt.
Args:
prompt (str): Music suggestion prompt.
audio_length (int): Number of tokens determining audio length.
Returns:
Union[str, Any]: The file path to the generated .wav file or an error message.
"""
try:
if not prompt.strip():
return "Error: No music suggestion provided."
model_key = "facebook/musicgen-large"
musicgen_model, musicgen_processor = get_musicgen_model(model_key)
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
with torch.inference_mode():
outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
audio_data = outputs[0, 0].cpu().numpy()
# Normalize audio data to 16-bit integer range
normalized_audio = (audio_data / np.max(np.abs(audio_data)) * 32767).astype("int16")
output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
write(output_path, 44100, normalized_audio)
return output_path
except Exception as e:
logging.exception("Error generating music")
return f"Error generating music: {e}"
# ---------------------------------------------------------------------
# Sound Design Generation Function
# ---------------------------------------------------------------------
@spaces.GPU(duration=200)
def generate_sound_design(prompt: str) -> Union[str, Any]:
"""
Generates a sound design audio file using AudioLDM 2 based on the prompt.
Args:
prompt (str): Sound design prompt.
Returns:
Union[str, Any]: The file path to the generated .wav file or an error message.
"""
try:
if not prompt.strip():
return "Error: No sound design suggestion provided."
pipe = get_sound_design_pipeline("cvssp/audioldm2", HF_TOKEN)
result = pipe(prompt) # Expected to return a dict with key 'audios'
audio_samples = result["audios"][0]
normalized_audio = (audio_samples / np.max(np.abs(audio_samples)) * 32767).astype("int16")
output_path = os.path.join(tempfile.gettempdir(), "sound_design_generated.wav")
write(output_path, 44100, normalized_audio)
return output_path
except Exception as e:
logging.exception("Error generating sound design")
return f"Error generating sound design: {e}"
# ---------------------------------------------------------------------
# Audio Blending Function
# ---------------------------------------------------------------------
@spaces.GPU(duration=100)
def blend_audio(voice_path: str, sound_effect_path: str, music_path: str, ducking: bool, duck_level: int = 10) -> Union[str, Any]:
"""
Blends three audio files (voice, sound design, and music) by:
- Looping/trimming music and sound design to match voice duration.
- Optionally applying ducking to background tracks.
- Overlaying the voice on top of the background.
Args:
voice_path (str): Path to the voice audio file.
sound_effect_path (str): Path to the sound design audio file.
music_path (str): Path to the music audio file.
ducking (bool): Whether to apply ducking.
duck_level (int): Amount of attenuation in dB.
Returns:
Union[str, Any]: The file path to the blended .wav file or an error message.
"""
try:
for path in [voice_path, sound_effect_path, music_path]:
if not os.path.isfile(path):
return f"Error: Missing audio file for {path}"
# Load audio segments
voice = AudioSegment.from_wav(voice_path)
music = AudioSegment.from_wav(music_path)
sound_effect = AudioSegment.from_wav(sound_effect_path)
voice_len = len(voice) # duration in milliseconds
# Loop or trim music to match voice duration using pydub multiplication
if len(music) < voice_len:
repeats = math.ceil(voice_len / len(music))
music = (music * repeats)[:voice_len]
else:
music = music[:voice_len]
# Loop or trim sound design to match voice duration
if len(sound_effect) < voice_len:
repeats = math.ceil(voice_len / len(sound_effect))
sound_effect = (sound_effect * repeats)[:voice_len]
else:
sound_effect = sound_effect[:voice_len]
# Apply ducking if enabled
if ducking:
music = music - duck_level
sound_effect = sound_effect - duck_level
# Overlay music and sound effect for background
background = music.overlay(sound_effect)
# Overlay voice on top of background
final_audio = background.overlay(voice)
output_path = os.path.join(tempfile.gettempdir(), "blended_output.wav")
final_audio.export(output_path, format="wav")
return output_path
except Exception as e:
logging.exception("Error blending audio")
return f"Error blending audio: {e}"
# ---------------------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------------------
with gr.Blocks(css="""
/* Global Styles */
body {
background: linear-gradient(135deg, #1d1f21, #3a3d41);
color: #f0f0f0;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.header {
text-align: center;
padding: 2rem 1rem;
background: linear-gradient(90deg, #6a11cb, #2575fc);
border-radius: 0 0 20px 20px;
margin-bottom: 2rem;
}
.header h1 {
margin: 0;
font-size: 2.5rem;
}
.header p {
font-size: 1.2rem;
}
.gradio-container {
background: #2e2e2e;
border-radius: 10px;
padding: 1rem;
}
.tab-title {
font-size: 1.1rem;
font-weight: bold;
}
.footer {
text-align: center;
font-size: 0.9em;
margin-top: 2rem;
padding: 1rem;
color: #cccccc;
}
""") as demo:
# Custom Header
with gr.Row(elem_classes="header"):
gr.Markdown("""
<h1>🎧 Ai Ads Promo</h1>
<p>Your all-in-one AI solution for creating professional audio ads.</p>
""")
gr.Markdown("""
**Welcome to Ai Ads Promo!**
This app helps you create amazing audio ads in just a few steps:
1. **Script Generation:** Provide your idea and get a voice-over script, sound design, and music suggestions.
2. **Voice Synthesis:** Convert the script into natural-sounding speech.
3. **Music Production:** Generate a custom music track.
4. **Sound Design:** Create creative sound effects.
5. **Audio Blending:** Seamlessly blend voice, music, and sound design (with optional ducking).
""")
with gr.Tabs():
# Step 1: Script Generation
with gr.Tab("πŸ“ Script Generation"):
with gr.Row():
user_prompt = gr.Textbox(
label="Promo Ads Idea",
placeholder="E.g., A 30-second ad for a radio morning show...",
lines=2
)
with gr.Row():
llama_model_id = gr.Textbox(
label="LLaMA Model ID",
value="meta-llama/Meta-Llama-3-8B-Instruct",
placeholder="Enter a valid Hugging Face model ID"
)
duration = gr.Slider(
label="Desired Ad Duration (seconds)",
minimum=15,
maximum=60,
step=15,
value=30
)
generate_script_button = gr.Button("Generate Script", variant="primary")
script_output = gr.Textbox(label="Generated Voice-Over Script", lines=5, interactive=False)
sound_design_output = gr.Textbox(label="Sound Design Suggestions", lines=3, interactive=False)
music_suggestion_output = gr.Textbox(label="Music Suggestions", lines=3, interactive=False)
generate_script_button.click(
fn=lambda prompt, model_id, dur: generate_script(prompt, model_id, HF_TOKEN, dur),
inputs=[user_prompt, llama_model_id, duration],
outputs=[script_output, sound_design_output, music_suggestion_output],
)
# Step 2: Voice Synthesis
with gr.Tab("🎀 Voice Synthesis"):
gr.Markdown("Generate a natural-sounding voice-over using Coqui TTS.")
selected_tts_model = gr.Dropdown(
label="TTS Model",
choices=[
"tts_models/en/ljspeech/tacotron2-DDC",
"tts_models/en/ljspeech/vits",
"tts_models/en/sam/tacotron-DDC",
],
value="tts_models/en/ljspeech/tacotron2-DDC",
multiselect=False
)
generate_voice_button = gr.Button("Generate Voice-Over", variant="primary")
voice_audio_output = gr.Audio(label="Voice-Over (WAV)", type="filepath")
generate_voice_button.click(
fn=lambda script, tts_model: generate_voice(script, tts_model),
inputs=[script_output, selected_tts_model],
outputs=voice_audio_output,
)
# Step 3: Music Production
with gr.Tab("🎢 Music Production"):
gr.Markdown("Generate a custom music track using the **MusicGen Large** model.")
audio_length = gr.Slider(
label="Music Length (tokens)",
minimum=128,
maximum=1024,
step=64,
value=512,
info="Increase tokens for longer audio (inference time may vary)."
)
generate_music_button = gr.Button("Generate Music", variant="primary")
music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")
generate_music_button.click(
fn=lambda music_prompt, length: generate_music(music_prompt, length),
inputs=[music_suggestion_output, audio_length],
outputs=[music_output],
)
# Step 4: Sound Design Generation
with gr.Tab("🎧 Sound Design Generation"):
gr.Markdown("Generate a creative sound design track based on the script's suggestions.")
generate_sound_design_button = gr.Button("Generate Sound Design", variant="primary")
sound_design_audio_output = gr.Audio(label="Generated Sound Design (WAV)", type="filepath")
generate_sound_design_button.click(
fn=generate_sound_design,
inputs=[sound_design_output],
outputs=[sound_design_audio_output],
)
# Step 5: Audio Blending (Voice + Sound Design + Music)
with gr.Tab("🎚️ Audio Blending"):
gr.Markdown("Blend your voice-over, sound design, and music track. Enable ducking to lower background audio during voice segments.")
ducking_checkbox = gr.Checkbox(label="Enable Ducking?", value=True)
duck_level_slider = gr.Slider(
label="Ducking Level (dB attenuation)",
minimum=0,
maximum=20,
step=1,
value=10
)
blend_button = gr.Button("Blend Audio", variant="primary")
blended_output = gr.Audio(label="Final Blended Output (WAV)", type="filepath")
blend_button.click(
fn=blend_audio,
inputs=[voice_audio_output, sound_design_audio_output, music_output, ducking_checkbox, duck_level_slider],
outputs=blended_output
)
# Footer and Visitor Badge
gr.Markdown("""
<div class="footer">
<hr>
Created with ❀️ by <a href="https://bilsimaging.com" target="_blank" style="color: #88aaff;">bilsimaging.com</a>
<br>
<small>Ai Ads Promo &copy; 2025</small>
</div>
""")
gr.HTML("""
<div style="text-align: center; margin-top: 1rem;">
<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold">
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold&countColor=%23263759" alt="visitor badge"/>
</a>
</div>
""")
if __name__ == "__main__":
demo.launch(debug=True)