import numpy as np
import gradio as gr
import torchaudio
import torch
from sherpa_onnx import OnlineRecognizer
import time

# Initialize the recognizer
recognizer_en = OnlineRecognizer.from_transducer(
    tokens="en_tokens.txt",
    encoder="en_encoder.onnx",
    decoder="en_decoder.onnx",
    joiner="en_joiner.onnx",
    num_threads=1,
    decoding_method="modified_beam_search",
    debug=False
)

recognizer_fr = OnlineRecognizer.from_transducer(
    tokens="fr_tokens.txt",
    encoder="fr_encoder.onnx",
    decoder="fr_decoder.onnx",
    joiner="fr_joiner.onnx",
    num_threads=1,
    decoding_method="modified_beam_search",
    debug=False
)

recognizer_de = OnlineRecognizer.from_transducer(
    tokens="de_tokens.txt",
    encoder="de_encoder.onnx",
    decoder="de_decoder.onnx",
    joiner="de_joiner.onnx",
    num_threads=1,
    decoding_method="modified_beam_search",
    debug=False
)

def transcribe_audio_online_streaming(file, language):
    """Generator for file transcription"""
    if file is None:
        yield "Please upload an audio file."
        return

    try:
        match language:
            case "English":
                recognizer = recognizer_en
            case "French":
                recognizer = recognizer_fr
            case "German":
                recognizer = recognizer_de
                
        waveform, sample_rate = torchaudio.load(file.name)
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = resampler(waveform)
            sample_rate = 16000

        waveform_np = waveform.numpy()[0]

        # Add 0.5 seconds of silence padding at the beginning and end
        pad_duration = 0.5  # seconds
        pad_samples = int(pad_duration * sample_rate)
        pad_start = np.zeros(pad_samples, dtype=np.float32)
        pad_end = np.zeros(pad_samples, dtype=np.float32)
        waveform_np = np.concatenate([pad_start, waveform_np, pad_end])
        
        total_samples = waveform_np.shape[0]
        
        s = recognizer.create_stream()
        chunk_size = 4000  # 0.25-second chunks
        offset = 0

        while offset < total_samples:
            end = offset + chunk_size
            chunk = waveform_np[offset:end]
            s.accept_waveform(sample_rate, chunk.tolist())
            
            while recognizer.is_ready(s):
                recognizer.decode_streams([s])
                
            yield recognizer.get_result(s)
            offset += chunk_size

        # Final processing
        tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
        s.accept_waveform(sample_rate, tail_paddings.tolist())
        s.input_finished()
        
        while recognizer.is_ready(s):
            recognizer.decode_streams([s])
        
        current_text = recognizer.get_result(s)
        if isinstance(current_text, (list, np.ndarray)):
            current_text = " ".join(map(str, current_text))
        elif isinstance(current_text, bytes):
            current_text = current_text.decode("utf-8", errors="ignore")

        yield current_text

    except Exception as e:
        yield f"Error: {e}"

def transcribe_microphone_stream(audio_chunk, stream_state, language):
    """Real-time microphone streaming transcription"""
    try:
        match language:
            case "English":
                recognizer = recognizer_en
            case "French":
                recognizer = recognizer_fr
            case "German":
                recognizer = recognizer_de
                
        if audio_chunk is None:  # End of stream
            if stream_state is not None:
                # Flush remaining audio
                tail_paddings = np.zeros(int(0.66 * 16000), dtype=np.float32)
                stream_state.accept_waveform(16000, tail_paddings.tolist())
                stream_state.input_finished()
                while recognizer.is_ready(stream_state):
                    recognizer.decode_streams([stream_state])
                final_text = recognizer.get_result(stream_state)
                return final_text, None
            return "", None

        sample_rate, waveform_np = audio_chunk
        if len(waveform_np.shape) > 1:
            waveform_np = waveform_np.mean(axis=1)
        
        # Resample if needed
        if sample_rate != 16000:
            waveform = torch.from_numpy(waveform_np).float().unsqueeze(0)
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = resampler(waveform)
            waveform_np = waveform.squeeze(0).numpy()
            sample_rate = 16000

        # Initialize stream if first chunk
        if stream_state is None:
            stream_state = recognizer.create_stream()

        # Process audio chunk
        stream_state.accept_waveform(sample_rate, waveform_np.tolist())
        
        # Decode available frames
        while recognizer.is_ready(stream_state):
            recognizer.decode_streams([stream_state])
        
        current_text = recognizer.get_result(stream_state)

        if isinstance(current_text, (list, np.ndarray)):
            current_text = " ".join(map(str, current_text))
        elif isinstance(current_text, bytes):
            current_text = current_text.decode("utf-8", errors="ignore")
        
        return current_text, stream_state

    except Exception as e:
        print(f"Stream error: {e}")
        return str(e), stream_state

def create_app():
    with gr.Blocks() as app:
        gr.Markdown("# Real-time Speech Recognition")
        language_choice = gr.Radio(choices=["English", "French", "German"], label="Select Language", value="English")
        
        with gr.Tabs():
            with gr.Tab("File Transcription"):
                gr.Markdown("Upload an audio file for streaming transcription")
                file_input = gr.File(label="Audio File", type="filepath")
                file_output = gr.Textbox(label="Transcription")
                transcribe_btn = gr.Button("Transcribe")
                transcribe_btn.click(lambda: "", outputs=file_output).then(
                    transcribe_audio_online_streaming,
                    inputs=[file_input, language_choice],
                    outputs=file_output
                )

            with gr.Tab("Live Microphone"):
                gr.Markdown("Speak into your microphone for real-time transcription")
                mic = gr.Audio(
                    sources=["microphone"],
                    streaming=True,
                    type="numpy",
                    label="Live Input",
                    show_download_button=False
                )
                live_output = gr.Textbox(label="Live Transcription")
                state = gr.State()
                
                mic.stream(
                    transcribe_microphone_stream,
                    inputs=[mic, state, language_choice],
                    outputs=[live_output, state],
                    show_progress="hidden"
                )

    return app

if __name__ == "__main__":
    app = create_app()
    app.launch()