project / app.py
reshma-05's picture
Create app.py
be2572b verified
raw
history blame
3.45 kB
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load Hugging Face token (paste yours below if needed)
os.environ['HF_TOKEN'] = 'HF_TOKEN'
model_id = "ibm-granite/granite-3.3-2b-instruct"
token = os.getenv("HF_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, token=token)
# Core generation function
def generate_response(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=200, do_sample=True)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Four feature functions
def disease_prediction(symptoms):
prompt = f"Patient has symptoms: {symptoms}. What could be the possible conditions?"
return generate_response(prompt)
def treatment_plan(condition):
prompt = f"What is the treatment plan for {condition}? Include medications, lifestyle changes, and follow-up."
return generate_response(prompt)
def health_analytics(vitals):
prompt = f"Analyze this health data and give insights: {vitals}"
return generate_response(prompt)
def patient_chat(query):
prompt = f"Medical Question: {query}"
return generate_response(prompt)
# Custom CSS
custom_css = """
body {
font-family: 'Segoe UI', sans-serif;
background-color: #f8f9fa;
}
h1, h2 {
color: #114B5F;
font-weight: bold;
}
.gradio-container {
padding: 20px !important;
}
textarea {
border-radius: 10px !important;
border: 1px solid #ccc !important;
}
button {
background-color: #114B5F !important;
color: white !important;
border-radius: 8px !important;
padding: 10px 16px !important;
}
.tabitem {
background-color: #d6ecf3 !important;
padding: 10px;
border-radius: 10px;
}
"""
# Gradio Interface
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# πŸ₯ HealthAI - Generative Healthcare Assistant")
with gr.Tab("🧠 Disease Prediction"):
with gr.Column():
symptom_input = gr.Textbox(label="Enter your symptoms")
disease_output = gr.Textbox(label="Predicted Conditions")
predict_btn = gr.Button("Predict")
predict_btn.click(disease_prediction, inputs=symptom_input, outputs=disease_output)
with gr.Tab("πŸ’Š Treatment Plans"):
with gr.Column():
condition_input = gr.Textbox(label="Enter diagnosed condition")
treatment_output = gr.Textbox(label="Recommended Treatment")
treatment_btn = gr.Button("Get Treatment Plan")
treatment_btn.click(treatment_plan, inputs=condition_input, outputs=treatment_output)
with gr.Tab("πŸ“Š Health Analytics"):
with gr.Column():
vitals_input = gr.Textbox(label="Enter vitals (e.g., heart rate: 80, BP: 120/80...)")
analytics_output = gr.Textbox(label="AI Insights")
analytics_btn = gr.Button("Analyze")
analytics_btn.click(health_analytics, inputs=vitals_input, outputs=analytics_output)
with gr.Tab("πŸ’¬ Patient Chat"):
with gr.Column():
query_input = gr.Textbox(label="Ask a health-related question")
chat_output = gr.Textbox(label="Response")
chat_btn = gr.Button("Ask")
chat_btn.click(patient_chat, inputs=query_input, outputs=chat_output)
demo.launch()