Spaces:
Running
Running
import streamlit as st | |
from transformers import pipeline | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalGPT-base-zh") | |
model = AutoModelForCausalLM.from_pretrained("medicalai/ClinicalGPT-base-zh") | |
import os | |
# Initialize the Hugging Face pipelines for multiple models | |
models = { | |
"ClinicalGPT-base-zh": pipeline("text-generation", model="ClinicalGPT-base-zh") | |
} | |
# Function to get medical diagnosis using all models | |
def get_medical_response(patient_name, age, sex, symptoms, xray_mri=None, medical_reports=None): | |
# Prepare the input message with the provided patient details | |
message_content = f"Patient Details:\nName: {patient_name}\nAge: {age}\nSex: {sex}\nSymptoms: {symptoms}" | |
# If X-ray/MRI file is provided, include it | |
if xray_mri: | |
message_content += f"\nX-ray/MRI: {xray_mri}" # File path or additional info | |
# If medical reports file is provided, include it | |
if medical_reports: | |
message_content += f"\nMedical Reports: {medical_reports}" # File path or additional info | |
# Dictionary to store results from each model | |
model_results = {} | |
# Iterate over each model and get the response | |
for model_name, model_pipeline in models.items(): | |
try: | |
result = model_pipeline(message_content, max_length=300) | |
model_results[model_name] = result[0]['generated_text'] | |
except Exception as e: | |
model_results[model_name] = f"Error: {str(e)}" # Return the error message if something goes wrong | |
return model_results | |
# Streamlit UI | |
def main(): | |
st.title("Medical Diagnosis Assistant") | |
# Collect patient details | |
patient_name = st.text_input("Patient Name") | |
age = st.number_input("Age", min_value=0) | |
sex = st.radio("Sex", options=["Male", "Female", "Other"]) | |
symptoms = st.text_area("Medical Symptoms") | |
# Optional file inputs | |
xray_mri = st.file_uploader("Upload X-ray/MRI Image (Optional)", type=["jpg", "jpeg", "png", "dcm", "pdf"]) | |
medical_reports = st.file_uploader("Upload Medical Reports (Optional)", type=["pdf", "txt", "docx"]) | |
if st.button("Submit"): | |
# Get medical diagnosis using all models | |
model_results = get_medical_response(patient_name, age, sex, symptoms, xray_mri.name if xray_mri else None, medical_reports.name if medical_reports else None) | |
# Display the results for each model | |
for model_name, diagnosis in model_results.items(): | |
st.subheader(f"Diagnosis from {model_name}:") | |
st.text_area(f"Diagnosis - {model_name}", diagnosis, height=300) | |
if __name__ == "__main__": | |
main() | |