MerlenMaven commited on
Commit
589108f
·
verified ·
1 Parent(s): 8f2ce99

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -0
app.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import cv2
4
+ import librosa
5
+ import librosa.display
6
+ import torch
7
+ import matplotlib.pyplot as plt
8
+ from scipy.signal import savgol_filter
9
+ from fer import FER
10
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
11
+ from flask import Flask, request, jsonify
12
+ from flask_cors import CORS
13
+ from groq import Groq
14
+ import requests
15
+ from threading import Thread
16
+ import concurrent.futures
17
+
18
+ # Set the environment variables before importing libraries
19
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' # Allow duplicate OpenMP libraries
20
+ os.environ['OMP_NUM_THREADS'] = '1' # Limit the number of OpenMP threads to 1
21
+
22
+ # Flask app for Groq Chatbot
23
+ app = Flask(__name__)
24
+ CORS(app)
25
+
26
+ # Groq API Setup
27
+ client = Groq(api_key="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")
28
+
29
+
30
+ # Configuration des modèles
31
+ weight_model1 = 0.7 # Pondération pour le modèle FER
32
+ weight_model2 = 0.3 # Pondération pour le modèle audio
33
+ pain_threshold = 0.4 # Seuil pour détecter la douleur
34
+ confidence_threshold = 0.3 # Seuil de confiance pour les émotions
35
+ pain_emotions = ["angry", "fear", "sad"] # Émotions liées à la douleur
36
+
37
+ # Fonction pour détecter si l'entrée est un audio ou une vidéo
38
+ def detect_input_type(file_path):
39
+ _, ext = os.path.splitext(file_path)
40
+ if ext.lower() in ['.mp3', '.wav', '.flac']:
41
+ return 'audio'
42
+ elif ext.lower() in ['.mp4', '.avi', '.mov', '.mkv']:
43
+ return 'video'
44
+ else:
45
+ return 'unknown'
46
+
47
+ # ---- Modèle FER (Vision) ----
48
+ def extract_frames_and_analyze(video_path, fer_detector, sampling_rate=1):
49
+ cap = cv2.VideoCapture(video_path)
50
+ pain_scores = []
51
+ frame_indices = []
52
+ frame_count = 0
53
+ while cap.isOpened():
54
+ ret, frame = cap.read()
55
+ if not ret:
56
+ break
57
+
58
+ # Ne traiter qu'une frame sur n pour optimiser la performance
59
+ if frame_count % sampling_rate == 0:
60
+ # Détecter l'émotion dominante
61
+ emotion, score = fer_detector.top_emotion(frame)
62
+ if emotion in pain_emotions and score >= confidence_threshold:
63
+ pain_scores.append(score)
64
+ frame_indices.append(frame_count)
65
+
66
+ frame_count += 1
67
+
68
+ cap.release()
69
+
70
+ # Si des scores sont détectés, appliquer le smoothing
71
+ if pain_scores:
72
+ window_length = min(5, len(pain_scores))
73
+ if window_length % 2 == 0:
74
+ window_length = max(3, window_length - 1)
75
+
76
+ # Ensure window_length is less than or equal to the length of pain_scores
77
+ window_length = min(window_length, len(pain_scores))
78
+
79
+ # Ensure polyorder is less than window_length
80
+ polyorder = min(2, window_length - 1)
81
+
82
+ pain_scores = savgol_filter(pain_scores, window_length, polyorder=polyorder)
83
+
84
+ return pain_scores, frame_indices
85
+
86
+ # ---- Modèle Audio ----
87
+ def analyze_audio(audio_path, model, feature_extractor):
88
+ try:
89
+ audio, sr = librosa.load(audio_path, sr=16000)
90
+ inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
91
+ with torch.no_grad():
92
+ logits = model(**inputs).logits
93
+ probs = torch.nn.functional.softmax(logits, dim=-1)
94
+
95
+ pain_scores = []
96
+ for idx, prob in enumerate(probs[0]):
97
+ emotion = model.config.id2label[idx]
98
+ if emotion in pain_emotions:
99
+ pain_scores.append(prob.item())
100
+ return pain_scores
101
+ except Exception as e:
102
+ print(f"Erreur lors de l'analyse audio : {e}")
103
+ return []
104
+
105
+ # ---- Fusion des scores ----
106
+ def combine_scores(scores_model1, scores_model2, weight1, weight2):
107
+ """Combine scores from FER and audio models using weights."""
108
+
109
+ # If any list is empty, fill it with 0 values to match the other model's length
110
+ if len(scores_model1) == 0:
111
+ scores_model1 = [0] * len(scores_model2)
112
+ if len(scores_model2) == 0:
113
+ scores_model2 = [0] * len(scores_model1)
114
+
115
+ # Combine the scores using weights
116
+ combined_scores = [
117
+ (weight1 * score1 + weight2 * score2)
118
+ for score1, score2 in zip(scores_model1, scores_model2)
119
+ ]
120
+
121
+ return combined_scores
122
+
123
+ # ---- Traitement de l'entrée audio ou vidéo ----
124
+ def process_input(file_path, fer_detector, model, feature_extractor):
125
+ input_type = detect_input_type(file_path)
126
+
127
+ if input_type == 'audio':
128
+ pain_scores_model1 = []
129
+ pain_scores_model2 = analyze_audio(file_path, model, feature_extractor)
130
+ final_scores = pain_scores_model2 # Pas de normalisation nécessaire ici
131
+ elif input_type == 'video':
132
+ # Traitement en parallèle des vidéos et de l'audio
133
+ with concurrent.futures.ThreadPoolExecutor() as executor:
134
+ future_video = executor.submit(extract_frames_and_analyze, file_path, fer_detector, sampling_rate=5)
135
+ future_audio = executor.submit(analyze_audio, file_path, model, feature_extractor)
136
+
137
+ pain_scores_model1, frame_indices = future_video.result()
138
+ pain_scores_model2 = future_audio.result()
139
+
140
+ final_scores = combine_scores(pain_scores_model1, pain_scores_model2, weight_model1, weight_model2)
141
+ else:
142
+ return "Type de fichier non pris en charge. Veuillez fournir un fichier audio ou vidéo."
143
+
144
+ # Décision finale
145
+ average_pain = sum(final_scores) / len(final_scores) if final_scores else 0
146
+ pain_detected = average_pain > pain_threshold
147
+ result = "Pain" if pain_detected else "No Pain"
148
+
149
+ # Affichage des résultats
150
+ if not final_scores:
151
+ plt.text(0.5, 0.5, "No Data Available", ha='center', va='center', fontsize=16)
152
+ else:
153
+ plt.plot(range(len(final_scores)), final_scores, label="Combined Pain Scores", color="purple")
154
+ plt.axhline(y=pain_threshold, color="green", linestyle="--", label="Pain Threshold")
155
+ plt.xlabel("Frame / Sample Index")
156
+ plt.ylabel("Pain Score")
157
+ plt.title("Pain Detection Scores")
158
+ plt.legend()
159
+ plt.grid(True)
160
+
161
+ # Save the graph as a file
162
+ graph_filename = "pain_detection_graph.png"
163
+ plt.savefig(graph_filename)
164
+ plt.close()
165
+
166
+ return result, average_pain, graph_filename
167
+
168
+
169
+ @app.route('/message', methods=['POST'])
170
+ def handle_message():
171
+ user_input = request.json.get('message', '')
172
+ completion = client.chat.completions.create(
173
+ model="llama3-8b-8192",
174
+ messages=[{"role": "user", "content": user_input}],
175
+ temperature=1,
176
+ max_tokens=1024,
177
+ top_p=1,
178
+ stream=True,
179
+ stop=None,
180
+ )
181
+
182
+ response = ""
183
+ for chunk in completion:
184
+ response += chunk.choices[0].delta.content or ""
185
+
186
+ return jsonify({'reply': response})
187
+
188
+ # Chatbot interaction function
189
+ def gradio_interface(file, chatbot_input, state_pain_results):
190
+ model_name = "superb/wav2vec2-large-superb-er"
191
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
192
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
193
+ detector = FER(mtcnn=True)
194
+
195
+ chatbot_response = "How can I assist you today?" # Default chatbot response
196
+ pain_result = ""
197
+ average_pain = ""
198
+ graph_filename = ""
199
+
200
+ # Handle file upload and process it when Submit is clicked
201
+ if file:
202
+ result, average_pain, graph_filename = process_input(file.name, detector, model, feature_extractor)
203
+ state_pain_results["result"] = result
204
+ state_pain_results["average_pain"] = average_pain
205
+ state_pain_results["graph_filename"] = graph_filename
206
+
207
+ # Custom chatbot response based on pain detection
208
+ if result == "No Pain":
209
+ chatbot_response = "It seems there's no pain detected. How can I assist you further?"
210
+ else:
211
+ chatbot_response = "It seems you have some pain. Would you like me to help with it or provide more details?"
212
+
213
+ # Update pain result and graph filename
214
+ pain_result = result
215
+ else:
216
+ # Use the existing state if no new file is uploaded
217
+ pain_result = state_pain_results.get("result", "")
218
+ average_pain = state_pain_results.get("average_pain", "")
219
+ graph_filename = state_pain_results.get("graph_filename", "")
220
+
221
+ # If the chatbot_input field is not empty, process the chat message
222
+ if chatbot_input:
223
+ # Send message to Flask server to get the response from Groq model
224
+ response = requests.post(
225
+ 'http://localhost:5000/message', json={'message': chatbot_input}
226
+ )
227
+ data = response.json()
228
+ chatbot_response = data['reply']
229
+
230
+ # Ensure 4 outputs: pain_result, average_pain, graph_output, chatbot_output
231
+ return pain_result, average_pain, graph_filename, chatbot_response
232
+
233
+
234
+ # Start Flask server in a thread
235
+ def start_flask():
236
+ app.run(debug=True, use_reloader=False)
237
+
238
+ # Launch Gradio and Flask
239
+ if __name__ == "__main__":
240
+ # Start Flask in a separate thread
241
+ flask_thread = Thread(target=start_flask)
242
+ flask_thread.start()
243
+
244
+ # Gradio interface
245
+ with gr.Blocks() as interface:
246
+ gr.Markdown("<h1 style='text-align:center;'>PainSense: AI-Driven Pain Detection and Chatbot Assistance</h1>")
247
+
248
+ with gr.Row():
249
+ with gr.Column(scale=1):
250
+ file_input = gr.File(label="Upload Audio or Video File")
251
+ with gr.Row(): # Place buttons next to each other
252
+ clear_button = gr.Button("Clear", elem_id="clear_btn")
253
+ submit_button = gr.Button("Submit", variant="primary", elem_id="submit_button")
254
+ chatbot_input = gr.Textbox(label="Chat with AI", placeholder="Ask a question...", interactive=True)
255
+ chatbot_output = gr.Textbox(label="Chatbot Response", interactive=False)
256
+
257
+ with gr.Column(scale=1):
258
+ pain_result = gr.Textbox(label="Pain Detection Result")
259
+ average_pain = gr.Textbox(label="Average Pain")
260
+ graph_output = gr.Image(label="Pain Detection Graph")
261
+
262
+ state = gr.State({"result": "", "average_pain": "", "graph_filename": ""})
263
+
264
+ # Clear button resets the UI, including the file input, chatbot input, and outputs
265
+ clear_button.click(lambda: (None, None, "", ""), outputs=[pain_result, average_pain, graph_output, chatbot_output, file_input])
266
+
267
+ # File input only triggers processing when the submit button is clicked
268
+ submit_button.click(
269
+ gradio_interface,
270
+ inputs=[file_input, chatbot_input, state],
271
+ outputs=[pain_result, average_pain, graph_output, chatbot_output],
272
+ )
273
+
274
+ # Chat input triggers chatbot response when 'Enter' is pressed
275
+ chatbot_input.submit(
276
+ lambda file, chatbot_input, state: gradio_interface(file, chatbot_input, state)[-1], # Only update chatbot_output
277
+ inputs=[file_input, chatbot_input, state],
278
+ outputs=[chatbot_output] # Only update chatbot output
279
+ )
280
+
281
+ interface.launch(debug=True)