Reshmarb commited on
Commit
bfa8c19
·
1 Parent(s): 20eb8e9

file added

Browse files
Files changed (1) hide show
  1. app.py +42 -63
app.py CHANGED
@@ -30,36 +30,41 @@ file_handler.setFormatter(formatter)
30
  logger.addHandler(console_handler)
31
  logger.addHandler(file_handler)
32
 
33
- #Initialize Groq Client
34
  client = Groq(api_key=os.getenv("GROQ_API_KEY_1"))
35
- # logger.info(f"API Key: {client}") # Just for debugging
36
-
37
- # # Initialize Groq Client
38
- #client = Groq(api_key="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")
39
 
40
  # Initialize spaCy NLP model for named entity recognition (NER)
41
  spacy.cli.download("en_core_web_sm")
42
  nlp = spacy.load("en_core_web_sm")
 
43
 
44
  # Initialize sentiment analysis model using Hugging Face
45
  sentiment_analyzer = pipeline("sentiment-analysis")
 
46
 
47
  # Load pre-trained YOLOv5 model
48
  def load_yolov5_model():
 
49
  model = torch.hub.load(r"ultralytics/yolov5", 'custom', path=r'./models/best.pt')
50
  model.eval()
 
51
  return model
52
 
53
  model = load_yolov5_model()
54
 
55
  # Function to preprocess user input for better NLP understanding
56
  def preprocess_input(user_input):
 
57
  user_input = user_input.strip().lower()
 
58
  return user_input
59
 
60
  # Function for sentiment analysis (optional)
61
  def analyze_sentiment(user_input):
 
62
  result = sentiment_analyzer(user_input)
 
63
  return result[0]['label']
64
 
65
  # Function to extract medical entities from input using NER
@@ -90,79 +95,76 @@ diseases = [
90
  ]
91
 
92
  def extract_medical_entities(user_input):
 
93
  user_input = preprocess_input(user_input)
94
  medical_entities = []
95
  for word in user_input.split():
96
  if word in symptoms or word in diseases:
97
  medical_entities.append(word)
 
98
  return medical_entities
99
 
100
  # Function to encode the image
101
  def encode_image(uploaded_image):
102
  try:
103
- logger.debug("Encoding image...")
104
  buffered = BytesIO()
105
  uploaded_image.save(buffered, format="PNG")
106
- logger.debug("Image encoding complete.")
107
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
 
108
  except Exception as e:
109
  logger.error(f"Error encoding image: {e}")
110
  raise
111
 
112
  # Initialize messages
113
  def initialize_messages():
114
- return [{"role": "system", "content": '''You are Dr. HealthBuddy, a professional, empathetic, and knowledgeable virtual doctor chatbot.'''}]
 
 
 
115
 
116
  messages = initialize_messages()
117
 
118
  # Function for image prediction using YOLOv5
119
  def predict_image(image):
120
  try:
121
- # Debug: Check if the image is None
122
  if image is None:
 
123
  return "Error: No image uploaded.", "No description available."
124
 
125
- # Convert PIL image to NumPy array (OpenCV format)
126
- image_np = np.array(image) # Convert PIL image to NumPy array
127
-
128
- # Convert RGB to BGR (OpenCV uses BGR by default)
129
  image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
130
-
131
- # Resize the image to match the model's expected input size
132
  image_resized = cv2.resize(image_np, (224, 224))
133
 
134
- # Transform the image for the model
135
  transform = transforms.Compose([
136
- transforms.ToTensor(), # Convert image to tensor
137
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize
138
  ])
139
- im = transform(image_resized).unsqueeze(0) # Add batch dimension (BCHW)
140
 
141
- # Get predictions
142
  with torch.no_grad():
143
- output = model(im) # Raw model output (logits)
144
 
145
- # Apply softmax to get confidence scores
146
  softmax = torch.nn.Softmax(dim=1)
147
  probs = softmax(output)
148
 
149
- # Get the predicted class and its confidence score
150
  predicted_class_id = torch.argmax(probs, dim=1).item()
151
  confidence_score = probs[0, predicted_class_id].item()
152
 
153
- # Get predicted class name if available
154
  if hasattr(model, 'names'):
155
  class_name = model.names[predicted_class_id]
156
  prediction_result = f"Predicted Class: {class_name}\nConfidence: {confidence_score:.4f}"
157
- description = get_description(class_name) # Function to get description
158
  else:
159
  prediction_result = f"Predicted Class ID: {predicted_class_id}\nConfidence: {confidence_score:.4f}"
160
  description = "No description available."
161
 
162
- # Display the image with OpenCV (optional)
163
  cv2.imshow("Processed Image", image_resized)
164
- cv2.waitKey(1) # Wait for 1 ms to display the image
165
 
 
166
  return prediction_result, description
167
 
168
  except Exception as e:
@@ -171,13 +173,16 @@ def predict_image(image):
171
 
172
  # Function to get description based on predicted class
173
  def get_description(class_name):
 
174
  descriptions = {
175
  "bcc": "Basal cell carcinoma (BCC) is a type of skin cancer that begins in the basal cells. It often appears as a slightly transparent bump on the skin, though it can take other forms. BCC grows slowly and is unlikely to spread to other parts of the body, but early treatment is important to prevent damage to surrounding tissues.",
176
  "atopic": "Atopic dermatitis is a chronic skin condition characterized by itchy, inflamed skin. It is common in individuals with a family history of allergies or asthma.",
177
  "acne": "Acne is a skin condition that occurs when hair follicles become clogged with oil and dead skin cells. It often causes pimples, blackheads, and whiteheads, and is most common among teenagers.",
178
  # Add more descriptions as needed
179
  }
180
- return descriptions.get(class_name.lower(), "No description available.")
 
 
181
 
182
  # Custom LLM Bot Function
183
  def customLLMBot(user_input, uploaded_image, chat_history):
@@ -185,27 +190,19 @@ def customLLMBot(user_input, uploaded_image, chat_history):
185
  global messages
186
  logger.info("Processing input...")
187
 
188
- # Preprocess the user input
189
  user_input = preprocess_input(user_input)
190
-
191
- # Analyze sentiment (Optional)
192
  sentiment = analyze_sentiment(user_input)
193
  logger.info(f"Sentiment detected: {sentiment}")
194
 
195
- # Extract medical entities (Optional)
196
  medical_entities = extract_medical_entities(user_input)
197
  logger.info(f"Extracted medical entities: {medical_entities}")
198
 
199
- # Append user input to the chat history
200
  chat_history.append(("user", user_input))
201
 
202
  if uploaded_image is not None:
203
- # Encode the image to base64
204
  base64_image = encode_image(uploaded_image)
205
-
206
  logger.debug(f"Image received, size: {len(base64_image)} bytes")
207
 
208
- # Create a message for the image prompt
209
  messages_image = [
210
  {
211
  "role": "user",
@@ -223,7 +220,6 @@ def customLLMBot(user_input, uploaded_image, chat_history):
223
  )
224
  logger.info("Image processed successfully.")
225
  else:
226
- # Process text input
227
  logger.info("Processing text input...")
228
  messages.append({
229
  "role": "user",
@@ -235,21 +231,17 @@ def customLLMBot(user_input, uploaded_image, chat_history):
235
  )
236
  logger.info("Text processed successfully.")
237
 
238
- # Extract the reply
239
  LLM_reply = response.choices[0].message.content
240
  logger.debug(f"LLM reply: {LLM_reply}")
241
 
242
- # Append the bot's response to the chat history
243
  chat_history.append(("bot", LLM_reply))
244
  messages.append({"role": "assistant", "content": LLM_reply})
245
 
246
- # Generate audio for response
247
  audio_file = f"response_{uuid.uuid4().hex}.mp3"
248
  tts = gTTS(LLM_reply, lang='en')
249
  tts.save(audio_file)
250
  logger.info(f"Audio response saved as {audio_file}")
251
 
252
- # Return chat history and audio file
253
  return chat_history, audio_file
254
 
255
  except Exception as e:
@@ -258,15 +250,14 @@ def customLLMBot(user_input, uploaded_image, chat_history):
258
 
259
  # Gradio Interface
260
  def chatbot_ui():
 
261
  with gr.Blocks() as demo:
262
  gr.Markdown("# Healthcare Chatbot Doctor")
263
 
264
- # State for user chat history
265
  chat_history = gr.State([])
266
 
267
- # Layout for chatbot and input box alignment
268
  with gr.Row():
269
- with gr.Column(scale=3): # Main column for chatbot
270
  chatbot = gr.Chatbot(label="Responses", elem_id="chatbot")
271
  user_input = gr.Textbox(
272
  label="Ask a health-related question",
@@ -274,81 +265,69 @@ def chatbot_ui():
274
  elem_id="user-input",
275
  lines=1,
276
  )
277
- with gr.Column(scale=1): # Side column for image and buttons
278
  uploaded_image = gr.Image(label="Upload an Image", type="pil")
279
  submit_btn = gr.Button("Submit")
280
  clear_btn = gr.Button("Clear")
281
  audio_output = gr.Audio(label="Audio Response")
282
 
283
- # New section for image prediction (left and right layout)
284
  with gr.Row():
285
- # Left side: Upload image
286
  with gr.Column():
287
  gr.Markdown("### Upload Image for Prediction")
288
  prediction_image = gr.Image(label="Upload Image", type="pil")
289
  predict_btn = gr.Button("Predict")
290
 
291
- # Right side: Prediction result and description
292
  with gr.Column():
293
  gr.Markdown("### Prediction Result")
294
  prediction_output = gr.Textbox(label="Result", interactive=False)
295
 
296
- # Description column
297
  gr.Markdown("### Description")
298
  description_output = gr.Textbox(label="Description", interactive=False)
299
 
300
- # Clear button for prediction result (below description box)
301
  clear_prediction_btn = gr.Button("Clear Prediction")
302
 
303
- # Define actions
304
  def handle_submit(user_query, image, history):
305
  logger.info("User submitted a query.")
306
  response, audio = customLLMBot(user_query, image, history)
307
  return response, audio, None, "", history
308
 
309
- # Clear prediction result and image
310
  def clear_prediction(prediction_image, prediction_output, description_output):
 
311
  return None, "", ""
312
 
313
- # Submit on pressing Enter key
314
  user_input.submit(
315
  handle_submit,
316
  inputs=[user_input, uploaded_image, chat_history],
317
  outputs=[chatbot, audio_output, uploaded_image, user_input, chat_history],
318
  )
319
 
320
- # Submit on button click
321
  submit_btn.click(
322
  handle_submit,
323
  inputs=[user_input, uploaded_image, chat_history],
324
  outputs=[chatbot, audio_output, uploaded_image, user_input, chat_history],
325
  )
326
 
327
- # Action for clearing all fields
328
  clear_btn.click(
329
  lambda: ([], "", None, []),
330
  inputs=[],
331
  outputs=[chatbot, user_input, uploaded_image, chat_history],
332
  )
333
 
334
- # Action for image prediction
335
  predict_btn.click(
336
  predict_image,
337
  inputs=[prediction_image],
338
- outputs=[prediction_output, description_output], # Update both outputs
339
  )
340
 
341
- # Action for clearing prediction result and image
342
  clear_prediction_btn.click(
343
  clear_prediction,
344
  inputs=[prediction_image, prediction_output, description_output],
345
  outputs=[prediction_image, prediction_output, description_output],
346
  )
347
 
 
348
  return demo
349
 
350
  # Launch the interface
351
- #chatbot_ui().launch(server_name="localhost", server_port=7860)
352
-
353
- # Launch the interface
354
- chatbot_ui().launch(server_name="0.0.0.0", server_port=7860)
 
30
  logger.addHandler(console_handler)
31
  logger.addHandler(file_handler)
32
 
33
+ # Initialize Groq Client
34
  client = Groq(api_key=os.getenv("GROQ_API_KEY_1"))
35
+ logger.info("Groq client initialized.")
 
 
 
36
 
37
  # Initialize spaCy NLP model for named entity recognition (NER)
38
  spacy.cli.download("en_core_web_sm")
39
  nlp = spacy.load("en_core_web_sm")
40
+ logger.info("spaCy NLP model loaded.")
41
 
42
  # Initialize sentiment analysis model using Hugging Face
43
  sentiment_analyzer = pipeline("sentiment-analysis")
44
+ logger.info("Sentiment analysis model loaded.")
45
 
46
  # Load pre-trained YOLOv5 model
47
  def load_yolov5_model():
48
+ logger.info("Loading YOLOv5 model...")
49
  model = torch.hub.load(r"ultralytics/yolov5", 'custom', path=r'./models/best.pt')
50
  model.eval()
51
+ logger.info("YOLOv5 model loaded and set to evaluation mode.")
52
  return model
53
 
54
  model = load_yolov5_model()
55
 
56
  # Function to preprocess user input for better NLP understanding
57
  def preprocess_input(user_input):
58
+ logger.info("Preprocessing user input...")
59
  user_input = user_input.strip().lower()
60
+ logger.info(f"Preprocessed input: {user_input}")
61
  return user_input
62
 
63
  # Function for sentiment analysis (optional)
64
  def analyze_sentiment(user_input):
65
+ logger.info("Analyzing sentiment...")
66
  result = sentiment_analyzer(user_input)
67
+ logger.info(f"Sentiment analysis result: {result[0]['label']}")
68
  return result[0]['label']
69
 
70
  # Function to extract medical entities from input using NER
 
95
  ]
96
 
97
  def extract_medical_entities(user_input):
98
+ logger.info("Extracting medical entities...")
99
  user_input = preprocess_input(user_input)
100
  medical_entities = []
101
  for word in user_input.split():
102
  if word in symptoms or word in diseases:
103
  medical_entities.append(word)
104
+ logger.info(f"Extracted medical entities: {medical_entities}")
105
  return medical_entities
106
 
107
  # Function to encode the image
108
  def encode_image(uploaded_image):
109
  try:
110
+ logger.info("Encoding image...")
111
  buffered = BytesIO()
112
  uploaded_image.save(buffered, format="PNG")
113
+ encoded_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
114
+ logger.info("Image encoding complete.")
115
+ return encoded_image
116
  except Exception as e:
117
  logger.error(f"Error encoding image: {e}")
118
  raise
119
 
120
  # Initialize messages
121
  def initialize_messages():
122
+ logger.info("Initializing messages...")
123
+ messages = [{"role": "system", "content": '''You are Dr. HealthBuddy, a professional, empathetic, and knowledgeable virtual doctor chatbot.'''}]
124
+ logger.info("Messages initialized.")
125
+ return messages
126
 
127
  messages = initialize_messages()
128
 
129
  # Function for image prediction using YOLOv5
130
  def predict_image(image):
131
  try:
132
+ logger.info("Predicting image...")
133
  if image is None:
134
+ logger.error("No image uploaded.")
135
  return "Error: No image uploaded.", "No description available."
136
 
137
+ image_np = np.array(image)
 
 
 
138
  image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
 
 
139
  image_resized = cv2.resize(image_np, (224, 224))
140
 
 
141
  transform = transforms.Compose([
142
+ transforms.ToTensor(),
143
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
144
  ])
145
+ im = transform(image_resized).unsqueeze(0)
146
 
 
147
  with torch.no_grad():
148
+ output = model(im)
149
 
 
150
  softmax = torch.nn.Softmax(dim=1)
151
  probs = softmax(output)
152
 
 
153
  predicted_class_id = torch.argmax(probs, dim=1).item()
154
  confidence_score = probs[0, predicted_class_id].item()
155
 
 
156
  if hasattr(model, 'names'):
157
  class_name = model.names[predicted_class_id]
158
  prediction_result = f"Predicted Class: {class_name}\nConfidence: {confidence_score:.4f}"
159
+ description = get_description(class_name)
160
  else:
161
  prediction_result = f"Predicted Class ID: {predicted_class_id}\nConfidence: {confidence_score:.4f}"
162
  description = "No description available."
163
 
 
164
  cv2.imshow("Processed Image", image_resized)
165
+ cv2.waitKey(1)
166
 
167
+ logger.info(f"Prediction result: {prediction_result}")
168
  return prediction_result, description
169
 
170
  except Exception as e:
 
173
 
174
  # Function to get description based on predicted class
175
  def get_description(class_name):
176
+ logger.info(f"Getting description for class: {class_name}")
177
  descriptions = {
178
  "bcc": "Basal cell carcinoma (BCC) is a type of skin cancer that begins in the basal cells. It often appears as a slightly transparent bump on the skin, though it can take other forms. BCC grows slowly and is unlikely to spread to other parts of the body, but early treatment is important to prevent damage to surrounding tissues.",
179
  "atopic": "Atopic dermatitis is a chronic skin condition characterized by itchy, inflamed skin. It is common in individuals with a family history of allergies or asthma.",
180
  "acne": "Acne is a skin condition that occurs when hair follicles become clogged with oil and dead skin cells. It often causes pimples, blackheads, and whiteheads, and is most common among teenagers.",
181
  # Add more descriptions as needed
182
  }
183
+ description = descriptions.get(class_name.lower(), "No description available.")
184
+ logger.info(f"Description: {description}")
185
+ return description
186
 
187
  # Custom LLM Bot Function
188
  def customLLMBot(user_input, uploaded_image, chat_history):
 
190
  global messages
191
  logger.info("Processing input...")
192
 
 
193
  user_input = preprocess_input(user_input)
 
 
194
  sentiment = analyze_sentiment(user_input)
195
  logger.info(f"Sentiment detected: {sentiment}")
196
 
 
197
  medical_entities = extract_medical_entities(user_input)
198
  logger.info(f"Extracted medical entities: {medical_entities}")
199
 
 
200
  chat_history.append(("user", user_input))
201
 
202
  if uploaded_image is not None:
 
203
  base64_image = encode_image(uploaded_image)
 
204
  logger.debug(f"Image received, size: {len(base64_image)} bytes")
205
 
 
206
  messages_image = [
207
  {
208
  "role": "user",
 
220
  )
221
  logger.info("Image processed successfully.")
222
  else:
 
223
  logger.info("Processing text input...")
224
  messages.append({
225
  "role": "user",
 
231
  )
232
  logger.info("Text processed successfully.")
233
 
 
234
  LLM_reply = response.choices[0].message.content
235
  logger.debug(f"LLM reply: {LLM_reply}")
236
 
 
237
  chat_history.append(("bot", LLM_reply))
238
  messages.append({"role": "assistant", "content": LLM_reply})
239
 
 
240
  audio_file = f"response_{uuid.uuid4().hex}.mp3"
241
  tts = gTTS(LLM_reply, lang='en')
242
  tts.save(audio_file)
243
  logger.info(f"Audio response saved as {audio_file}")
244
 
 
245
  return chat_history, audio_file
246
 
247
  except Exception as e:
 
250
 
251
  # Gradio Interface
252
  def chatbot_ui():
253
+ logger.info("Setting up Gradio interface...")
254
  with gr.Blocks() as demo:
255
  gr.Markdown("# Healthcare Chatbot Doctor")
256
 
 
257
  chat_history = gr.State([])
258
 
 
259
  with gr.Row():
260
+ with gr.Column(scale=3):
261
  chatbot = gr.Chatbot(label="Responses", elem_id="chatbot")
262
  user_input = gr.Textbox(
263
  label="Ask a health-related question",
 
265
  elem_id="user-input",
266
  lines=1,
267
  )
268
+ with gr.Column(scale=1):
269
  uploaded_image = gr.Image(label="Upload an Image", type="pil")
270
  submit_btn = gr.Button("Submit")
271
  clear_btn = gr.Button("Clear")
272
  audio_output = gr.Audio(label="Audio Response")
273
 
 
274
  with gr.Row():
 
275
  with gr.Column():
276
  gr.Markdown("### Upload Image for Prediction")
277
  prediction_image = gr.Image(label="Upload Image", type="pil")
278
  predict_btn = gr.Button("Predict")
279
 
 
280
  with gr.Column():
281
  gr.Markdown("### Prediction Result")
282
  prediction_output = gr.Textbox(label="Result", interactive=False)
283
 
 
284
  gr.Markdown("### Description")
285
  description_output = gr.Textbox(label="Description", interactive=False)
286
 
 
287
  clear_prediction_btn = gr.Button("Clear Prediction")
288
 
 
289
  def handle_submit(user_query, image, history):
290
  logger.info("User submitted a query.")
291
  response, audio = customLLMBot(user_query, image, history)
292
  return response, audio, None, "", history
293
 
 
294
  def clear_prediction(prediction_image, prediction_output, description_output):
295
+ logger.info("Clearing prediction results.")
296
  return None, "", ""
297
 
 
298
  user_input.submit(
299
  handle_submit,
300
  inputs=[user_input, uploaded_image, chat_history],
301
  outputs=[chatbot, audio_output, uploaded_image, user_input, chat_history],
302
  )
303
 
 
304
  submit_btn.click(
305
  handle_submit,
306
  inputs=[user_input, uploaded_image, chat_history],
307
  outputs=[chatbot, audio_output, uploaded_image, user_input, chat_history],
308
  )
309
 
 
310
  clear_btn.click(
311
  lambda: ([], "", None, []),
312
  inputs=[],
313
  outputs=[chatbot, user_input, uploaded_image, chat_history],
314
  )
315
 
 
316
  predict_btn.click(
317
  predict_image,
318
  inputs=[prediction_image],
319
+ outputs=[prediction_output, description_output],
320
  )
321
 
 
322
  clear_prediction_btn.click(
323
  clear_prediction,
324
  inputs=[prediction_image, prediction_output, description_output],
325
  outputs=[prediction_image, prediction_output, description_output],
326
  )
327
 
328
+ logger.info("Gradio interface setup complete.")
329
  return demo
330
 
331
  # Launch the interface
332
+ logger.info("Launching chatbot interface...")
333
+ chatbot_ui().launch(server_name="0.0.0.0", server_port=7860)