from flask import Flask, render_template,request, redirect,url_for, jsonify , session
from helper_functions import predict_class ,predict_sentences_class, inference , predict , align_predictions_with_sentences , load_models , load_fr_models
from helper_functions import predict_fr_class, fr_inference , align_fr_predictions_with_sentences , transcribe_speech
import fitz  # PyMuPDF
import os, shutil
import torch
import tempfile
from pydub import AudioSegment
import logging
import torchaudio
 
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = 'static/uploads' 
device = "cpu"
# Global variables for models
global_model = None
global_neptune = None
global_pipe = None
global_fr_model = None
global_fr_neptune = None
global_fr_pipe = None
global_fr_wav2vec2_processor = None
global_fr_wav2vec2_model = None

  
def init_app():
    global global_model, global_neptune, global_pipe
    global global_fr_model, global_fr_neptune, global_fr_wav2vec2_processor, global_fr_wav2vec2_model
    
    print("Loading English models...")
    global_model, global_neptune, global_pipe = load_models()
    
    print("Loading French models...")
    global_fr_model, global_fr_neptune, global_fr_wav2vec2_processor, global_fr_wav2vec2_model = load_fr_models()
    
    print("Models loaded successfully!")

init_app()

@app.route("/")
def home():
    predict_class = ""
    class_probabilities = dict()
    chart_data = dict()
    return render_template('pdf.html', class_probabilities= class_probabilities, predicted_class=predict_class,chart_data = chart_data)

@app.route('/pdf')
def pdf():
    predict_class = ""
    class_probabilities = dict()
    chart_data = dict()
    sentences_prediction = dict()
    return render_template('pdf.html', class_probabilities= class_probabilities, predicted_class=predict_class,chart_data = chart_data,sentences_prediction=sentences_prediction)

@app.route('/pdf/upload' , methods = ['POST'])
def treatment():
    global global_model , global_neptune
    if request.method == 'POST' :
        # Récupérer le fichier PDF de la requête
        file = request.files['file']
        filename = file.filename

        # Enregistrer le fichier dans le répertoire de téléchargement
        filepath = app.config['UPLOAD_FOLDER'] + "/" + filename
        file.save(filepath)

        # Ouvrir le fichier PDF
        pdf_document = fitz.open(filepath)

        # Initialiser une variable pour stocker le texte extrait
        extracted_text = ""

        # Boucler à travers chaque page pour extraire le texte
        for page_num in range(len(pdf_document)):
            # Récupérer l'objet de la page
            page = pdf_document.load_page(page_num)

            # Extraire le texte de la page
            page_text = page.get_text()

            # Ajouter le texte de la page à la variable d'extraction
            extracted_text += f"\nPage {page_num + 1}:\n{page_text}"

        # Fermer le fichier PDF
        pdf_document.close()
        # Prepare data for the chart
        predicted_class , class_probabilities = predict_class([extracted_text] , global_model)
        print(class_probabilities)
        # Process the transcribed text
        sentences_prediction = predict_sentences_class(extracted_text , global_model)
        chart_data = {
            'datasets': [{
                'data': list(class_probabilities.values()),
                'backgroundColor': [color[2] for color in class_probabilities.keys()],
                'borderColor': [color[2] for color in class_probabilities.keys()]
            }],
            'labels': [label[0] for label in class_probabilities.keys()]
        }
        print(predict_class)
        print(chart_data)
        print(sentences_prediction)
         # clear the uploads folder
        for filename in os.listdir(app.config['UPLOAD_FOLDER']):
            file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print('Failed to delete %s. Reason: %s' % (file_path, e))
        return render_template('pdf.html',extracted_text = extracted_text, class_probabilities=class_probabilities, predicted_class=predicted_class, chart_data = chart_data,sentences_prediction=sentences_prediction)
    return render_template('pdf.html')

## Sentence

@app.route('/sentence' , methods = ['GET' , 'POST'])
def sentence():
    global global_model
    if request.method == 'POST':
        # Get the form data
        text = [request.form['text']]
        predicted_class , class_probabilities = predict_class(text , global_model)
        # Prepare data for the chart
        chart_data = {
            'datasets': [{
                'data': list(class_probabilities.values()),
                'backgroundColor': [color[2 ] for color in class_probabilities.keys()],
                'borderColor': [color[2] for color in class_probabilities.keys()]
            }],
            'labels': [label[0] for label in class_probabilities.keys()]
        }
        print(chart_data)
        return render_template('response_sentence.html', text=text, class_probabilities=class_probabilities, predicted_class=predicted_class,chart_data = chart_data)

    # Render the initial form page
    return render_template('sentence.html')

## Voice 
@app.route("/voice_backup")
def slu_backup():
    input_file = "static/uploads/2022.jep-architectures-neuronales.pdf"
    # Ouvrir le fichier PDF
    pdf_document = fitz.open(input_file)
    # Initialiser une variable pour stocker le texte extrait
    extracted_text = ""
    # Boucler à travers chaque page pour extraire le texte
    for page_num in range(len(pdf_document)):
        # Récupérer l'objet de la page
        page = pdf_document.load_page(page_num)

        # Extraire le texte de la page
        page_text = page.get_text()

        # Ajouter le texte de la page à la variable d'extraction
        extracted_text += f"\nPage {page_num + 1}:\n{page_text}"

    # Fermer le fichier PDF
    pdf_document.close()
    # Prepare data for the chart
    inference_batch, sentences = inference(extracted_text)
    predictions = predict(inference_batch)
    sentences_prediction = align_predictions_with_sentences(sentences, predictions)
    predicted_class , class_probabilities = predict_class([extracted_text] , global_model)

    chart_data = {
            'datasets': [{
                'data': list(class_probabilities.values()),
                'backgroundColor': [color[2 ] for color in class_probabilities.keys()],
                'borderColor': [color[2] for color in class_probabilities.keys()]
            }],
            'labels': [label[0] for label in class_probabilities.keys()]
        }
    print(class_probabilities)
    print(chart_data)
    print(sentences_prediction)
    return render_template('voice_backup.html',extracted_text = extracted_text, class_probabilities=class_probabilities, predicted_class=predicted_class, chart_data = chart_data, sentences_prediction = sentences_prediction)

logging.basicConfig(level=logging.DEBUG)

@app.route("/voice", methods=['GET', 'POST'])
def slu():
    global global_neptune, global_pipe, global_model

    if request.method == 'POST':
        logging.debug("Received POST request")
        audio_file = request.files.get('audio')

        if audio_file:
            logging.debug(f"Received audio file: {audio_file.filename}")
            
            # Save audio data to a temporary file
            with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio:
                audio_file.save(temp_audio)
                temp_audio_path = temp_audio.name

            logging.debug(f"Saved audio to temporary file: {temp_audio_path}")

            try:
                # Transcribe audio using Whisper
                result = global_pipe(temp_audio_path)
                extracted_text = result["text"]
                logging.debug(f"Transcribed text: {extracted_text}")

                # Process the transcribed text
                ####inference_batch, sentences = inference(extracted_text)
                ####predictions = predict(inference_batch, global_neptune)
                sentences_prediction = predict_sentences_class(extracted_text , global_model)
                predicted_class, class_probabilities = predict_class([extracted_text], global_model)

                chart_data = {
                    'datasets': [{
                        'data': list(class_probabilities.values()),
                        'backgroundColor': [color[2] for color in class_probabilities.keys()],
                        'borderColor': [color[2] for color in class_probabilities.keys()]
                    }],
                    'labels': [label[0] for label in class_probabilities.keys()]
                }

                response_data = {
                    'extracted_text': extracted_text,
                    'class_probabilities' : class_probabilities,
                    'predicted_class': predicted_class,
                    'chart_data': chart_data,
                    'sentences_prediction': sentences_prediction
                }
                logging.debug(f"Prepared response data: {response_data}")

                return render_template('voice.html', 
                           class_probabilities= class_probabilities, 
                           predicted_class= predicted_class, 
                           chart_data= chart_data, 
                           sentences_prediction=sentences_prediction)

            except Exception as e:
                logging.error(f"Error processing audio: {str(e)}")
                return jsonify({'error': str(e)}), 500

            finally:
                # Remove temporary file
                os.unlink(temp_audio_path)

        else:
            logging.error("No audio file received")
            return jsonify({'error': 'No audio file received'}), 400

    # For GET request
    logging.debug("Received GET request")
    return render_template('voice.html', 
                           class_probabilities={}, 
                           predicted_class=[""], 
                           chart_data={}, 
                           sentences_prediction={})

## French Pages 
@app.route('/pdf_fr')
def pdf_fr():
    predict_class = ""
    class_probabilities = dict()
    chart_data = dict()
    return render_template('pdf_fr.html', class_probabilities= class_probabilities, predicted_class=predict_class,chart_data = chart_data)

@app.route('/pdf_fr/upload' , methods = ['POST'])
def treatment_fr():
    global global_fr_neptune , global_fr_model
    if request.method == 'POST' :
        # Récupérer le fichier PDF de la requête
        file = request.files['file']
        filename = file.filename

        # Enregistrer le fichier dans le répertoire de téléchargement
        filepath = app.config['UPLOAD_FOLDER'] + "/" + filename
        file.save(filepath)

        # Ouvrir le fichier PDF
        pdf_document = fitz.open(filepath)

        # Initialiser une variable pour stocker le texte extrait
        extracted_text = ""

        # Boucler à travers chaque page pour extraire le texte
        for page_num in range(len(pdf_document)):
            # Récupérer l'objet de la page
            page = pdf_document.load_page(page_num)

            # Extraire le texte de la page
            page_text = page.get_text()

            # Ajouter le texte de la page à la variable d'extraction
            extracted_text += f"\nPage {page_num + 1}:\n{page_text}"

        # Fermer le fichier PDF
        pdf_document.close()
        # Process the text
        ####inference_batch, sentences = fr_inference(extracted_text) 
        ####predictions = predict(inference_batch, global_fr_neptune) 
        sentences_prediction = predict_sentences_class(extracted_text , global_fr_model)
        # Prepare data for the chart
        predicted_class , class_probabilities = predict_fr_class([extracted_text] , global_fr_model)
        
        chart_data = {
            'datasets': [{
                'data': list(class_probabilities.values()),
                'backgroundColor': [color[2] for color in class_probabilities.keys()],
                'borderColor': [color[2] for color in class_probabilities.keys()]
            }],
            'labels': [label[0] for label in class_probabilities.keys()]
        }
        print(predict_class)
        print(chart_data)
         # clear the uploads folder
        for filename in os.listdir(app.config['UPLOAD_FOLDER']):
            file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print('Failed to delete %s. Reason: %s' % (file_path, e))
        return render_template('pdf_fr.html',extracted_text = extracted_text, class_probabilities=class_probabilities, predicted_class=predicted_class, chart_data = chart_data, sentences_prediction=sentences_prediction)
    return render_template('pdf_fr.html')

@app.route('/sentence_fr' , methods = ['GET' , 'POST'])
def sentence_fr():
    global global_fr_model
    if request.method == 'POST':
        # Get the form data
        text = [request.form['text']]
        predicted_class , class_probabilities = predict_fr_class(text , global_fr_model)
        # Prepare data for the chart
        chart_data = {
            'datasets': [{
                'data': list(class_probabilities.values()),
                'backgroundColor': [color[2 ] for color in class_probabilities.keys()],
                'borderColor': [color[2] for color in class_probabilities.keys()]
            }],
            'labels': [label[0] for label in class_probabilities.keys()]
        }
        print(predicted_class)
        print(chart_data)
        return render_template('response_fr_sentence.html', text=text, class_probabilities=class_probabilities, predicted_class=predicted_class,chart_data = chart_data)

    # Render the initial form page
    return render_template('sentence_fr.html')

from pydub import AudioSegment
import io

@app.route("/voice_fr", methods=['GET', 'POST'])
def slu_fr():
    global global_fr_neptune, global_fr_model, global_fr_wav2vec2_processor, global_fr_wav2vec2_model

    if request.method == 'POST':
        logging.info("Received POST request for /voice_fr")
        audio_file = request.files.get('audio')

        if audio_file:
            logging.info(f"Received audio file: {audio_file.filename}")
            
            # Lire le contenu du fichier audio
            audio_data = audio_file.read()
            
            # Convertir l'audio en WAV si nécessaire
            try:
                audio = AudioSegment.from_file(io.BytesIO(audio_data))
                audio = audio.set_frame_rate(16000).set_channels(1)
                
                # Sauvegarder l'audio converti dans un fichier temporaire
                with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio:
                    audio.export(temp_audio.name, format="wav")
                    temp_audio_path = temp_audio.name

                logging.info(f"Converted and saved audio to temporary file: {temp_audio_path}")
            except Exception as e:
                logging.error(f"Error converting audio: {str(e)}")
                return jsonify({'error': 'Unable to process audio file'}), 400

            try:
                # Transcrire l'audio en utilisant la fonction de helper_functions
                extracted_text = transcribe_speech(temp_audio_path, global_fr_wav2vec2_processor, global_fr_wav2vec2_model)
                logging.info(f"Transcribed text: {extracted_text}")

                # Traiter le texte transcrit
                ####inference_batch, sentences = fr_inference(extracted_text)
                ####predictions = predict(inference_batch, global_fr_neptune)
                sentences_prediction = predict_sentences_class(extracted_text , global_fr_model)
                predicted_class, class_probabilities = predict_fr_class([extracted_text], global_fr_model)

                chart_data = {
                    'datasets': [{
                        'data': list(class_probabilities.values()),
                        'backgroundColor': [color[2] for color in class_probabilities.keys()],
                        'borderColor': [color[2] for color in class_probabilities.keys()]
                    }],
                    'labels': [label[0] for label in class_probabilities.keys()]
                }

                response_data = {
                    'extracted_text': extracted_text,
                    'class_probabilities': class_probabilities,
                    'predicted_class': predicted_class,
                    'chart_data': chart_data,
                    'sentences_prediction': sentences_prediction
                }
                logging.info(f"Prepared response data: {response_data}")

                return render_template('voice_fr.html', 
                           class_probabilities=class_probabilities, 
                           predicted_class=predicted_class, 
                           chart_data=chart_data, 
                           sentences_prediction=sentences_prediction)

            except Exception as e:
                logging.error(f"Error processing audio: {str(e)}")
                return jsonify({'error': str(e)}), 500

            finally:
                # Supprimer le fichier temporaire
                os.unlink(temp_audio_path)

        else:
            logging.error("No audio file received")
            return jsonify({'error': 'No audio file received'}), 400

    # Pour la requête GET
    logging.info("Received GET request for /voice_fr")
    return render_template('voice_fr.html', 
                           class_probabilities={}, 
                           predicted_class=[""], 
                           chart_data={}, 
                           sentences_prediction={}) 

if __name__ == '__main__':
    app.run(debug=True)