|
|
|
|
|
import os |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
os.environ['HF_TOKEN'] = 'HF_TOKEN' |
|
|
|
model_id = "ibm-granite/granite-3.3-2b-instruct" |
|
token = os.getenv("HF_TOKEN") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) |
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, token=token) |
|
|
|
|
|
def generate_response(prompt): |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
outputs = model.generate(**inputs, max_new_tokens=200, do_sample=True) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
def disease_prediction(symptoms): |
|
prompt = f"Patient has symptoms: {symptoms}. What could be the possible conditions?" |
|
return generate_response(prompt) |
|
|
|
def treatment_plan(condition): |
|
prompt = f"What is the treatment plan for {condition}? Include medications, lifestyle changes, and follow-up." |
|
return generate_response(prompt) |
|
|
|
def health_analytics(vitals): |
|
prompt = f"Analyze this health data and give insights: {vitals}" |
|
return generate_response(prompt) |
|
|
|
def patient_chat(query): |
|
prompt = f"Medical Question: {query}" |
|
return generate_response(prompt) |
|
|
|
|
|
custom_css = """ |
|
body { |
|
font-family: 'Segoe UI', sans-serif; |
|
background-color: #f8f9fa; |
|
} |
|
h1, h2 { |
|
color: #114B5F; |
|
font-weight: bold; |
|
} |
|
.gradio-container { |
|
padding: 20px !important; |
|
} |
|
textarea { |
|
border-radius: 10px !important; |
|
border: 1px solid #ccc !important; |
|
} |
|
button { |
|
background-color: #114B5F !important; |
|
color: white !important; |
|
border-radius: 8px !important; |
|
padding: 10px 16px !important; |
|
} |
|
.tabitem { |
|
background-color: #d6ecf3 !important; |
|
padding: 10px; |
|
border-radius: 10px; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
gr.Markdown("# π₯ HealthAI - Generative Healthcare Assistant") |
|
|
|
with gr.Tab("π§ Disease Prediction"): |
|
with gr.Column(): |
|
symptom_input = gr.Textbox(label="Enter your symptoms") |
|
disease_output = gr.Textbox(label="Predicted Conditions") |
|
predict_btn = gr.Button("Predict") |
|
predict_btn.click(disease_prediction, inputs=symptom_input, outputs=disease_output) |
|
|
|
with gr.Tab("π Treatment Plans"): |
|
with gr.Column(): |
|
condition_input = gr.Textbox(label="Enter diagnosed condition") |
|
treatment_output = gr.Textbox(label="Recommended Treatment") |
|
treatment_btn = gr.Button("Get Treatment Plan") |
|
treatment_btn.click(treatment_plan, inputs=condition_input, outputs=treatment_output) |
|
|
|
with gr.Tab("π Health Analytics"): |
|
with gr.Column(): |
|
vitals_input = gr.Textbox(label="Enter vitals (e.g., heart rate: 80, BP: 120/80...)") |
|
analytics_output = gr.Textbox(label="AI Insights") |
|
analytics_btn = gr.Button("Analyze") |
|
analytics_btn.click(health_analytics, inputs=vitals_input, outputs=analytics_output) |
|
|
|
with gr.Tab("π¬ Patient Chat"): |
|
with gr.Column(): |
|
query_input = gr.Textbox(label="Ask a health-related question") |
|
chat_output = gr.Textbox(label="Response") |
|
chat_btn = gr.Button("Ask") |
|
chat_btn.click(patient_chat, inputs=query_input, outputs=chat_output) |
|
|
|
demo.launch() |