File size: 11,560 Bytes
c623f78
589108f
5eb80a4
c623f78
24d04ee
 
 
 
 
 
 
 
c623f78
30cd067
 
5eb80a4
589108f
 
 
 
 
 
 
 
 
 
 
 
 
 
9961512
589108f
 
 
 
 
 
 
 
 
 
6481a82
589108f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b686f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589108f
 
9961512
589108f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import zipfile
import os
import sys

# Path to your ZIP file and extraction directory
zip_path = "fer.zip"  # Ensure the correct path to your ZIP file
extract_folder = "fer"  # Directory where files will be extracted

# Check if the extraction folder exists, if not, extract the ZIP file
if not os.path.exists(extract_folder):
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extract_folder)  # Extract to 'fer' directory

# Add the extracted folder to sys.path so we can import the FER module from there
sys.path.insert(0, os.path.abspath(extract_folder))  # Insert at the beginning

import gradio as gr
import cv2
import librosa
import librosa.display
import torch
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
from flask import Flask, request, jsonify
from flask_cors import CORS
from groq import Groq
import requests
from threading import Thread
import concurrent.futures
from fer import FER

# Set the environment variables before importing libraries
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'  # Allow duplicate OpenMP libraries
os.environ['OMP_NUM_THREADS'] = '1'  # Limit the number of OpenMP threads to 1

# Flask app for Groq Chatbot
app = Flask(__name__)
CORS(app)

# Groq API Setup
client = Groq(api_key="your_api_key")


# Configuration des modèles
weight_model1 = 0.7  # Pondération pour le modèle FER
weight_model2 = 0.3  # Pondération pour le modèle audio
pain_threshold = 0.4  # Seuil pour détecter la douleur
confidence_threshold = 0.3  # Seuil de confiance pour les émotions
pain_emotions = ["angry", "fear", "sad"]  # Émotions liées à la douleur

# Fonction pour détecter si l'entrée est un audio ou une vidéo
def detect_input_type(file_path):
    _, ext = os.path.splitext(file_path)
    if ext.lower() in ['.mp3', '.wav', '.flac']:
        return 'audio'
    elif ext.lower() in ['.mp4', '.avi', '.mov', '.mkv']:
        return 'video'
    else:
        return 'unknown'

# ---- Modèle FER (Vision) ----
def extract_frames_and_analyze(video_path, fer_detector, sampling_rate=1):
    cap = cv2.VideoCapture(video_path)
    pain_scores = []
    frame_indices = []
    frame_count = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Ne traiter qu'une frame sur n pour optimiser la performance
        if frame_count % sampling_rate == 0:
            # Détecter l'émotion dominante
            emotion, score = fer_detector.top_emotion(frame)
            if emotion in pain_emotions and score >= confidence_threshold:
                pain_scores.append(score)
                frame_indices.append(frame_count)

        frame_count += 1

    cap.release()

    # Si des scores sont détectés, appliquer le smoothing
    if pain_scores:
        window_length = min(5, len(pain_scores))
        if window_length % 2 == 0:
            window_length = max(3, window_length - 1)

        # Ensure window_length is less than or equal to the length of pain_scores
        window_length = min(window_length, len(pain_scores))

        # Ensure polyorder is less than window_length
        polyorder = min(2, window_length - 1)

        pain_scores = savgol_filter(pain_scores, window_length, polyorder=polyorder)

    return pain_scores, frame_indices

# ---- Modèle Audio ----
def analyze_audio(audio_path, model, feature_extractor):
    try:
        audio, sr = librosa.load(audio_path, sr=16000)
        inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
        with torch.no_grad():
            logits = model(**inputs).logits
        probs = torch.nn.functional.softmax(logits, dim=-1)

        pain_scores = []
        for idx, prob in enumerate(probs[0]):
            emotion = model.config.id2label[idx]
            if emotion in pain_emotions:
                pain_scores.append(prob.item())
        return pain_scores
    except Exception as e:
        print(f"Erreur lors de l'analyse audio : {e}")
        return []

# ---- Fusion des scores ----
def combine_scores(scores_model1, scores_model2, weight1, weight2):
    """Combine scores from FER and audio models using weights."""

    # If any list is empty, fill it with 0 values to match the other model's length
    if len(scores_model1) == 0:
        scores_model1 = [0] * len(scores_model2)
    if len(scores_model2) == 0:
        scores_model2 = [0] * len(scores_model1)

    # Combine the scores using weights
    combined_scores = [
        (weight1 * score1 + weight2 * score2)
        for score1, score2 in zip(scores_model1, scores_model2)
    ]

    return combined_scores

# ---- Traitement de l'entrée audio ou vidéo ----
def process_input(file_path, fer_detector, model, feature_extractor):
    input_type = detect_input_type(file_path)

    if input_type == 'audio':
        pain_scores_model1 = []
        pain_scores_model2 = analyze_audio(file_path, model, feature_extractor)
        final_scores = pain_scores_model2  # Pas de normalisation nécessaire ici
    elif input_type == 'video':
        # Traitement en parallèle des vidéos et de l'audio
        with concurrent.futures.ThreadPoolExecutor() as executor:
            future_video = executor.submit(extract_frames_and_analyze, file_path, fer_detector, sampling_rate=5)
            future_audio = executor.submit(analyze_audio, file_path, model, feature_extractor)

            pain_scores_model1, frame_indices = future_video.result()
            pain_scores_model2 = future_audio.result()

        final_scores = combine_scores(pain_scores_model1, pain_scores_model2, weight_model1, weight_model2)
    else:
        return "Type de fichier non pris en charge. Veuillez fournir un fichier audio ou vidéo."

    # Décision finale
    average_pain = sum(final_scores) / len(final_scores) if final_scores else 0
    pain_detected = average_pain > pain_threshold
    result = "Pain" if pain_detected else "No Pain"

    # Affichage des résultats
    if not final_scores:
        plt.text(0.5, 0.5, "No Data Available", ha='center', va='center', fontsize=16)
    else:
        plt.plot(range(len(final_scores)), final_scores, label="Combined Pain Scores", color="purple")
        plt.axhline(y=pain_threshold, color="green", linestyle="--", label="Pain Threshold")
        plt.xlabel("Frame / Sample Index")
        plt.ylabel("Pain Score")
        plt.title("Pain Detection Scores")
        plt.legend()
        plt.grid(True)
    
    # Save the graph as a file
    graph_filename = "pain_detection_graph.png"
    plt.savefig(graph_filename)
    plt.close()

    return result, average_pain, graph_filename


@app.route('/message', methods=['POST'])
def handle_message():
    user_input = request.json.get('message', '')
    completion = client.chat.completions.create(
        model="llama3-8b-8192",
        messages=[{"role": "user", "content": user_input}],
        temperature=1,
        max_tokens=1024,
        top_p=1,
        stream=True,
        stop=None,
    )

    response = ""
    for chunk in completion:
        response += chunk.choices[0].delta.content or ""

    return jsonify({'reply': response})

# Chatbot interaction function
def gradio_interface(file, chatbot_input, state_pain_results):
    model_name = "superb/wav2vec2-large-superb-er"
    model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
    detector = FER(mtcnn=True)

    chatbot_response = "How can I assist you today?"  # Default chatbot response
    pain_result = ""
    average_pain = ""
    graph_filename = ""

    # Handle file upload and process it when Submit is clicked
    if file:
        result, average_pain, graph_filename = process_input(file.name, detector, model, feature_extractor)
        state_pain_results["result"] = result
        state_pain_results["average_pain"] = average_pain
        state_pain_results["graph_filename"] = graph_filename

        # Custom chatbot response based on pain detection
        if result == "No Pain":
            chatbot_response = "It seems there's no pain detected. How can I assist you further?"
        else:
            chatbot_response = "It seems you have some pain. Would you like me to help with it or provide more details?"

        # Update pain result and graph filename
        pain_result = result
    else:
        # Use the existing state if no new file is uploaded
        pain_result = state_pain_results.get("result", "")
        average_pain = state_pain_results.get("average_pain", "")
        graph_filename = state_pain_results.get("graph_filename", "")

        # If the chatbot_input field is not empty, process the chat message
        if chatbot_input:
            # Send message to Flask server to get the response from Groq model
            response = requests.post(
                'http://localhost:5000/message', json={'message': chatbot_input}
            )
            data = response.json()
            chatbot_response = data['reply']

    # Ensure 4 outputs: pain_result, average_pain, graph_output, chatbot_output
    return pain_result, average_pain, graph_filename, chatbot_response


# Start Flask server in a thread
def start_flask():
    app.run(debug=True, use_reloader=False)

# Launch Gradio and Flask
if __name__ == "__main__":
    # Start Flask in a separate thread
    flask_thread = Thread(target=start_flask)
    flask_thread.start()

    # Gradio interface
    with gr.Blocks() as interface:
        gr.Markdown("""<div style="text-align:center; margin-top:20px;"><h1>PainSense: AI-Driven Pain Detection and Chatbot Assistance</h1></div>""")
        with gr.Row():
            with gr.Column(scale=1):
                file_input = gr.File(label="Upload Audio or Video File")
                with gr.Row():  # Place buttons next to each other
                    clear_button = gr.Button("Clear", elem_id="clear_btn")
                    submit_button = gr.Button("Submit", variant="primary", elem_id="submit_button")
                chatbot_input = gr.Textbox(label="Chat with AI", placeholder="Ask a question...", interactive=True)
                chatbot_output = gr.Textbox(label="Chatbot Response", interactive=False)

            with gr.Column(scale=1):
                pain_result = gr.Textbox(label="Pain Detection Result")
                average_pain = gr.Textbox(label="Average Pain")
                graph_output = gr.Image(label="Pain Detection Graph")

        state = gr.State({"result": "", "average_pain": "", "graph_filename": ""}) 

        # Clear button resets the UI, including the file input, chatbot input, and outputs
        clear_button.click(lambda: (None, None, "", "",""), outputs=[pain_result, average_pain, graph_output, chatbot_output, file_input])

        # File input only triggers processing when the submit button is clicked
        submit_button.click(
            gradio_interface,
            inputs=[file_input, chatbot_input, state],
            outputs=[pain_result, average_pain, graph_output, chatbot_output],
        )

        # Chat input triggers chatbot response when 'Enter' is pressed
        chatbot_input.submit(
            lambda file, chatbot_input, state: gradio_interface(file, chatbot_input, state)[-1],  # Only update chatbot_output
            inputs=[file_input, chatbot_input, state],
            outputs=[chatbot_output]  # Only update chatbot output
        )

    interface.launch(debug=True)