Spaces:
Sleeping
Sleeping
file modified
Browse files- app.py +89 -62
- requirements.txt +4 -4
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
from gtts import gTTS
|
3 |
import uuid
|
@@ -5,14 +6,9 @@ import base64
|
|
5 |
from io import BytesIO
|
6 |
import os
|
7 |
import logging
|
8 |
-
import torch
|
9 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
-
from torchvision import transforms
|
11 |
-
from PIL import Image
|
12 |
-
import torchvision.models as models
|
13 |
|
14 |
# Set up logger
|
15 |
-
logger = logging.getLogger(
|
16 |
logger.setLevel(logging.DEBUG)
|
17 |
console_handler = logging.StreamHandler()
|
18 |
file_handler = logging.FileHandler('chatbot_log.log')
|
@@ -22,81 +18,102 @@ file_handler.setFormatter(formatter)
|
|
22 |
logger.addHandler(console_handler)
|
23 |
logger.addHandler(file_handler)
|
24 |
|
25 |
-
#
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
#
|
31 |
-
logger.info("Loading Vision model...")
|
32 |
-
vision_model = models.resnet18(pretrained=True)
|
33 |
-
vision_model.eval()
|
34 |
-
preprocess = transforms.Compose([
|
35 |
-
transforms.Resize(256),
|
36 |
-
transforms.CenterCrop(224),
|
37 |
-
transforms.ToTensor(),
|
38 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
39 |
-
])
|
40 |
|
41 |
# Function to encode the image
|
42 |
def encode_image(uploaded_image):
|
43 |
try:
|
44 |
logger.debug("Encoding image...")
|
45 |
buffered = BytesIO()
|
46 |
-
uploaded_image.save(buffered, format="PNG")
|
47 |
logger.debug("Image encoding complete.")
|
48 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
49 |
except Exception as e:
|
50 |
logger.error(f"Error encoding image: {e}")
|
51 |
raise
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
def analyze_image(image):
|
55 |
-
try:
|
56 |
-
logger.info("Analyzing image...")
|
57 |
-
input_tensor = preprocess(image).unsqueeze(0)
|
58 |
-
with torch.no_grad():
|
59 |
-
outputs = vision_model(input_tensor)
|
60 |
-
_, predicted_class = outputs.max(1)
|
61 |
-
logger.info(f"Predicted class: {predicted_class.item()}")
|
62 |
-
return f"The image likely belongs to class {predicted_class.item()}."
|
63 |
-
except Exception as e:
|
64 |
-
logger.error(f"Error analyzing image: {e}")
|
65 |
-
return "An error occurred while analyzing the image."
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
# Function for chatbot logic
|
81 |
def customLLMBot(user_input, uploaded_image, chat_history):
|
82 |
try:
|
83 |
global messages
|
84 |
logger.info("Processing input...")
|
|
|
|
|
85 |
chat_history.append(("user", user_input))
|
86 |
|
87 |
if uploaded_image is not None:
|
88 |
-
#
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
else:
|
93 |
-
#
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
# Generate audio for response
|
99 |
-
LLM_reply = chat_history[-1][1]
|
100 |
audio_file = f"response_{uuid.uuid4().hex}.mp3"
|
101 |
tts = gTTS(LLM_reply, lang='en')
|
102 |
tts.save(audio_file)
|
@@ -106,9 +123,11 @@ def customLLMBot(user_input, uploaded_image, chat_history):
|
|
106 |
return chat_history, audio_file
|
107 |
|
108 |
except Exception as e:
|
|
|
109 |
logger.error(f"Error in customLLMBot function: {e}")
|
110 |
return [(("user", user_input or "Image uploaded"), ("bot", f"An error occurred: {e}"))], None
|
111 |
|
|
|
112 |
# Gradio Interface
|
113 |
def chatbot_ui():
|
114 |
with gr.Blocks() as demo:
|
@@ -117,8 +136,9 @@ def chatbot_ui():
|
|
117 |
# State for user chat history
|
118 |
chat_history = gr.State([])
|
119 |
|
|
|
120 |
with gr.Row():
|
121 |
-
with gr.Column(scale=3):
|
122 |
chatbot = gr.Chatbot(label="Responses", elem_id="chatbot")
|
123 |
user_input = gr.Textbox(
|
124 |
label="Ask a health-related question",
|
@@ -126,7 +146,7 @@ def chatbot_ui():
|
|
126 |
elem_id="user-input",
|
127 |
lines=1,
|
128 |
)
|
129 |
-
with gr.Column(scale=1):
|
130 |
uploaded_image = gr.Image(label="Upload an Image", type="pil")
|
131 |
submit_btn = gr.Button("Submit")
|
132 |
clear_btn = gr.Button("Clear")
|
@@ -134,21 +154,25 @@ def chatbot_ui():
|
|
134 |
|
135 |
# Define actions
|
136 |
def handle_submit(user_query, image, history):
|
|
|
137 |
response, audio = customLLMBot(user_query, image, history)
|
138 |
-
return response, audio, None,
|
139 |
|
|
|
140 |
user_input.submit(
|
141 |
handle_submit,
|
142 |
inputs=[user_input, uploaded_image, chat_history],
|
143 |
-
outputs=[chatbot, audio_output, uploaded_image,
|
144 |
)
|
145 |
|
|
|
146 |
submit_btn.click(
|
147 |
handle_submit,
|
148 |
inputs=[user_input, uploaded_image, chat_history],
|
149 |
-
outputs=[chatbot, audio_output, uploaded_image,
|
150 |
)
|
151 |
|
|
|
152 |
clear_btn.click(
|
153 |
lambda: ([], "", None, []),
|
154 |
inputs=[],
|
@@ -157,5 +181,8 @@ def chatbot_ui():
|
|
157 |
|
158 |
return demo
|
159 |
|
|
|
160 |
# Launch the interface
|
161 |
chatbot_ui().launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
1 |
+
from groq import Groq
|
2 |
import gradio as gr
|
3 |
from gtts import gTTS
|
4 |
import uuid
|
|
|
6 |
from io import BytesIO
|
7 |
import os
|
8 |
import logging
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Set up logger
|
11 |
+
logger = logging.getLogger(_name_)
|
12 |
logger.setLevel(logging.DEBUG)
|
13 |
console_handler = logging.StreamHandler()
|
14 |
file_handler = logging.FileHandler('chatbot_log.log')
|
|
|
18 |
logger.addHandler(console_handler)
|
19 |
logger.addHandler(file_handler)
|
20 |
|
21 |
+
# Initialize Groq Client
|
22 |
+
client = Groq(api_key=os.getenv("GROQ_API_KEY_2"))
|
23 |
+
|
24 |
+
# client = Groq(
|
25 |
+
# api_key="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",
|
26 |
+
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# Function to encode the image
|
29 |
def encode_image(uploaded_image):
|
30 |
try:
|
31 |
logger.debug("Encoding image...")
|
32 |
buffered = BytesIO()
|
33 |
+
uploaded_image.save(buffered, format="PNG") # Ensure the correct format
|
34 |
logger.debug("Image encoding complete.")
|
35 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
36 |
except Exception as e:
|
37 |
logger.error(f"Error encoding image: {e}")
|
38 |
raise
|
39 |
+
def initialize_messages():
|
40 |
+
return [{"role": "system",
|
41 |
+
"content": '''You are Dr. HealthBuddy, a highly experienced and professional virtual doctor chatbot with over 40 years of expertise across all medical fields. You provide health-related information, symptom guidance, lifestyle tips, and actionable solutions using a dataset to reference common symptoms and conditions. Your goal is to offer concise, empathetic, and knowledgeable responses tailored to each patient’s needs.
|
42 |
|
43 |
+
You only respond to health-related inquiries and strive to provide the best possible guidance. Your responses should include clear explanations, actionable steps, and when necessary, advise patients to seek in-person care from a healthcare provider for a proper diagnosis or treatment. Maintain a friendly, professional, and empathetic tone in all your interactions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
Prompt Template:
|
46 |
+
- Input: Patient’s health concerns, including symptoms, questions, or specific issues they mention.
|
47 |
+
- Response: Start with a polite acknowledgment of the patient’s concern. Provide a clear, concise explanation and suggest practical, actionable steps based on the dataset. If needed, advise on when to consult a healthcare provider.
|
48 |
+
|
49 |
+
Examples:
|
50 |
+
|
51 |
+
- User: "I have skin rash and itching. What could it be?"
|
52 |
+
Response: "According to the data, skin rash and itching are common symptoms of conditions like fungal infections. You can try keeping the affected area dry and clean, and using over-the-counter antifungal creams. If the rash persists or worsens, please consult a dermatologist."
|
53 |
+
|
54 |
+
- User: "What might cause nodal skin eruptions?"
|
55 |
+
Response: "Nodal skin eruptions could be linked to conditions such as fungal infections. It's best to monitor the symptoms and avoid scratching. For a proper diagnosis, consider visiting a healthcare provider."
|
56 |
+
|
57 |
+
- User: "I am a 22-year-old female diagnosed with hypothyroidism. I've gained 10 kg recently. What should I do?"
|
58 |
+
Response: "Hi. You have done well managing your hypothyroidism. For effective weight loss, focus on a balanced diet rich in vegetables, lean proteins, and whole grains. Pair this with regular exercise like brisk walking or yoga. Also, consult your endocrinologist to ensure your thyroid levels are well-controlled. Let me know if you have more questions."
|
59 |
+
|
60 |
+
- User: "I’ve been feeling discomfort between my shoulder blades after sitting for long periods. What could this be?"
|
61 |
+
Response: "Hello. The discomfort between your shoulder blades could be related to posture or strain. Try adjusting your sitting position and consider ergonomic changes to your workspace. Over-the-counter pain relievers or hot compresses may help. If the pain persists, consult an orthopedic specialist for further evaluation."
|
62 |
+
|
63 |
+
Always ensure the tone remains compassionate, and offer educational insights while stressing that you are not a substitute for professional medical advice. Encourage users to consult a healthcare provider for any serious or persistent health concerns.'''
|
64 |
+
}]
|
65 |
+
messages=initialize_messages()
|
66 |
|
|
|
67 |
def customLLMBot(user_input, uploaded_image, chat_history):
|
68 |
try:
|
69 |
global messages
|
70 |
logger.info("Processing input...")
|
71 |
+
|
72 |
+
# Append user input to the chat history
|
73 |
chat_history.append(("user", user_input))
|
74 |
|
75 |
if uploaded_image is not None:
|
76 |
+
# Encode the image to base64
|
77 |
+
base64_image = encode_image(uploaded_image)
|
78 |
+
|
79 |
+
# Log the image size and type
|
80 |
+
logger.debug(f"Image received, size: {len(base64_image)} bytes")
|
81 |
+
|
82 |
+
# Create a message for the image prompt
|
83 |
+
messages.append({
|
84 |
+
"role": "user",
|
85 |
+
"content": "What's in this image?",
|
86 |
+
"type": "image_url", # If this is supported in Groq API
|
87 |
+
"image_url": {"url": f"data:image/png;base64,{base64_image}"}
|
88 |
+
})
|
89 |
+
|
90 |
+
logger.info("Sending image to Groq API for processing...")
|
91 |
+
response = client.chat.completions.create(
|
92 |
+
model="llama-3.2-11b-vision-preview",
|
93 |
+
messages=messages,
|
94 |
+
)
|
95 |
+
logger.info("Image processed successfully.")
|
96 |
else:
|
97 |
+
# Process text input
|
98 |
+
logger.info("Processing text input...")
|
99 |
+
messages.append({
|
100 |
+
"role": "user",
|
101 |
+
"content": user_input
|
102 |
+
})
|
103 |
+
response = client.chat.completions.create(
|
104 |
+
model="llama-3.2-11b-vision-preview",
|
105 |
+
messages=messages,
|
106 |
+
)
|
107 |
+
logger.info("Text processed successfully.")
|
108 |
+
|
109 |
+
# Extract the reply
|
110 |
+
LLM_reply = response.choices[0].message.content
|
111 |
+
logger.debug(f"LLM reply: {LLM_reply}")
|
112 |
+
|
113 |
+
# Append the bot's response to the chat history
|
114 |
+
chat_history.append(("bot", LLM_reply))
|
115 |
|
116 |
# Generate audio for response
|
|
|
117 |
audio_file = f"response_{uuid.uuid4().hex}.mp3"
|
118 |
tts = gTTS(LLM_reply, lang='en')
|
119 |
tts.save(audio_file)
|
|
|
123 |
return chat_history, audio_file
|
124 |
|
125 |
except Exception as e:
|
126 |
+
# Handle errors gracefully
|
127 |
logger.error(f"Error in customLLMBot function: {e}")
|
128 |
return [(("user", user_input or "Image uploaded"), ("bot", f"An error occurred: {e}"))], None
|
129 |
|
130 |
+
|
131 |
# Gradio Interface
|
132 |
def chatbot_ui():
|
133 |
with gr.Blocks() as demo:
|
|
|
136 |
# State for user chat history
|
137 |
chat_history = gr.State([])
|
138 |
|
139 |
+
# Layout for chatbot and input box alignment
|
140 |
with gr.Row():
|
141 |
+
with gr.Column(scale=3): # Main column for chatbot
|
142 |
chatbot = gr.Chatbot(label="Responses", elem_id="chatbot")
|
143 |
user_input = gr.Textbox(
|
144 |
label="Ask a health-related question",
|
|
|
146 |
elem_id="user-input",
|
147 |
lines=1,
|
148 |
)
|
149 |
+
with gr.Column(scale=1): # Side column for image and buttons
|
150 |
uploaded_image = gr.Image(label="Upload an Image", type="pil")
|
151 |
submit_btn = gr.Button("Submit")
|
152 |
clear_btn = gr.Button("Clear")
|
|
|
154 |
|
155 |
# Define actions
|
156 |
def handle_submit(user_query, image, history):
|
157 |
+
logger.info("User submitted a query.")
|
158 |
response, audio = customLLMBot(user_query, image, history)
|
159 |
+
return response, audio, None,'', history # Clear the image after submission
|
160 |
|
161 |
+
# Submit on pressing Enter key
|
162 |
user_input.submit(
|
163 |
handle_submit,
|
164 |
inputs=[user_input, uploaded_image, chat_history],
|
165 |
+
outputs=[chatbot, audio_output, uploaded_image,user_input, chat_history],
|
166 |
)
|
167 |
|
168 |
+
# Submit on button click
|
169 |
submit_btn.click(
|
170 |
handle_submit,
|
171 |
inputs=[user_input, uploaded_image, chat_history],
|
172 |
+
outputs=[chatbot, audio_output, uploaded_image,user_input, chat_history],
|
173 |
)
|
174 |
|
175 |
+
# Action for clearing all fields
|
176 |
clear_btn.click(
|
177 |
lambda: ([], "", None, []),
|
178 |
inputs=[],
|
|
|
181 |
|
182 |
return demo
|
183 |
|
184 |
+
|
185 |
# Launch the interface
|
186 |
chatbot_ui().launch(server_name="0.0.0.0", server_port=7860)
|
187 |
+
|
188 |
+
#chatbot_ui().launch(server_name="localhost", server_port=7860)
|
requirements.txt
CHANGED
@@ -2,7 +2,7 @@ gtts
|
|
2 |
gradio
|
3 |
groq
|
4 |
loguru
|
5 |
-
torch
|
6 |
-
transformers
|
7 |
-
torchvision
|
8 |
-
pillow
|
|
|
2 |
gradio
|
3 |
groq
|
4 |
loguru
|
5 |
+
# torch
|
6 |
+
# transformers
|
7 |
+
# torchvision
|
8 |
+
# pillow
|