import os
import gc
import sys
import time
import torch
import spaces
import torchaudio
import numpy as np


from df.enhance import enhance, init_df
from dotenv import load_dotenv
load_dotenv()
from scipy.signal import resample
from pyannote.audio import Pipeline
from difflib import SequenceMatcher
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, Wav2Vec2ForCTC, AutoProcessor, AutoTokenizer, AutoModelForSeq2SeqLM


import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

class ChunkedTranscriber:
    def __init__(self, chunk_size=30, overlap=5, sample_rate=16000):
        self.chunk_size = chunk_size
        self.overlap = overlap
        self.sample_rate = sample_rate
        self.previous_text = ""
        self.previous_lang = None
        self.speaker_diarization_pipeline = self.load_speaker_diarization_pipeline()

    def load_speaker_diarization_pipeline(self):
        """
        Load the pre-trained speaker diarization pipeline from pyannote-audio.
        """
        pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=os.getenv("HF_TOKEN"))
        return pipeline

    @spaces.GPU(duration=180)
    def diarize_audio(self, audio_path):
        """
        Perform speaker diarization on the input audio.
        """
        diarization_result = self.speaker_diarization_pipeline({"uri": "audio", "audio": audio_path})
        return diarization_result

    def load_lid_mms(self):
        model_id = "facebook/mms-lid-256"
        processor = AutoFeatureExtractor.from_pretrained(model_id)
        model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)
        return processor, model

    
    @spaces.GPU(duration=180)
    def language_identification(self, model, processor, chunk, device="cuda"):
        inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
        model.to(device)
        inputs.to(device)
        with torch.no_grad():
          outputs = model(**inputs).logits

        lang_id = torch.argmax(outputs, dim=-1)[0].item()
        detected_lang = model.config.id2label[lang_id]
        del model
        del inputs
        torch.cuda.empty_cache()
        gc.collect()
        return detected_lang


    def load_mms(self) :
        model_id = "facebook/mms-1b-all"
        processor = AutoProcessor.from_pretrained(model_id)
        model = Wav2Vec2ForCTC.from_pretrained(model_id)
        return model, processor


    @spaces.GPU(duration=180)
    def mms_transcription(self, model, processor, chunk, device="cuda"):

        inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
        model.to(device)
        inputs.to(device)
        with torch.no_grad():
            outputs = model(**inputs).logits

        ids = torch.argmax(outputs, dim=-1)[0]
        transcription = processor.decode(ids)
        del model
        del inputs
        torch.cuda.empty_cache()
        gc.collect()
        return transcription


    def load_T2T_translation_model(self) :
        model_id = "facebook/nllb-200-distilled-600M"
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
        return model, tokenizer

    
    @spaces.GPU(duration=180)
    def text2text_translation(self, translation_model, translation_tokenizer, transcript, device="cuda"):
        # model, tokenizer = load_translation_model()

        tokenized_inputs = translation_tokenizer(transcript, return_tensors='pt')
        translation_model.to(device)
        tokenized_inputs.to(device)
        translated_tokens = translation_model.generate(**tokenized_inputs,
                                                      forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"),
                                                      max_length=100)
        del translation_model
        del tokenized_inputs
        torch.cuda.empty_cache()
        gc.collect()
        return translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]

        
    def preprocess_audio(self, audio):
        """
        Create overlapping chunks with improved timing logic
        """
        chunk_samples = int(self.chunk_size * self.sample_rate)
        overlap_samples = int(self.overlap * self.sample_rate)

        chunks_with_times = []
        start_idx = 0

        while start_idx < len(audio):
            end_idx = min(start_idx + chunk_samples, len(audio))

            # Add padding for first chunk
            if start_idx == 0:
                chunk = audio[start_idx:end_idx]
                padding = torch.zeros(int(1 * self.sample_rate))
                chunk = torch.cat([padding, chunk])
            else:
                # Include overlap from previous chunk
                actual_start = max(0, start_idx - overlap_samples)
                chunk = audio[actual_start:end_idx]

            # Pad if necessary
            if len(chunk) < chunk_samples:
                chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))

            # Adjust time ranges to account for overlaps
            chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap)
            chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate)

            chunks_with_times.append({
                'chunk': chunk,
                'start_time': start_idx / self.sample_rate,
                'end_time': end_idx / self.sample_rate,
                'transcribe_start': chunk_start_time,
                'transcribe_end': chunk_end_time
            })

            # Move to next chunk with smaller step size for better continuity
            start_idx += (chunk_samples - overlap_samples)

        return chunks_with_times


    def merge_close_segments(self, results):
        """
        Merge segments that are close in time and have the same language
        """
        if not results:
            return results

        merged = []
        current = results[0]

        for next_segment in results[1:]:
            # Skip empty segments
            if not next_segment['text'].strip():
                continue

            # If segments are in the same language and close in time
            if (current['detected_language'] == next_segment['detected_language'] and
                abs(next_segment['start_time'] - current['end_time']) <= self.overlap):

                # Merge the segments
                current['text'] = current['text'] + ' ' + next_segment['text']
                current['end_time'] = next_segment['end_time']
                if 'translated' in current and 'translated' in next_segment:
                    current['translated'] = current['translated'] + ' ' + next_segment['translated']
            else:
                if current['text'].strip():  # Only add non-empty segments
                    merged.append(current)
                current = next_segment

        if current['text'].strip():  # Add the last segment if non-empty
            merged.append(current)

        return merged


    def clean_overlapping_text(self, current_text, prev_text, current_lang, prev_lang, min_overlap=3):
        """
        Improved text cleaning with language awareness and better sentence boundary handling
        """
        if not prev_text or not current_text:
            return current_text

        # If languages are different, don't try to merge
        if prev_lang and current_lang and prev_lang != current_lang:
            return current_text

        # Split into words
        prev_words = prev_text.split()
        curr_words = current_text.split()

        if len(prev_words) < 2 or len(curr_words) < 2:
            return current_text

        # Find matching sequences at the end of prev_text and start of current_text
        matcher = SequenceMatcher(None, prev_words, curr_words)
        matches = list(matcher.get_matching_blocks())

        # Look for significant overlaps
        best_overlap = 0
        overlap_size = 0

        for match in matches:
            # Check if the match is at the start of current text
            if match.b == 0 and match.size >= min_overlap:
                if match.size > overlap_size:
                    best_overlap = match.size
                    overlap_size = match.size

        if best_overlap > 0:
            # Remove overlapping content while preserving sentence integrity
            cleaned_words = curr_words[best_overlap:]
            if not cleaned_words:  # If everything was overlapping
                return ""
            return ' '.join(cleaned_words).strip()

        return current_text


    def process_chunk(self, chunk_data, mms_model, mms_processor, translation_model=None, translation_tokenizer=None):
        """
        Process chunk with improved language handling
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        try:
            # Language detection
            lid_processor, lid_model = self.load_lid_mms()
            lid_lang = self.language_identification(lid_model, lid_processor, chunk_data['chunk'])

            # Configure processor
            mms_processor.tokenizer.set_target_lang(lid_lang)
            mms_model.load_adapter(lid_lang)

            # Transcribe
            inputs = mms_processor(chunk_data['chunk'], sampling_rate=self.sample_rate, return_tensors="pt")
            inputs = inputs.to(device)
            mms_model = mms_model.to(device)

            with torch.no_grad():
                outputs = mms_model(**inputs).logits

            ids = torch.argmax(outputs, dim=-1)[0]
            transcription = mms_processor.decode(ids)

            # Clean overlapping text with language awareness
            cleaned_transcription = self.clean_overlapping_text(
                transcription,
                self.previous_text,
                lid_lang,
                self.previous_lang,
                min_overlap=3
            )

            # Update previous state
            self.previous_text = transcription
            self.previous_lang = lid_lang

            if not cleaned_transcription.strip():
                return None

            result = {
                'start_time': chunk_data['start_time'],
                'end_time': chunk_data['end_time'],
                'text': cleaned_transcription,
                'detected_language': lid_lang
            }

            # Handle translation
            if translation_model and translation_tokenizer and cleaned_transcription.strip():
                translation = self.text2text_translation(
                    translation_model,
                    translation_tokenizer,
                    cleaned_transcription
                )
                result['translated'] = translation

            return result

        except Exception as e:
            print(f"Error processing chunk: {str(e)}")
            return None
        finally:
            torch.cuda.empty_cache()
            gc.collect()


    def translate_text(self, text, translation_model, translation_tokenizer, device):
        """
        Translate cleaned text using the provided translation model.
        """
        tokenized_inputs = translation_tokenizer(text, return_tensors='pt')
        tokenized_inputs = tokenized_inputs.to(device)
        translation_model = translation_model.to(device)

        translated_tokens = translation_model.generate(
            **tokenized_inputs,
            forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"),
            max_length=100
        )

        translation = translation_tokenizer.batch_decode(
            translated_tokens,
            skip_special_tokens=True
        )[0]

        del translation_model
        del tokenized_inputs
        torch.cuda.empty_cache()
        gc.collect()
        return translation


    def audio_denoising(): 
        model, df_state = init_df()
        enhanced_audio = enhance(model, df_state, noisy_audio)
        return enhanced_audio

    def transcribe_audio(self, audio_path, translate=False):
        """
        Main transcription function with improved segment merging
        """
        # Perform speaker diarization
        diarization_result = self.diarize_audio(audio_path)

        # Extract speaker segments
        speaker_segments = []

        for turn, _, speaker in diarization_result.itertracks(yield_label=True):
            speaker_segments.append({
                'start_time': turn.start,
                'end_time': turn.end,
                'speaker': speaker
            })
        
        audio = self.load_audio(audio_path)
        chunks = self.preprocess_audio(audio)

        mms_model, mms_processor = self.load_mms()
        translation_model, translation_tokenizer = None, None
        if translate:
            translation_model, translation_tokenizer = self.load_T2T_translation_model()

        # Process chunks
        results = []
        for chunk_data in chunks:
            result = self.process_chunk(
                chunk_data,
                mms_model,
                mms_processor,
                translation_model,
                translation_tokenizer
            )
            if result:
                for segment in speaker_segments:
                    if int(segment['start_time']) <= int(chunk_data['start_time']) < int(segment['end_time']):
                        result['speaker'] = segment['speaker']
                        break
                results.append(result)
                # results.append(result)

        # Merge close segments and clean up
        
        merged_results = self.merge_close_segments(results)

        _translation = ""
        _output = ""
        for res in merged_results: 
            _translation+=res['translated']
            _output+=f"{res['start_time']}-{res['end_time']} - Speaker: {res['speaker'].split('_')[1]} - Language: {res['detected_language']}\n Text: {res['text']}\n Translation: {res['translated']}\n\n"
        logger.info(f"\n\n TRANSLATION: {_translation}")
        return _translation, _output


    def load_audio(self, audio_path):
        """
        Load and preprocess audio file.
        """
        waveform, sample_rate = torchaudio.load(audio_path)

        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0)
        else:
            waveform = waveform.squeeze(0)

        # Resample if necessary
        if sample_rate != self.sample_rate:
            resampler = torchaudio.transforms.Resample(
                orig_freq=sample_rate,
                new_freq=self.sample_rate
            )
            waveform = resampler(waveform)

        return waveform.float()