import sys
import time

from importlib.metadata import version

import torch
import torchaudio
import torchaudio.transforms as T

import gradio as gr

from transformers import AutoModelForCTC, Wav2Vec2BertProcessor

# Config
model_name = "Yehor/w2v-bert-uk"

min_duration = 0.5
max_duration = 60

concurrency_limit = 5
use_torch_compile = False

# Torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load the model
asr_model = AutoModelForCTC.from_pretrained(model_name, torch_dtype=torch_dtype, device_map=device)
processor = Wav2Vec2BertProcessor.from_pretrained(model_name)

if use_torch_compile:
    asr_model = torch.compile(asr_model)

# Elements
examples = [
    "example_1.wav",
    "example_2.wav",
    "example_3.wav",
    "example_4.wav",
    "example_5.wav",
    "example_6.wav",
]

examples_table = """
| File  | Text |
| ------------- | ------------- |
| `example_1.wav`  | тема про яку не люблять говорити офіційні джерела у генштабі і міноборони це хімічна зброя окупанти вже тривалий час використовують хімічну зброю заборонену |
| `example_2.wav`  | всіма конвенціями якщо спочатку це були гранати з дронів то тепер фіксують випадки застосування |
| `example_3.wav`  | хімічних снарядів причому склад отруйної речовони різний а отже й наслідки для наших військових теж різні  |
| `example_4.wav`  | використовує на фронті все що має і хімічна зброя не вийняток тож з чим маємо справу розбиралася марія моганисян |
| `example_5.wav`  | двох тисяч випадків застосування росіянами боєприпасів споряджених небезпечними хімічними речовинами |
| `example_6.wav`  | на всі писані норми марія моганисян олександр моторний спецкор марафон єдині новини |
""".strip()

# https://www.tablesgenerator.com/markdown_tables
authors_table = """
## Authors

Follow them in social networks and **contact** if you need any help or have any questions:

| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** |
|-------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram                                                                  |
| https://x.com/yehor_smoliakov at X                                                              |
| https://github.com/egorsmkv at GitHub                                                           |
| https://huggingface.co/Yehor at Hugging Face                                                    |
| or use egorsmkv@gmail.com                                                                       |
""".strip()

description_head = f"""
# Speech-to-Text for Ukrainian

## Overview

This space uses https://huggingface.co/Yehor/w2v-bert-uk model to recognize audio files.

> Due to resource limitations, audio duration **must not** exceed **{max_duration}** seconds.
""".strip()

description_foot = f"""
## Community

- **Discord**: https://discord.gg/yVAjkBgmt4
- Speech Recognition: https://t.me/speech_recognition_uk
- Speech Synthesis: https://t.me/speech_synthesis_uk

## More

Check out other ASR models: https://github.com/egorsmkv/speech-recognition-uk

{authors_table}
""".strip()

transcription_value = """
Recognized text will appear here.

Choose **an example file** below the Recognize button, upload **your audio file**, or use **the microphone** to record own voice.
""".strip()

tech_env = f"""
#### Environment

- Python: {sys.version}
- Torch device: {device}
- Torch dtype: {torch_dtype}
- Use torch.compile: {use_torch_compile}
""".strip()

tech_libraries = f"""
#### Libraries

- torch: {version('torch')}
- torchaudio: {version('torchaudio')}
- transformers: {version('transformers')}
- accelerate: {version('accelerate')}
- gradio: {version('gradio')}
""".strip()


def inference(audio_path, progress=gr.Progress()):
    if not audio_path:
        raise gr.Error("Please upload an audio file.")

    gr.Info("Starting recognition", duration=2)

    progress(0, desc="Recognizing")

    meta = torchaudio.info(audio_path)
    duration = meta.num_frames / meta.sample_rate

    if duration < min_duration:
        raise gr.Error(
            f"The duration of the file is less than {min_duration} seconds, it is {round(duration, 2)} seconds."
        )
    if duration > max_duration:
        raise gr.Error(f"The duration of the file exceeds {max_duration} seconds.")

    paths = [
        audio_path,
    ]

    results = []

    for path in progress.tqdm(paths, desc="Recognizing...", unit="file"):
        t0 = time.time()

        meta = torchaudio.info(audio_path)
        audio_duration = meta.num_frames / meta.sample_rate

        audio_input, sr = torchaudio.load(path)

        if meta.num_channels > 1:
            audio_input = torch.mean(audio_input, dim=0, keepdim=True)

        if meta.sample_rate != 16_000:
            resampler = T.Resample(sr, 16_000, dtype=audio_input.dtype)
            audio_input = resampler(audio_input)

        audio_input = audio_input.squeeze().numpy()

        features = processor([audio_input], sampling_rate=16_000).input_features
        features = torch.tensor(features).to(device)

        if torch_dtype == torch.float16:
            features = features.half()

        with torch.inference_mode():
            logits = asr_model(features).logits

        predicted_ids = torch.argmax(logits, dim=-1)
        predictions = processor.batch_decode(predicted_ids)

        if not predictions:
            predictions = "-"

        elapsed_time = round(time.time() - t0, 2)
        rtf = round(elapsed_time / audio_duration, 4)
        audio_duration = round(audio_duration, 2)

        results.append(
            {
                "path": path.split("/")[-1],
                "transcription": "\n".join(predictions),
                "audio_duration": audio_duration,
                "rtf": rtf,
            }
        )

    gr.Info("Finished!", duration=2)

    result_texts = []

    for result in results:
        result_texts.append(f'**{result["path"]}**')
        result_texts.append("\n\n")
        result_texts.append(f'> {result["transcription"]}')
        result_texts.append("\n\n")
        result_texts.append(f'**Audio duration**: {result["audio_duration"]}')
        result_texts.append("\n")
        result_texts.append(f'**Real-Time Factor**: {result["rtf"]}')

    return "\n".join(result_texts)


demo = gr.Blocks(
    title="Speech-to-Text for Ukrainian",
    analytics_enabled=False,
    theme=gr.themes.Base(),
)

with demo:
    gr.Markdown(description_head)

    gr.Markdown("## Usage")

    with gr.Row():
        audio_file = gr.Audio(label="Audio file", type="filepath")
        transcription = gr.Markdown(
            label="Transcription",
            value=transcription_value,
        )

    gr.Button("Recognize").click(
        inference,
        concurrency_limit=concurrency_limit,
        inputs=audio_file,
        outputs=transcription,
    )

    with gr.Row():
        gr.Examples(label="Choose an example", inputs=audio_file, examples=examples)

    gr.Markdown(examples_table)

    gr.Markdown(description_foot)

    gr.Markdown("### Gradio app uses:")
    gr.Markdown(tech_env)
    gr.Markdown(tech_libraries)

if __name__ == "__main__":
    demo.queue()
    demo.launch()