import torch
import pickle
from transformers import  AutoTokenizer , DistilBertForSequenceClassification , CamembertForSequenceClassification
from transformers import BatchEncoding, PreTrainedTokenizerBase
from typing import Optional
from torch import Tensor
import numpy as np 
from random import shuffle
from Model import BERT
from FrModel import FR_BERT
from Model import tokenizer , mult_token_id , cls_token_id , pad_token_id , max_pred , maxlen , sep_token_id 
from FrModel import fr_tokenizer , fr_mult_token_id , fr_cls_token_id , fr_pad_token_id , fr_sep_token_id
from transformers import pipeline
from transformers import AutoModelForCTC, Wav2Vec2Processor
import torchaudio
import logging
import soundfile as sf

device = "cpu"
# Load the model
def load_models():
    print("Loading DistilBERT model...")
    model = DistilBertForSequenceClassification.from_pretrained("DistillMDPI1/DistillMDPI1/saved_model")
    
    print("Loading BERT model...")
    neptune = BERT()
    device = "cpu"
    model_save_path = "neptune_270_papers/neptune_270_papers/model.pt"
    neptune.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
    neptune.to(device)
    
    print("Loading speech recognition pipeline...")
    pipe = pipeline(
        "automatic-speech-recognition",
        model="openai/whisper-tiny.en",
        chunk_length_s=30,
        device=device,
    )
    print(pipe)
    # Charger le label encoder
    with open("DistillMDPI1/DistillMDPI1/label_encoder.pkl", "rb") as f:
        label_encoder = pickle.load(f)
    
    return model, neptune, pipe

def load_fr_models():
    print("Loading Camembert model")
    fr_model = CamembertForSequenceClassification.from_pretrained("Camembert/Camembert/saved_model")
    print("Loading BERT model...")
    fr_neptune = FR_BERT()
    device = "cpu"
    model_save_path = "fr_neptune/fr_neptune/model.pt"
    fr_neptune.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
    fr_neptune.to(device)
    print("Loading Wav2Vec2 model for French...")
    wav2vec2_processor = Wav2Vec2Processor.from_pretrained("bhuang/asr-wav2vec2-french")
    wav2vec2_model = AutoModelForCTC.from_pretrained("bhuang/asr-wav2vec2-french").to(device)
    return fr_model, fr_neptune, wav2vec2_processor, wav2vec2_model 

fr_class_labels = {0: ('Physics', 'primary', '#5e7cc8'), 1: ('AI','cyan', '#0dcaf0'),
                   2: ('economies', 'warning' , '#f7c32e'), 3: ('environments','success' , '#0cbc87'), 
                   4: ('sports', 'orange', '#fd7e14')} 
class_labels = {
    16: ('vehicles','info' , '#4f9ef8'),
    10: ('environments','success' , '#0cbc87'),
    9: ('energies', 'danger', '#d6293e'),
    0: ('Physics', 'primary', '#0f6fec'),
    13: ('robotics', 'moss','#B1E5F2'),
    3: ('agriculture','agri' , '#a8c686'),
    11: ('ML', 'yellow', '#ffc107'),
    8: ('economies', 'warning' , '#f7c32e'),
    15: ('technologies','vanila' ,'#FDF0D5' ),
    12: ('mathematics','coffe' ,'#7f5539' ),
    14: ('sports', 'orange', '#fd7e14'),
    4: ('AI','cyan', '#0dcaf0'),
    6: ('Innovation','rosy' ,'#BF98A0'),
    5: ('Science','picton' ,'#5fa8d3' ),
    1: ('Societies','purple' , '#6f42c1'),
    2: ('administration','pink', '#d63384'),
    7: ('biology' ,'cambridge' , '#88aa99')}

def predict_class(text,model):
    # Tokenisation du texte
    inputs = transform_list_of_texts([text], tokenizer, 510, 510, 1, 2550)
    
    
    # Initialiser une liste pour stocker les probabilités de chaque échantillon
    all_probabilities = []
    
    # Passage du texte à travers le modèle
    model.eval()
    with torch.no_grad():
        for i, sample in enumerate(inputs['input_ids']):
          for j in range(len(sample)):
            input_ids_tensor = torch.tensor(sample[j], device=device).unsqueeze(0)
            attention_mask_tensor = torch.tensor(inputs['attention_mask'][i][j], device=device).unsqueeze(0)
            outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)
            
            # Application de la fonction softmax
            probabilities = torch.softmax(outputs.logits, dim=1)[0]
            all_probabilities.append(probabilities)
    
    # Calculer la moyenne des probabilités si nous avons plusieurs échantillons
    if len(all_probabilities) > 1:
        mean_probabilities = torch.stack(all_probabilities).mean(dim=0)
    else:
        mean_probabilities = all_probabilities[0]
    
    # Identification de la classe majoritaire
    predicted_class_index = torch.argmax(mean_probabilities).item()
    predicted_class = class_labels[predicted_class_index]

    # Créer un dictionnaire de pourcentages trié par probabilité
    sorted_percentages = {class_labels[idx]: mean_probabilities[idx].item() * 100 for idx in range(len(class_labels))}
    print(sorted_percentages)
    sorted_percentages = dict(sorted(sorted_percentages.items(), key=lambda item: item[1], reverse=True))

    return predicted_class, sorted_percentages

def predict_class_for_Neptune(text,model):
    # Tokenize the text
    encoded_text = transform_for_inference_text(text, tokenizer, 125, 125, 1, 2550)
    batch, sentences = prepare_text(encoded_text)
    
    # Process the text through the model
    model.eval()
    all_probabilities = []
    with torch.no_grad():
        for sample in batch:
            input_ids = torch.tensor(sample[0], device=device, dtype=torch.long).unsqueeze(0)
            segment_ids = torch.tensor(sample[1], device=device, dtype=torch.long).unsqueeze(0)
            masked_pos = torch.tensor(sample[2], device=device, dtype=torch.long).unsqueeze(0)
            
            _, _, logits_mclsf1, logits_mclsf2 = model(input_ids, segment_ids, masked_pos)
            probabilities1 = torch.softmax(logits_mclsf1, dim=1)[0]
            probabilities2 = torch.softmax(logits_mclsf2, dim=1)[0]
            all_probabilities.extend([probabilities1, probabilities2])
    
    # Aggregate probabilities
    aggregated_probabilities = torch.stack(all_probabilities).mean(dim=0)
    
    # Identify the majority class
    predicted_class_index = torch.argmax(aggregated_probabilities).item()
    predicted_class = class_labels[predicted_class_index]
    
    # Create a sorted dictionary of percentages
    sorted_percentages = {class_labels[idx]: aggregated_probabilities[idx].item() * 100 for idx in range(len(class_labels))}
    sorted_percentages = dict(sorted(sorted_percentages.items(), key=lambda item: item[1], reverse=True))
    
    return predicted_class, sorted_percentages

def predict_sentences_class(text,model):
    # Tokenisation du texte
    inputs = transform_list_of_texts([text], tokenizer, 510, 510, 1, 2550)
    aligned_predictions = {}
    
    # Passage du texte à travers le modèle
    model.eval()
    with torch.no_grad():
        for i, sample in enumerate(inputs['input_ids']):
            for j in range(len(sample)):
                input_ids_tensor = sample[j].clone().detach().to(device).unsqueeze(0)
                attention_mask_tensor = inputs['attention_mask'][i][j].clone().detach().to(device).unsqueeze(0)
                
                # Decode the sentence
                sentence = tokenizer.decode(input_ids_tensor[0], skip_special_tokens=True)

                # Passage du texte à travers le modèle
                outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)
                
                # Identification de la classe prédite
                predicted_class_index = torch.argmax(outputs.logits, dim=1).item()
                predicted_class = class_labels[predicted_class_index] # Get only the class name

                # Ajouter la prédiction au dictionnaire
                if sentence not in aligned_predictions:
                    aligned_predictions[sentence] = predicted_class

    return aligned_predictions


def transform_list_of_texts(
    texts: list[str],
    tokenizer: PreTrainedTokenizerBase,
    chunk_size: int,
    stride: int,
    minimal_chunk_length: int,
    maximal_text_length: Optional[int] = None,
) -> BatchEncoding:
    model_inputs = [
        transform_single_text(text, tokenizer, chunk_size, stride, minimal_chunk_length, maximal_text_length)
        for text in texts
    ]
    input_ids = [model_input[0] for model_input in model_inputs]
    attention_mask = [model_input[1] for model_input in model_inputs]
    tokens = {"input_ids": input_ids, "attention_mask": attention_mask}
    return BatchEncoding(tokens)


def transform_single_text(
    text: str,
    tokenizer: PreTrainedTokenizerBase,
    chunk_size: int,
    stride: int,
    minimal_chunk_length: int,
    maximal_text_length: Optional[int],
) -> tuple[Tensor, Tensor]:
    """Transforms (the entire) text to model input of BERT model."""
    if maximal_text_length:
        tokens = tokenize_text_with_truncation(text, tokenizer, maximal_text_length)
    else:
        tokens = tokenize_whole_text(text, tokenizer)
    input_id_chunks, mask_chunks = split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length)
    add_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks)
    add_padding_tokens(input_id_chunks, mask_chunks , chunk_size)
    input_ids, attention_mask = stack_tokens_from_all_chunks(input_id_chunks, mask_chunks)
    return input_ids, attention_mask


def tokenize_whole_text(text: str, tokenizer: PreTrainedTokenizerBase) -> BatchEncoding:
    """Tokenizes the entire text without truncation and without special tokens."""
    tokens = tokenizer(text, add_special_tokens=False, truncation=False, return_tensors="pt")
    return tokens


def tokenize_text_with_truncation(
    text: str, tokenizer: PreTrainedTokenizerBase, maximal_text_length: int
) -> BatchEncoding:
    """Tokenizes the text with truncation to maximal_text_length and without special tokens."""
    tokens = tokenizer(
        text, add_special_tokens=False, max_length=maximal_text_length, truncation=True, return_tensors="pt"
    )
    return tokens


def split_tokens_into_smaller_chunks(
    tokens: BatchEncoding,
    chunk_size: int,
    stride: int,
    minimal_chunk_length: int,
) -> tuple[list[Tensor], list[Tensor]]:
    """Splits tokens into overlapping chunks with given size and stride."""
    input_id_chunks = split_overlapping(tokens["input_ids"][0], chunk_size, stride, minimal_chunk_length)
    mask_chunks = split_overlapping(tokens["attention_mask"][0], chunk_size, stride, minimal_chunk_length)
    return input_id_chunks, mask_chunks


def add_special_tokens_at_beginning_and_end(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> None:
    """
    Adds special CLS token (token id = 101) at the beginning.
    Adds SEP token (token id = 102) at the end of each chunk.
    Adds corresponding attention masks equal to 1 (attention mask is boolean).
    """
    for i in range(len(input_id_chunks)):
        # adding CLS (token id 101) and SEP (token id 102) tokens
        input_id_chunks[i] = torch.cat([Tensor([101]), input_id_chunks[i], Tensor([102])])
        # adding attention masks  corresponding to special tokens
        mask_chunks[i] = torch.cat([Tensor([1]), mask_chunks[i], Tensor([1])])


def add_padding_tokens(input_id_chunks: list[Tensor], mask_chunks: list[Tensor] , chunk_size) -> None:
    """Adds padding tokens (token id = 0) at the end to make sure that all chunks have exactly 512 tokens."""
    for i in range(len(input_id_chunks)):
        # get required padding length
        pad_len = chunk_size + 2 - input_id_chunks[i].shape[0]
        # check if tensor length satisfies required chunk size
        if pad_len > 0:
            # if padding length is more than 0, we must add padding
            input_id_chunks[i] = torch.cat([input_id_chunks[i], Tensor([0] * pad_len)])
            mask_chunks[i] = torch.cat([mask_chunks[i], Tensor([0] * pad_len)])


def stack_tokens_from_all_chunks(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> tuple[Tensor, Tensor]:
    """Reshapes data to a form compatible with BERT model input."""
    input_ids = torch.stack(input_id_chunks)
    attention_mask = torch.stack(mask_chunks)

    return input_ids.long(), attention_mask.int()


def split_overlapping(tensor: Tensor, chunk_size: int, stride: int, minimal_chunk_length: int) -> list[Tensor]:
    """Helper function for dividing 1-dimensional tensors into overlapping chunks."""
    result = [tensor[i : i + chunk_size] for i in range(0, len(tensor), stride)]
    if len(result) > 1:
        # ignore chunks with less than minimal_length number of tokens
        result = [x for x in result if len(x) >= minimal_chunk_length]
    return result

## Voice part

def stack_tokens_from_all_chunks_for_inference(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> tuple[Tensor, Tensor]:
    """Reshapes data to a form compatible with BERT model input."""
    input_ids = torch.stack(input_id_chunks)
    attention_mask = torch.stack(mask_chunks)

    return input_ids.long(), attention_mask.int()

def transform_for_inference_text(text: str,
    tokenizer: PreTrainedTokenizerBase,
    chunk_size: int,
    stride: int,
    minimal_chunk_length: int,
    maximal_text_length: Optional[int],) -> BatchEncoding:
    if maximal_text_length:
        tokens = tokenize_text_with_truncation(text, tokenizer, maximal_text_length)
    else:
        tokens = tokenize_whole_text(text, tokenizer)
    input_id_chunks, mask_chunks = split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length)
    add_special_tokens_at_beginning_and_end_inference(input_id_chunks, mask_chunks)
    add_padding_tokens_inference(input_id_chunks, mask_chunks, chunk_size)
    input_ids, attention_mask = stack_tokens_from_all_chunks_for_inference(input_id_chunks, mask_chunks)
    return {"input_ids": input_ids, "attention_mask": attention_mask}

def add_special_tokens_at_beginning_and_end_inference(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> None:
    """
    Adds special MULT token, CLS token at the beginning.
    Adds SEP token at the end of each chunk.
    Adds corresponding attention masks equal to 1 (attention mask is boolean).
    """
    for i in range(len(input_id_chunks)):
        # adding MULT, CLS, and SEP tokens
        input_id_chunks[i] = torch.cat([input_id_chunks[i]])
        # adding attention masks corresponding to special tokens
        mask_chunks[i] = torch.cat([mask_chunks[i]])

def add_padding_tokens_inference(input_id_chunks: list[Tensor], mask_chunks: list[Tensor], chunk_size: int) -> None:
    """Adds padding tokens at the end to make sure that all chunks have exactly chunk_size tokens."""
    pad_token_id = 0  # Assuming this is defined somewhere in your code
    for i in range(len(input_id_chunks)):
        # get required padding length
        pad_len = chunk_size - input_id_chunks[i].shape[0]
        # check if tensor length satisfies required chunk size
        if pad_len > 0:
            # if padding length is more than 0, we must add padding
            input_id_chunks[i] = torch.cat([input_id_chunks[i], torch.tensor([pad_token_id] * pad_len)])
            mask_chunks[i] = torch.cat([mask_chunks[i], torch.tensor([0] * pad_len)])

def prepare_text(tokens_splitted: BatchEncoding):
    batch = []
    sentences = []
    input_ids_list = tokens_splitted['input_ids']
    
    for i in range(0, len(input_ids_list), 2):  # Adjust loop to stop at second last index
        k = i + 1
        if k == len(input_ids_list):
            input_ids_a = input_ids_list[i]
            input_ids_a = [token for token in input_ids_a.view(-1).tolist() if token != pad_token_id]
            input_ids_b = []
            input_ids = [cls_token_id] + [mult_token_id] + input_ids_a + [sep_token_id] + [mult_token_id] + input_ids_b + [sep_token_id]
            text_input_a = tokenizer.decode(input_ids_a)
            sentences.append(text_input_a)
            segment_ids = [0] * (1 + 1 + len(input_ids_a) + 1) + [1] * (1 + len(input_ids_b) + 1)
            
            # MASK LM
            n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15))))
            cand_masked_pos = [idx for idx, token in enumerate(input_ids) if token not in [cls_token_id, sep_token_id, mult_token_id]]
            shuffle(cand_masked_pos)
            masked_tokens, masked_pos = [], []
            for pos in cand_masked_pos[:n_pred]:
                masked_pos.append(pos)
                masked_tokens.append(input_ids[pos])
                input_ids[pos] = tokenizer.mask_token_id

            # Zero Padding
            n_pad = maxlen - len(input_ids)
            input_ids.extend([pad_token_id] * n_pad)
            segment_ids.extend([0] * n_pad)

            # Zero Padding for masked tokens
            if max_pred > n_pred:
                n_pad = max_pred - n_pred
                masked_tokens.extend([0] * n_pad)
                masked_pos.extend([0] * n_pad)
        else:
            input_ids_a = input_ids_list[i]  # Correct the indexing here
            input_ids_b = input_ids_list[k]  # Correct the indexing here
            input_ids_a = [token for token in input_ids_a.view(-1).tolist() if token != pad_token_id]
            input_ids_b = [token for token in input_ids_b.view(-1).tolist() if token != pad_token_id]
            input_ids = [cls_token_id] + [mult_token_id] + input_ids_a + [sep_token_id] + [mult_token_id] + input_ids_b + [sep_token_id]
            segment_ids = [0] * (1 + 1 + len(input_ids_a) + 1) + [1] * (1 + len(input_ids_b) + 1)
            text_input_a = tokenizer.decode(input_ids_a)
            text_input_b = tokenizer.decode(input_ids_b)
            sentences.append(text_input_a)
            sentences.append(text_input_b)

            # MASK LM
            n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15))))
            cand_masked_pos = [idx for idx, token in enumerate(input_ids) if token not in [cls_token_id, sep_token_id, mult_token_id]]
            shuffle(cand_masked_pos)
            masked_tokens, masked_pos = [], []
            for pos in cand_masked_pos[:n_pred]:
                masked_pos.append(pos)
                masked_tokens.append(input_ids[pos])
                input_ids[pos] = tokenizer.mask_token_id

            # Zero Padding
            n_pad = maxlen - len(input_ids)
            input_ids.extend([pad_token_id] * n_pad)
            segment_ids.extend([0] * n_pad)

            # Zero Padding for masked tokens
            if max_pred > n_pred:
                n_pad = max_pred - n_pred
                masked_tokens.extend([0] * n_pad)
                masked_pos.extend([0] * n_pad)

        batch.append([input_ids, segment_ids, masked_pos])
    return batch, sentences

def inference(text: str):
    encoded_text = transform_for_inference_text(text, tokenizer, 125, 125, 1, 2550)
    batch, sentences = prepare_text(encoded_text)
    return batch, sentences

def predict(inference_batch,neptune , device = device):
    all_preds_mult1 = []
    neptune.eval()
    with torch.no_grad():
        for batch in inference_batch:
            input_ids = torch.tensor(batch[0], device=device, dtype=torch.long).unsqueeze(0)
            segment_ids = torch.tensor(batch[1], device=device, dtype=torch.long).unsqueeze(0)
            masked_pos = torch.tensor(batch[2], device=device, dtype=torch.long).unsqueeze(0)
            _, _, logits_mclsf1, logits_mclsf2 = neptune(input_ids, segment_ids, masked_pos)
            preds_mult1 = torch.argmax(logits_mclsf1, dim=1).cpu().detach().numpy()
            preds_mult2 = torch.argmax(logits_mclsf2, dim=1).cpu().detach().numpy()
            
            all_preds_mult1.extend(preds_mult1)
            all_preds_mult1.extend(preds_mult2)

    return all_preds_mult1

def align_predictions_with_sentences(sentences, preds):
    dc = {}  # Initialize an empty dictionary
    for sentence, pred in zip(sentences, preds):  # Iterate through sentences and predictions
        dc[sentence] = class_labels.get(pred, "Unknown")  # Look up the label for each prediction
    return dc

#### FRENCH PREPROCESSING ####
def predict_fr_class(text , model):
    # Tokenisation du texte
    inputs = transform_list_of_fr_texts(text, fr_tokenizer, 126, 30, 1, 2550)
    # Extraire le tenseur de la liste
    input_ids_tensor = inputs["input_ids"][0]
    attention_mask_tensor = inputs["attention_mask"][0]
    # Passage du texte à travers le modèle
    with torch.no_grad():
        outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)

    # Application de la fonction softmax
    probabilities = torch.softmax(outputs.logits, dim=1)[0]

    # Identification de la classe majoritaire
    predicted_class_index = torch.argmax(probabilities).item()
    predicted_class = fr_class_labels[predicted_class_index]

    # Créer un dictionnaire de pourcentages trié par probabilité
    sorted_percentages = {fr_class_labels[idx]: probabilities[idx].item() * 100 for idx in range(len(fr_class_labels))}
    sorted_percentages = dict(sorted(sorted_percentages.items(), key=lambda item: item[1], reverse=True))

    return predicted_class, sorted_percentages

def prepare_fr_text(tokens_splitted: BatchEncoding):
    batch = []
    sentences = []
    input_ids_list = tokens_splitted['input_ids']
    
    for i in range(0, len(input_ids_list), 2):  # Adjust loop to stop at second last index
        k = i + 1
        if k == len(input_ids_list):
            input_ids_a = input_ids_list[i]
            input_ids_a = [token for token in input_ids_a.view(-1).tolist() if token != pad_token_id]
            input_ids_b = []
            input_ids = [fr_cls_token_id] + [fr_mult_token_id] + input_ids_a + [fr_sep_token_id] + [fr_mult_token_id] + input_ids_b + [fr_sep_token_id]
            text_input_a = fr_tokenizer.decode(input_ids_a , skip_special_tokens=True)
            sentences.append(text_input_a)
            segment_ids = [0] * (1 + 1 + len(input_ids_a) + 1) + [1] * (1 + len(input_ids_b) + 1)
            
            # MASK LM
            n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15))))
            cand_masked_pos = [idx for idx, token in enumerate(input_ids) if token not in [fr_cls_token_id, fr_sep_token_id, fr_mult_token_id]]
            shuffle(cand_masked_pos)
            masked_tokens, masked_pos = [], []
            for pos in cand_masked_pos[:n_pred]:
                masked_pos.append(pos)
                masked_tokens.append(input_ids[pos])
                input_ids[pos] = fr_tokenizer.mask_token_id

            # Zero Padding
            n_pad = maxlen - len(input_ids)
            input_ids.extend([fr_pad_token_id] * n_pad)
            segment_ids.extend([0] * n_pad)

            # Zero Padding for masked tokens
            if max_pred > n_pred:
                n_pad = max_pred - n_pred
                masked_tokens.extend([0] * n_pad)
                masked_pos.extend([0] * n_pad)
        else:
            input_ids_a = input_ids_list[i]  # Correct the indexing here
            input_ids_b = input_ids_list[k]  # Correct the indexing here
            input_ids_a = [token for token in input_ids_a.view(-1).tolist() if token != pad_token_id]
            input_ids_b = [token for token in input_ids_b.view(-1).tolist() if token != pad_token_id]
            input_ids = [fr_cls_token_id] + [fr_mult_token_id] + input_ids_a + [fr_sep_token_id] + [fr_mult_token_id] + input_ids_b + [fr_sep_token_id]
            segment_ids = [0] * (1 + 1 + len(input_ids_a) + 1) + [1] * (1 + len(input_ids_b) + 1)
            text_input_a = fr_tokenizer.decode(input_ids_a , skip_special_tokens=True)
            text_input_b = fr_tokenizer.decode(input_ids_b, skip_special_tokens=True)
            sentences.append(text_input_a)
            sentences.append(text_input_b)

            # MASK LM
            n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15))))
            cand_masked_pos = [idx for idx, token in enumerate(input_ids) if token not in [fr_cls_token_id, fr_sep_token_id, fr_mult_token_id]]
            shuffle(cand_masked_pos)
            masked_tokens, masked_pos = [], []
            for pos in cand_masked_pos[:n_pred]:
                masked_pos.append(pos)
                masked_tokens.append(input_ids[pos])
                input_ids[pos] = fr_tokenizer.mask_token_id

            # Zero Padding
            n_pad = maxlen - len(input_ids)
            input_ids.extend([fr_pad_token_id] * n_pad)
            segment_ids.extend([0] * n_pad)

            # Zero Padding for masked tokens
            if max_pred > n_pred:
                n_pad = max_pred - n_pred
                masked_tokens.extend([0] * n_pad)
                masked_pos.extend([0] * n_pad)

        batch.append([input_ids, segment_ids, masked_pos])
    return batch, sentences

def fr_inference(text: str):
    encoded_text = transform_for_inference_fr_text(text, fr_tokenizer, 125, 125, 1, 2550)
    batch, sentences = prepare_fr_text(encoded_text)
    return batch, sentences

def align_fr_predictions_with_sentences(sentences, preds):
    dc = {}  # Initialize an empty dictionary
    for sentence, pred in zip(sentences, preds):  # Iterate through sentences and predictions
        dc[sentence] = fr_class_labels.get(pred, "Unknown")  # Look up the label for each prediction
    return dc

def transform_for_inference_fr_text(text: str,
    tokenizer: PreTrainedTokenizerBase,
    chunk_size: int,
    stride: int,
    minimal_chunk_length: int,
    maximal_text_length: Optional[int],) -> BatchEncoding:
    if maximal_text_length:
        tokens = tokenize_text_with_truncation(text, tokenizer, maximal_text_length)
    else:
        tokens = tokenize_whole_text(text, tokenizer)
    input_id_chunks, mask_chunks = split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length)
    add_special_tokens_at_beginning_and_end_inference(input_id_chunks, mask_chunks)
    add_padding_fr_tokens_inference(input_id_chunks, mask_chunks, chunk_size)
    input_ids, attention_mask = stack_tokens_from_all_chunks_for_inference(input_id_chunks, mask_chunks)
    return {"input_ids": input_ids, "attention_mask": attention_mask} 

def add_padding_fr_tokens_inference(input_id_chunks: list[Tensor], mask_chunks: list[Tensor], chunk_size: int) -> None:
    """Adds padding tokens at the end to make sure that all chunks have exactly chunk_size tokens."""
    pad_token_id = 1  # Assuming this is defined somewhere in your code
    for i in range(len(input_id_chunks)):
        # get required padding length
        pad_len = chunk_size - input_id_chunks[i].shape[0]
        # check if tensor length satisfies required chunk size
        if pad_len > 0:
            # if padding length is more than 0, we must add padding
            input_id_chunks[i] = torch.cat([input_id_chunks[i], torch.tensor([pad_token_id] * pad_len)])
            mask_chunks[i] = torch.cat([mask_chunks[i], torch.tensor([0] * pad_len)])


def transform_list_of_fr_texts(
    texts: list[str],
    tokenizer: PreTrainedTokenizerBase,
    chunk_size: int,
    stride: int,
    minimal_chunk_length: int,
    maximal_text_length: Optional[int] = None,
) -> BatchEncoding:
    model_inputs = [
        transform_single_fr_text(text, tokenizer, chunk_size, stride, minimal_chunk_length, maximal_text_length)
        for text in texts
    ]
    input_ids = [model_input[0] for model_input in model_inputs]
    attention_mask = [model_input[1] for model_input in model_inputs]
    tokens = {"input_ids": input_ids, "attention_mask": attention_mask}
    return BatchEncoding(tokens)


def transform_single_fr_text(
    text: str,
    tokenizer: PreTrainedTokenizerBase,
    chunk_size: int,
    stride: int,
    minimal_chunk_length: int,
    maximal_text_length: Optional[int],
) -> tuple[Tensor, Tensor]:
    """Transforms (the entire) text to model input of BERT model."""
    if maximal_text_length:
        tokens = tokenize_text_with_truncation(text, tokenizer, maximal_text_length)
    else:
        tokens = tokenize_whole_text(text, tokenizer)
    input_id_chunks, mask_chunks = split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length)
    add_fr_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks)
    add_padding_tokens(input_id_chunks, mask_chunks , chunk_size)
    input_ids, attention_mask = stack_tokens_from_all_chunks(input_id_chunks, mask_chunks)
    return input_ids, attention_mask

def add_fr_special_tokens_at_beginning_and_end(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> None:
    """
    Adds special CLS token (token id = 101) at the beginning.
    Adds SEP token (token id = 102) at the end of each chunk.
    Adds corresponding attention masks equal to 1 (attention mask is boolean).
    """
    for i in range(len(input_id_chunks)):
        # adding CLS (token id 101) and SEP (token id 102) tokens
        input_id_chunks[i] = torch.cat([Tensor([5]), input_id_chunks[i], Tensor([6])])
        # adding attention masks  corresponding to special tokens
        mask_chunks[i] = torch.cat([Tensor([1]), mask_chunks[i], Tensor([1])])
                
def transcribe_speech(audio_path, wav2vec2_processor, wav2vec2_model):
    logging.info(f"Starting transcription of {audio_path}")
    
    try:
        # Try loading with torchaudio first
        waveform, sample_rate = torchaudio.load(audio_path)
        waveform = waveform.squeeze().numpy()
        logging.info(f"Audio loaded with torchaudio. Shape: {waveform.shape}, Sample rate: {sample_rate}")
    except Exception as e:
        logging.warning(f"torchaudio failed to load the audio. Trying with soundfile. Error: {str(e)}")
        try:
            # If torchaudio fails, try with soundfile
            waveform, sample_rate = sf.read(audio_path)
            waveform = torch.from_numpy(waveform).float()
            logging.info(f"Audio loaded with soundfile. Shape: {waveform.shape}, Sample rate: {sample_rate}")
        except Exception as e:
            logging.error(f"Both torchaudio and soundfile failed to load the audio. Error: {str(e)}")
            raise ValueError("Unable to load the audio file.")

    # Ensure waveform is 1D
    if waveform.ndim > 1:
        waveform = np.mean(waveform, axis=0)  # Changed from axis=1 to axis=0
        logging.info(f"Waveform reduced to 1D. New shape: {waveform.shape}")

    # Resample if necessary
    if sample_rate != wav2vec2_processor.feature_extractor.sampling_rate:
        resampler = torchaudio.transforms.Resample(sample_rate, wav2vec2_processor.feature_extractor.sampling_rate)
        waveform = resampler(torch.from_numpy(waveform).float())
        logging.info(f"Audio resampled to {wav2vec2_processor.feature_extractor.sampling_rate}Hz")

    # Normalize
    try:
        input_values = wav2vec2_processor(waveform, sampling_rate=wav2vec2_processor.feature_extractor.sampling_rate, return_tensors="pt").input_values
        logging.info(f"Input values shape after processing: {input_values.shape}")
    except Exception as e:
        logging.error(f"Error during audio processing: {str(e)}")
        raise

    # Ensure input_values is 2D (batch_size, sequence_length)
    input_values = input_values.squeeze()
    if input_values.dim() == 0:  # If it's a scalar, unsqueeze twice
        input_values = input_values.unsqueeze(0).unsqueeze(0)
    elif input_values.dim() == 1:  # If it's 1D, unsqueeze once
        input_values = input_values.unsqueeze(0)
    logging.info(f"Final input values shape: {input_values.shape}")

    try:
        with torch.inference_mode():
            logits = wav2vec2_model(input_values.to(device)).logits
        logging.info(f"Model inference successful. Logits shape: {logits.shape}")
    except Exception as e:
        logging.error(f"Error during model inference: {str(e)}")
        raise

    predicted_ids = torch.argmax(logits, dim=-1)
    predicted_sentence = wav2vec2_processor.batch_decode(predicted_ids)
    logging.info(f"Transcription complete. Result: {predicted_sentence[0]}")
    return predicted_sentence[0]