engrphoenix's picture
Update app.py
2ab0991 verified
import streamlit as st
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalGPT-base-zh")
model = AutoModelForCausalLM.from_pretrained("medicalai/ClinicalGPT-base-zh")
import os
# Initialize the Hugging Face pipelines for multiple models
models = {
"ClinicalGPT-base-zh": pipeline("text-generation", model="ClinicalGPT-base-zh")
}
# Function to get medical diagnosis using all models
def get_medical_response(patient_name, age, sex, symptoms, xray_mri=None, medical_reports=None):
# Prepare the input message with the provided patient details
message_content = f"Patient Details:\nName: {patient_name}\nAge: {age}\nSex: {sex}\nSymptoms: {symptoms}"
# If X-ray/MRI file is provided, include it
if xray_mri:
message_content += f"\nX-ray/MRI: {xray_mri}" # File path or additional info
# If medical reports file is provided, include it
if medical_reports:
message_content += f"\nMedical Reports: {medical_reports}" # File path or additional info
# Dictionary to store results from each model
model_results = {}
# Iterate over each model and get the response
for model_name, model_pipeline in models.items():
try:
result = model_pipeline(message_content, max_length=300)
model_results[model_name] = result[0]['generated_text']
except Exception as e:
model_results[model_name] = f"Error: {str(e)}" # Return the error message if something goes wrong
return model_results
# Streamlit UI
def main():
st.title("Medical Diagnosis Assistant")
# Collect patient details
patient_name = st.text_input("Patient Name")
age = st.number_input("Age", min_value=0)
sex = st.radio("Sex", options=["Male", "Female", "Other"])
symptoms = st.text_area("Medical Symptoms")
# Optional file inputs
xray_mri = st.file_uploader("Upload X-ray/MRI Image (Optional)", type=["jpg", "jpeg", "png", "dcm", "pdf"])
medical_reports = st.file_uploader("Upload Medical Reports (Optional)", type=["pdf", "txt", "docx"])
if st.button("Submit"):
# Get medical diagnosis using all models
model_results = get_medical_response(patient_name, age, sex, symptoms, xray_mri.name if xray_mri else None, medical_reports.name if medical_reports else None)
# Display the results for each model
for model_name, diagnosis in model_results.items():
st.subheader(f"Diagnosis from {model_name}:")
st.text_area(f"Diagnosis - {model_name}", diagnosis, height=300)
if __name__ == "__main__":
main()