import os
import re
import torch
import tempfile
import logging
import math
from typing import Tuple, Union, Any
from scipy.io.wavfile import write
from pydub import AudioSegment
from dotenv import load_dotenv
import spaces
import gradio as gr
import numpy as np

# Transformers & Models
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    AutoProcessor,
    MusicgenForConditionalGeneration,
)

# Coqui TTS
from TTS.api import TTS

# Diffusers for sound design generation
from diffusers import DiffusionPipeline, AudioLDMPipeline
import diffusers
from packaging import version

# ---------------------------------------------------------------------
# Setup Logging and Environment Variables
# ---------------------------------------------------------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    logging.warning("HF_TOKEN is not set in your environment. Some model downloads might fail.")

# ---------------------------------------------------------------------
# Global Model Caches
# ---------------------------------------------------------------------
LLAMA_PIPELINES: dict[str, Any] = {}
MUSICGEN_MODELS: dict[str, Any] = {}
TTS_MODELS: dict[str, Any] = {}
SOUND_DESIGN_PIPELINES: dict[str, Any] = {}

# ---------------------------------------------------------------------
# Utility Functions
# ---------------------------------------------------------------------
def clean_text(text: str) -> str:
    """
    Remove undesired characters that may not be recognized by the model.
    
    Args:
        text (str): Input text to be cleaned.
    
    Returns:
        str: Cleaned text.
    """
    return re.sub(r'\*', '', text)

# ---------------------------------------------------------------------
# Model Helper Functions
# ---------------------------------------------------------------------
def get_llama_pipeline(model_id: str, token: str) -> Any:
    """
    Returns a cached LLaMA text-generation pipeline or loads a new one.
    
    Args:
        model_id (str): Hugging Face model ID.
        token (str): Hugging Face token.
        
    Returns:
        Any: A Hugging Face text-generation pipeline.
    """
    if model_id in LLAMA_PIPELINES:
        return LLAMA_PIPELINES[model_id]

    logging.info(f"Loading LLaMA model from {model_id}...")
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        use_auth_token=token,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
    )
    text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
    LLAMA_PIPELINES[model_id] = text_pipeline
    return text_pipeline

def get_musicgen_model(model_key: str = "facebook/musicgen-large") -> Tuple[Any, Any]:
    """
    Returns a cached MusicGen model and processor, or loads new ones.
    
    Args:
        model_key (str): Hugging Face model key (default is 'facebook/musicgen-large').
    
    Returns:
        Tuple[Any, Any]: The MusicGen model and its processor.
    """
    if model_key in MUSICGEN_MODELS:
        return MUSICGEN_MODELS[model_key]

    logging.info(f"Loading MusicGen model from {model_key}...")
    model = MusicgenForConditionalGeneration.from_pretrained(model_key)
    processor = AutoProcessor.from_pretrained(model_key)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    MUSICGEN_MODELS[model_key] = (model, processor)
    return model, processor

def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC") -> TTS:
    """
    Returns a cached TTS model or loads a new one.
    
    Args:
        model_name (str): Identifier for the TTS model.
    
    Returns:
        TTS: A Coqui TTS model.
    """
    if model_name in TTS_MODELS:
        return TTS_MODELS[model_name]

    logging.info(f"Loading TTS model: {model_name}...")
    tts_model = TTS(model_name)
    TTS_MODELS[model_name] = tts_model
    return tts_model

def get_sound_design_pipeline(model_name: str, token: str) -> Any:
    """
    Returns a cached DiffusionPipeline for sound design, or loads a new one.
    Raises an error if diffusers version is less than 0.21.0.
    
    Args:
        model_name (str): The model name to load.
        token (str): Hugging Face token.
    
    Returns:
        Any: A DiffusionPipeline for sound design.
    
    Raises:
        ValueError: If diffusers version is lower than 0.21.0.
    """
    if version.parse(diffusers.__version__) < version.parse("0.21.0"):
        raise ValueError("AudioLDM2 requires diffusers>=0.21.0. Please upgrade your diffusers package.")
    
    if model_name in SOUND_DESIGN_PIPELINES:
        return SOUND_DESIGN_PIPELINES[model_name]
    
    logging.info(f"Loading sound design pipeline from {model_name}...")
    pipe = DiffusionPipeline.from_pretrained(
        model_name,
        pipeline_class=AudioLDMPipeline,
        use_auth_token=token
    )
    SOUND_DESIGN_PIPELINES[model_name] = pipe
    return pipe

# ---------------------------------------------------------------------
# Script Generation Function
# ---------------------------------------------------------------------
@spaces.GPU(duration=100)
def generate_script(user_prompt: str, model_id: str, token: str, duration: int) -> Tuple[str, str, str]:
    """
    Generates a voice-over script, sound design suggestions, and music ideas based on the user prompt.
    
    Args:
        user_prompt (str): The user-provided concept.
        model_id (str): The LLaMA model ID.
        token (str): Hugging Face token.
        duration (int): The desired duration in seconds.
    
    Returns:
        Tuple[str, str, str]: Voice-over script, sound design suggestions, and music suggestions.
    """
    try:
        text_pipeline = get_llama_pipeline(model_id, token)
        system_prompt = (
            "You are an expert radio imaging producer specializing in sound design and music. "
            f"Based on the user's concept and the selected duration of {duration} seconds, produce the following:\n"
            "1. A concise voice-over script. Prefix this section with 'Voice-Over Script:'\n"
            "2. Suggestions for sound design. Prefix this section with 'Sound Design Suggestions:'\n"
            "3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'"
        )
        combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"

        with torch.inference_mode():
            result = text_pipeline(
                combined_prompt,
                max_new_tokens=300,
                do_sample=True,
                temperature=0.8
            )

        generated_text = result[0]["generated_text"]
        if "Output:" in generated_text:
            generated_text = generated_text.split("Output:")[-1].strip()

        # Extract sections using regex
        pattern = r"Voice-Over Script:\s*(.*?)\s*Sound Design Suggestions:\s*(.*?)\s*Music Suggestions:\s*(.*)"
        match = re.search(pattern, generated_text, re.DOTALL)
        if match:
            voice_script, sound_design, music_suggestions = (grp.strip() for grp in match.groups())
        else:
            voice_script = "No voice-over script found."
            sound_design = "No sound design suggestions found."
            music_suggestions = "No music suggestions found."

        return voice_script, sound_design, music_suggestions

    except Exception as e:
        logging.exception("Error generating script")
        return f"Error generating script: {e}", "", ""

# ---------------------------------------------------------------------
# Voice-Over Generation Function
# ---------------------------------------------------------------------
@spaces.GPU(duration=100)
def generate_voice(script: str, tts_model_name: str = "tts_models/en/ljspeech/tacotron2-DDC") -> Union[str, Any]:
    """
    Generates a voice-over audio file from a script using Coqui TTS.
    
    Args:
        script (str): The voice-over script.
        tts_model_name (str): The TTS model name.
    
    Returns:
        Union[str, Any]: The file path to the generated .wav file or an error message.
    """
    try:
        if not script.strip():
            return "Error: No script provided."

        cleaned_script = clean_text(script)
        tts_model = get_tts_model(tts_model_name)
        output_path = os.path.join(tempfile.gettempdir(), "voice_over.wav")
        tts_model.tts_to_file(text=cleaned_script, file_path=output_path)
        return output_path

    except Exception as e:
        logging.exception("Error generating voice")
        return f"Error generating voice: {e}"

# ---------------------------------------------------------------------
# Music Generation Function
# ---------------------------------------------------------------------
@spaces.GPU(duration=200)
def generate_music(prompt: str, audio_length: int) -> Union[str, Any]:
    """
    Generates a music track using the MusicGen model based on the prompt.
    
    Args:
        prompt (str): Music suggestion prompt.
        audio_length (int): Number of tokens determining audio length.
    
    Returns:
        Union[str, Any]: The file path to the generated .wav file or an error message.
    """
    try:
        if not prompt.strip():
            return "Error: No music suggestion provided."

        model_key = "facebook/musicgen-large"
        musicgen_model, musicgen_processor = get_musicgen_model(model_key)
        device = "cuda" if torch.cuda.is_available() else "cpu"

        inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
        with torch.inference_mode():
            outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)

        audio_data = outputs[0, 0].cpu().numpy()
        # Normalize audio data to 16-bit integer range
        normalized_audio = (audio_data / np.max(np.abs(audio_data)) * 32767).astype("int16")
        output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
        write(output_path, 44100, normalized_audio)
        return output_path

    except Exception as e:
        logging.exception("Error generating music")
        return f"Error generating music: {e}"

# ---------------------------------------------------------------------
# Sound Design Generation Function
# ---------------------------------------------------------------------
@spaces.GPU(duration=200)
def generate_sound_design(prompt: str) -> Union[str, Any]:
    """
    Generates a sound design audio file using AudioLDM 2 based on the prompt.
    
    Args:
        prompt (str): Sound design prompt.
    
    Returns:
        Union[str, Any]: The file path to the generated .wav file or an error message.
    """
    try:
        if not prompt.strip():
            return "Error: No sound design suggestion provided."
        
        pipe = get_sound_design_pipeline("cvssp/audioldm2", HF_TOKEN)
        result = pipe(prompt)  # Expected to return a dict with key 'audios'
        audio_samples = result["audios"][0]
        normalized_audio = (audio_samples / np.max(np.abs(audio_samples)) * 32767).astype("int16")
        output_path = os.path.join(tempfile.gettempdir(), "sound_design_generated.wav")
        write(output_path, 44100, normalized_audio)
        return output_path

    except Exception as e:
        logging.exception("Error generating sound design")
        return f"Error generating sound design: {e}"

# ---------------------------------------------------------------------
# Audio Blending Function
# ---------------------------------------------------------------------
@spaces.GPU(duration=100)
def blend_audio(voice_path: str, sound_effect_path: str, music_path: str, ducking: bool, duck_level: int = 10) -> Union[str, Any]:
    """
    Blends three audio files (voice, sound design, and music) by:
      - Looping/trimming music and sound design to match voice duration.
      - Optionally applying ducking to background tracks.
      - Overlaying the voice on top of the background.
    
    Args:
        voice_path (str): Path to the voice audio file.
        sound_effect_path (str): Path to the sound design audio file.
        music_path (str): Path to the music audio file.
        ducking (bool): Whether to apply ducking.
        duck_level (int): Amount of attenuation in dB.
    
    Returns:
        Union[str, Any]: The file path to the blended .wav file or an error message.
    """
    try:
        for path in [voice_path, sound_effect_path, music_path]:
            if not os.path.isfile(path):
                return f"Error: Missing audio file for {path}"

        # Load audio segments
        voice = AudioSegment.from_wav(voice_path)
        music = AudioSegment.from_wav(music_path)
        sound_effect = AudioSegment.from_wav(sound_effect_path)
        voice_len = len(voice)  # duration in milliseconds

        # Loop or trim music to match voice duration using pydub multiplication
        if len(music) < voice_len:
            repeats = math.ceil(voice_len / len(music))
            music = (music * repeats)[:voice_len]
        else:
            music = music[:voice_len]

        # Loop or trim sound design to match voice duration
        if len(sound_effect) < voice_len:
            repeats = math.ceil(voice_len / len(sound_effect))
            sound_effect = (sound_effect * repeats)[:voice_len]
        else:
            sound_effect = sound_effect[:voice_len]

        # Apply ducking if enabled
        if ducking:
            music = music - duck_level
            sound_effect = sound_effect - duck_level

        # Overlay music and sound effect for background
        background = music.overlay(sound_effect)
        # Overlay voice on top of background
        final_audio = background.overlay(voice)

        output_path = os.path.join(tempfile.gettempdir(), "blended_output.wav")
        final_audio.export(output_path, format="wav")
        return output_path

    except Exception as e:
        logging.exception("Error blending audio")
        return f"Error blending audio: {e}"

# ---------------------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------------------
with gr.Blocks(css="""
    /* Global Styles */
    body {
        background: linear-gradient(135deg, #1d1f21, #3a3d41);
        color: #f0f0f0;
        font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
    }
    .header {
        text-align: center;
        padding: 2rem 1rem;
        background: linear-gradient(90deg, #6a11cb, #2575fc);
        border-radius: 0 0 20px 20px;
        margin-bottom: 2rem;
    }
    .header h1 {
        margin: 0;
        font-size: 2.5rem;
    }
    .header p {
        font-size: 1.2rem;
    }
    .gradio-container {
        background: #2e2e2e;
        border-radius: 10px;
        padding: 1rem;
    }
    .tab-title {
        font-size: 1.1rem;
        font-weight: bold;
    }
    .footer {
        text-align: center;
        font-size: 0.9em;
        margin-top: 2rem;
        padding: 1rem;
        color: #cccccc;
    }
""") as demo:

    # Custom Header
    with gr.Row(elem_classes="header"):
        gr.Markdown("""
        <h1>🎧 Ai Ads Promo</h1>
        <p>Your all-in-one AI solution for creating professional audio ads.</p>
        """)

    gr.Markdown("""
    **Welcome to Ai Ads Promo!**

    This app helps you create amazing audio ads in just a few steps:
    
    1. **Script Generation:** Provide your idea and get a voice-over script, sound design, and music suggestions.
    2. **Voice Synthesis:** Convert the script into natural-sounding speech.
    3. **Music Production:** Generate a custom music track.
    4. **Sound Design:** Create creative sound effects.
    5. **Audio Blending:** Seamlessly blend voice, music, and sound design (with optional ducking).
    """)

    with gr.Tabs():
        # Step 1: Script Generation
        with gr.Tab("📝 Script Generation"):
            with gr.Row():
                user_prompt = gr.Textbox(
                    label="Promo Ads Idea", 
                    placeholder="E.g., A 30-second ad for a radio morning show...",
                    lines=2
                )
            with gr.Row():
                llama_model_id = gr.Textbox(
                    label="LLaMA Model ID", 
                    value="meta-llama/Meta-Llama-3-8B-Instruct", 
                    placeholder="Enter a valid Hugging Face model ID"
                )
                duration = gr.Slider(
                    label="Desired Ad Duration (seconds)",
                    minimum=15, 
                    maximum=60, 
                    step=15, 
                    value=30
                )
            generate_script_button = gr.Button("Generate Script", variant="primary")
            script_output = gr.Textbox(label="Generated Voice-Over Script", lines=5, interactive=False)
            sound_design_output = gr.Textbox(label="Sound Design Suggestions", lines=3, interactive=False)
            music_suggestion_output = gr.Textbox(label="Music Suggestions", lines=3, interactive=False)

            generate_script_button.click(
                fn=lambda prompt, model_id, dur: generate_script(prompt, model_id, HF_TOKEN, dur),
                inputs=[user_prompt, llama_model_id, duration],
                outputs=[script_output, sound_design_output, music_suggestion_output],
            )

        # Step 2: Voice Synthesis
        with gr.Tab("🎤 Voice Synthesis"):
            gr.Markdown("Generate a natural-sounding voice-over using Coqui TTS.")
            selected_tts_model = gr.Dropdown(
                label="TTS Model",
                choices=[
                    "tts_models/en/ljspeech/tacotron2-DDC",  
                    "tts_models/en/ljspeech/vits", 
                    "tts_models/en/sam/tacotron-DDC", 
                ],
                value="tts_models/en/ljspeech/tacotron2-DDC",
                multiselect=False
            )
            generate_voice_button = gr.Button("Generate Voice-Over", variant="primary")
            voice_audio_output = gr.Audio(label="Voice-Over (WAV)", type="filepath")

            generate_voice_button.click(
                fn=lambda script, tts_model: generate_voice(script, tts_model),
                inputs=[script_output, selected_tts_model],
                outputs=voice_audio_output,
            )

        # Step 3: Music Production
        with gr.Tab("🎶 Music Production"):
            gr.Markdown("Generate a custom music track using the **MusicGen Large** model.")
            audio_length = gr.Slider(
                label="Music Length (tokens)",
                minimum=128, 
                maximum=1024, 
                step=64, 
                value=512,
                info="Increase tokens for longer audio (inference time may vary)."
            )
            generate_music_button = gr.Button("Generate Music", variant="primary")
            music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")

            generate_music_button.click(
                fn=lambda music_prompt, length: generate_music(music_prompt, length),
                inputs=[music_suggestion_output, audio_length],
                outputs=[music_output],
            )

        # Step 4: Sound Design Generation
        with gr.Tab("🎧 Sound Design Generation"):
            gr.Markdown("Generate a creative sound design track based on the script's suggestions.")
            generate_sound_design_button = gr.Button("Generate Sound Design", variant="primary")
            sound_design_audio_output = gr.Audio(label="Generated Sound Design (WAV)", type="filepath")
            
            generate_sound_design_button.click(
                fn=generate_sound_design,
                inputs=[sound_design_output],
                outputs=[sound_design_audio_output],
            )

        # Step 5: Audio Blending (Voice + Sound Design + Music)
        with gr.Tab("🎚️ Audio Blending"):
            gr.Markdown("Blend your voice-over, sound design, and music track. Enable ducking to lower background audio during voice segments.")
            ducking_checkbox = gr.Checkbox(label="Enable Ducking?", value=True)
            duck_level_slider = gr.Slider(
                label="Ducking Level (dB attenuation)", 
                minimum=0, 
                maximum=20, 
                step=1, 
                value=10
            )
            blend_button = gr.Button("Blend Audio", variant="primary")
            blended_output = gr.Audio(label="Final Blended Output (WAV)", type="filepath")

            blend_button.click(
                fn=blend_audio,
                inputs=[voice_audio_output, sound_design_audio_output, music_output, ducking_checkbox, duck_level_slider],
                outputs=blended_output
            )

    # Footer and Visitor Badge
    gr.Markdown("""
    <div class="footer">
        <hr>
        Created with ❤️ by <a href="https://bilsimaging.com" target="_blank" style="color: #88aaff;">bilsimaging.com</a>
        <br>
        <small>Ai Ads Promo &copy; 2025</small>
    </div>
    """)
    gr.HTML("""
    <div style="text-align: center; margin-top: 1rem;">
        <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold">
            <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold&countColor=%23263759" alt="visitor badge"/>
        </a>
    </div>
    """)

if __name__ == "__main__":
    demo.launch(debug=True)