import gradio as gr
import numpy as np
import os
import time
import torch
from scipy.io import wavfile

# Bark imports
from bark import generate_audio, SAMPLE_RATE
from bark.generation import preload_models

# Hugging Face Transformers
from transformers import AutoModelForTextToSpeech, AutoProcessor, AutoTokenizer
from transformers import SpeechT5HifiGan, SpeechT5ForTextToSpeech, SpeechT5Processor

class VoiceSynthesizer:
    def __init__(self):
        # Create working directory
        self.base_dir = os.path.dirname(os.path.abspath(__file__))
        self.working_dir = os.path.join(self.base_dir, "working_files")
        os.makedirs(self.working_dir, exist_ok=True)
        
        # Initialize models dictionary
        self.models = {
            "bark": self._initialize_bark,
            "speecht5": self._initialize_speecht5
        }
        
        # Default model
        self.current_model = "bark"
        
        # Initialize Bark models
        try:
            print("Attempting to load Bark models...")
            preload_models()
            print("Bark models loaded successfully.")
        except Exception as e:
            print(f"Bark model loading error: {e}")
    
    def _initialize_bark(self):
        """Bark model initialization (already done in __init__)"""
        return None
    
    def _initialize_speecht5(self):
        """Initialize SpeechT5 model from Hugging Face"""
        try:
            # Load SpeechT5 model and processor
            model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
            processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
            vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
            
            # Load speaker embeddings
            embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
            speaker_embeddings = torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0)
            
            return {
                "model": model,
                "processor": processor,
                "vocoder": vocoder,
                "speaker_embeddings": speaker_embeddings
            }
        except Exception as e:
            print(f"SpeechT5 model loading error: {e}")
            return None
    
    def set_model(self, model_name):
        """Set the current model for speech synthesis"""
        if model_name not in self.models:
            raise ValueError(f"Model {model_name} not supported")
        self.current_model = model_name
    
    def generate_speech(self, text, model_name=None, voice_preset=None):
        """Generate speech using selected model"""
        if not text or not text.strip():
            return None, "Please enter some text to speak"
        
        # Use specified model or current model
        current_model = model_name or self.current_model
        
        try:
            if current_model == "bark":
                return self._generate_bark_speech(text, voice_preset)
            elif current_model == "speecht5":
                return self._generate_speecht5_speech(text, voice_preset)
            else:
                raise ValueError(f"Unsupported model: {current_model}")
        
        except Exception as e:
            print(f"Speech generation error: {e}")
            import traceback
            traceback.print_exc()
            return None, f"Error generating speech: {str(e)}"
    
    def _generate_bark_speech(self, text, voice_preset=None):
        """Generate speech using Bark"""
        # List of Bark voice presets
        voice_presets = [
            "v2/en_speaker_6",  # Female
            "v2/en_speaker_3",  # Male
            "v2/en_speaker_9",  # Neutral
        ]
        
        # Select voice preset
        history_prompt = voice_preset if voice_preset else voice_presets[0]
        
        # Generate audio
        audio_array = generate_audio(
            text, 
            history_prompt=history_prompt
        )
        
        # Save generated audio
        filename = f"bark_speech_{int(time.time())}.wav"
        filepath = os.path.join(self.working_dir, filename)
        wavfile.write(filepath, SAMPLE_RATE, audio_array)
        
        return filepath, None
    
    def _generate_speecht5_speech(self, text, speaker_id=None):
        """Generate speech using SpeechT5"""
        # Ensure model is initialized
        speecht5_models = self.models["speecht5"]()
        if not speecht5_models:
            return None, "SpeechT5 model not loaded"
        
        model = speecht5_models["model"]
        processor = speecht5_models["processor"]
        vocoder = speecht5_models["vocoder"]
        speaker_embeddings = speecht5_models["speaker_embeddings"]
        
        # Prepare inputs
        inputs = processor(text=text, return_tensors="pt")
        
        # Generate speech
        speech = model.generate_speech(
            inputs["input_ids"], 
            speaker_embeddings
        )
        
        # Convert to numpy array
        audio_array = speech.numpy()
        
        # Save generated audio
        filename = f"speecht5_speech_{int(time.time())}.wav"
        filepath = os.path.join(self.working_dir, filename)
        wavfile.write(filepath, 16000, audio_array)
        
        return filepath, None

def create_interface():
    synthesizer = VoiceSynthesizer()
    
    with gr.Blocks() as interface:
        gr.Markdown("# 🎙️ Advanced Voice Synthesis")
        
        with gr.Row():
            with gr.Column():
                gr.Markdown("## Speech Generation")
                text_input = gr.Textbox(label="Enter Text to Speak")
                
                # Model Selection
                model_dropdown = gr.Dropdown(
                    choices=[
                        "bark (Suno AI)",
                        "speecht5 (Microsoft)"
                    ],
                    label="Select TTS Model",
                    value="bark (Suno AI)"
                )
                
                # Voice Preset Dropdowns
                with gr.Row():
                    bark_preset = gr.Dropdown(
                        choices=[
                            "v2/en_speaker_6 (Female)",
                            "v2/en_speaker_3 (Male)", 
                            "v2/en_speaker_9 (Neutral)"
                        ],
                        label="Bark Voice Preset",
                        visible=True
                    )
                    
                    speecht5_preset = gr.Dropdown(
                        choices=[
                            "Default Speaker"
                        ],
                        label="SpeechT5 Speaker",
                        visible=False
                    )
                
                generate_btn = gr.Button("Generate Speech")
                audio_output = gr.Audio(label="Generated Speech")
                error_output = gr.Textbox(label="Errors", visible=True)
        
        # Dynamic model and preset visibility
        def update_model_visibility(model):
            if "bark" in model.lower():
                return {
                    bark_preset: gr.update(visible=True),
                    speecht5_preset: gr.update(visible=False)
                }
            else:
                return {
                    bark_preset: gr.update(visible=False),
                    speecht5_preset: gr.update(visible=True)
                }
        
        model_dropdown.change(
            fn=update_model_visibility,
            inputs=model_dropdown,
            outputs=[bark_preset, speecht5_preset]
        )
        
        # Speech generation logic
        def generate_speech_wrapper(text, model, bark_preset, speecht5_preset):
            # Map model name
            model_map = {
                "bark (Suno AI)": "bark",
                "speecht5 (Microsoft)": "speecht5"
            }
            
            # Select appropriate preset
            preset = bark_preset if "bark" in model else speecht5_preset
            
            return synthesizer.generate_speech(
                text, 
                model_name=model_map[model], 
                voice_preset=preset
            )
        
        generate_btn.click(
            fn=generate_speech_wrapper,
            inputs=[text_input, model_dropdown, bark_preset, speecht5_preset],
            outputs=[audio_output, error_output]
        )
    
    return interface

if __name__ == "__main__":
    interface = create_interface()
    interface.launch(
        share=False,
        debug=True,
        show_error=True,
        server_name='0.0.0.0',
        server_port=7860
    )