import os
from flask import Flask, request, jsonify, send_file, Response
import torch
import torchaudio
import librosa
import yaml
import numpy as np
from pydub import AudioSegment
from modules.commons import build_model, load_checkpoint, recursive_munch
from hf_utils import load_custom_model_from_hf
from modules.campplus.DTDNN import CAMPPlus
from modules.bigvgan import bigvgan
from transformers import AutoFeatureExtractor, WhisperModel
from modules.audio import mel_spectrogram
from modules.rmvpe import RMVPE
from io import BytesIO

# Initialize Flask app
app = Flask(__name__)

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model and configuration (same as in the original code)
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
                                                "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
                                                "config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
config = yaml.safe_load(open(dit_config_path, 'r'))
model_params = recursive_munch(config['model_params'])
model = build_model(model_params, stage='DiT')
hop_length = config['preprocess_params']['spect_params']['hop_length']
sr = config['preprocess_params']['sr']

# Load checkpoints
model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
                                 load_only_params=True, ignore_modules=[], is_distributed=False)
for key in model:
    model[key].eval()
    model[key].to(device)
model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)

# Load additional models
campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
campplus_model.eval()
campplus_model.to(device)

bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval().to(device)

whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer,
                                                                     'whisper_name') else "openai/whisper-small"
whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
del whisper_model.decoder
whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)

# f0 conditioned model
dit_checkpoint_path_f0, dit_config_path_f0 = load_custom_model_from_hf("Plachta/Seed-VC",
                                                "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
                                                "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")

config_f0 = yaml.safe_load(open(dit_config_path_f0, 'r'))
model_params_f0 = recursive_munch(config_f0['model_params'])
model_f0 = build_model(model_params_f0, stage='DiT')
hop_length_f0 = config_f0['preprocess_params']['spect_params']['hop_length']
sr_f0 = config_f0['preprocess_params']['sr']

# Load checkpoints for f0 model
model_f0, _, _, _ = load_checkpoint(model_f0, None, dit_checkpoint_path_f0,
                                 load_only_params=True, ignore_modules=[], is_distributed=False)
for key in model_f0:
    model_f0[key].eval()
    model_f0[key].to(device)
model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)

# F0 extractor
model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
rmvpe = RMVPE(model_path, is_half=False, device=device)

# Define Mel spectrogram conversion
def to_mel(x):
    mel_fn_args = {
        "n_fft": config['preprocess_params']['spect_params']['n_fft'],
        "win_size": config['preprocess_params']['spect_params']['win_length'],
        "hop_size": config['preprocess_params']['spect_params']['hop_length'],
        "num_mels": config['preprocess_params']['spect_params']['n_mels'],
        "sampling_rate": sr,
        "fmin": 0,
        "fmax": None,
        "center": False
    }
    return mel_spectrogram(x, **mel_fn_args)

def adjust_f0_semitones(f0_sequence, n_semitones):
    factor = 2 ** (n_semitones / 12)
    return f0_sequence * factor

def crossfade(chunk1, chunk2, overlap):
    fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
    fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
    chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
    return chunk2

# Define the Flask route for voice conversion
@app.route('/convert', methods=['POST'])
def voice_conversion_api():
    # Get the input files and parameters from the request
    source = request.files['source']
    target = request.files['target']
    diffusion_steps = int(request.form['diffusion_steps'])
    length_adjust = float(request.form['length_adjust'])
    inference_cfg_rate = float(request.form['inference_cfg_rate'])
    f0_condition = bool(request.form['f0_condition'])
    auto_f0_adjust = bool(request.form['auto_f0_adjust'])
    pitch_shift = int(request.form['pitch_shift'])

    # Read source and target audio
    source_audio = librosa.load(source, sr=sr)[0]
    ref_audio = librosa.load(target, sr=sr)[0]

    # Process audio
    source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
    ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)

    # Resample and process the audio (same as the original logic)
    ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
    converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)

    # You can add further processing and generation logic here (same as the original code)

    # At the end, create the output (this is just an example, adapt based on the real output)
    output_wave = np.random.randn(44100 * 10)  # Replace with actual generated wave
    output_wave = (output_wave * 32768.0).astype(np.int16)

    # Convert to MP3 and send the response
    mp3_file = BytesIO()
    AudioSegment(
        output_wave.tobytes(), frame_rate=sr,
        sample_width=output_wave.dtype.itemsize, channels=1
    ).export(mp3_file, format="mp3", bitrate="320k")
    mp3_file.seek(0)  # Ensure the stream is at the beginning

    return send_file(mp3_file, mimetype="audio/mpeg", as_attachment=True, download_name="converted_audio.mp3")

if __name__ == "__main__":
    # Run the Flask app
    app.run(host='0.0.0.0', debug=True, port=7860)