import argparse
import logging
import random
import uuid
import numpy as np
from transformers import pipeline
from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image, export_to_video
from transformers import (
    SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5ForSpeechToSpeech,
    BlipProcessor, BlipForConditionalGeneration, TrOCRProcessor, VisionEncoderDecoderModel, 
    ViTImageProcessor, AutoTokenizer, AutoImageProcessor, TimesformerForVideoClassification,
    MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, DPTForDepthEstimation, DPTFeatureExtractor
)
from datasets import load_dataset
from PIL import Image
from torchvision import transforms
import torch
import torchaudio
from speechbrain.pretrained import WaveformEnhancement
import joblib
from huggingface_hub import hf_hub_url, cached_download
from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector, CannyDetector, MidasDetector
import warnings
import time
from espnet2.bin.tts_inference import Text2Speech
import soundfile as sf
from asteroid.models import BaseModel
import traceback
import os
import yaml

warnings.filterwarnings("ignore")

def setup_logger():
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler()
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    return logger

logger = setup_logger()

def load_config(config_path):
    with open(config_path, "r") as file:
        return yaml.load(file, Loader=yaml.FullLoader)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="config.yaml")
    return parser.parse_args()

args = parse_args()

# Ensure the config is always set when not running as the main script
if __name__ != "__main__":
    args.config = "config.gradio.yaml"

config = load_config(args.config)

local_deployment = config["local_deployment"]
if config["inference_mode"] == "huggingface":
    local_deployment = "none"

PROXY = {"https": config["proxy"]} if config["proxy"] else None

start = time.time()

local_models = ""  # Changed to empty string

def load_pipes(local_deployment):
    standard_pipes = {}
    other_pipes = {}
    controlnet_sd_pipes = {}
    
    if local_deployment in ["full"]:
        other_pipes = {
            "damo-vilab/text-to-video-ms-1.7b": {
                "model": DiffusionPipeline.from_pretrained(f"{local_models}damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"),
                "device": "cuda:0"
            },
            "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k": {
                "model": BaseModel.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k"),
                "device": "cuda:0"
            },
            "microsoft/speecht5_vc": {
                "processor": SpeechT5Processor.from_pretrained(f"{local_models}microsoft/speecht5_vc"),
                "model": SpeechT5ForSpeechToSpeech.from_pretrained(f"{local_models}microsoft/speecht5_vc"),
                "vocoder": SpeechT5HifiGan.from_pretrained(f"{local_models}microsoft/speecht5_hifigan"),
                "embeddings_dataset": load_dataset(f"{local_models}Matthijs/cmu-arctic-xvectors", split="validation"),
                "device": "cuda:0"
            },
            "facebook/maskformer-swin-base-coco": {
                "feature_extractor": MaskFormerFeatureExtractor.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"),
                "model": MaskFormerForInstanceSegmentation.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"),
                "device": "cuda:0"
            },
            "Intel/dpt-hybrid-midas": {
                "model": DPTForDepthEstimation.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas", low_cpu_mem_usage=True),
                "feature_extractor": DPTFeatureExtractor.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas"),
                "device": "cuda:0"
            }