import gradio as gr import torch import open_clip import openai import os from PIL import Image ############################################# # ✅ 1) Configure OpenAI API (Latest SDK v1.0+) ############################################# from openai import OpenAI client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # Store API key in env variable GPT_MODEL = "gpt-4" ############################################# # ✅ 2) Load Fine-Tuned BiomedCLIP from HF Hub ############################################# model_name = "hf-hub:mgbam/OpenCLIP-BiomedCLIP-Finetuned" # Replace with your repo model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms( model_name, pretrained=None ) tokenizer = open_clip.get_tokenizer(model_name) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() ############################################# # ✅ 3) Expanded Label Set for Medical Imaging ############################################# LABELS = [ # X-ray & CT Modalities "chest X-ray", "spinal X-ray", "dental X-ray", "abdominal X-ray", "brain CT scan", "chest CT scan", "abdominal CT scan", # MRI Modalities "brain MRI", "spinal MRI", "cardiac MRI", # Ultrasound Modalities "abdominal ultrasound", "thyroid ultrasound", "breast ultrasound", "fetal ultrasound", "echocardiogram", # PET & Nuclear Imaging "PET scan", "bone scan", "thyroid scan", # Histopathology & Microscopy "histopathology (H&E stain)", "immunohistochemistry histopathology", "squamous cell carcinoma histopathology", "adenocarcinoma histopathology", # Other Medical Images "retinal scan", "COVID line chart", "ECG scan", "pie chart", ] ############################################# # ✅ 4) Image Classification with BiomedCLIP ############################################# def classify_image(image: Image.Image): """ Classifies an uploaded medical image using BiomedCLIP zero-shot learning. Returns the best label and confidence score. """ if image is None: return "No image provided." # Preprocess image image_tensor = preprocess_val(image).unsqueeze(0).to(device) # Convert labels into text prompts text_prompts = [f"This is a {label}" for label in LABELS] text_tokens = tokenizer(text_prompts).to(device) with torch.no_grad(): image_features = model.encode_image(image_tensor) text_features = model.encode_text(text_tokens) logit_scale = model.logit_scale.exp() logits = (logit_scale * image_features @ text_features.T).softmax(dim=-1) # Extract top label probs = logits.squeeze(0).cpu().numpy() best_idx = probs.argmax() best_label = LABELS[best_idx] best_confidence = probs[best_idx] return best_label, best_confidence ############################################# # ✅ 5) GPT-4 Medical Explanation ############################################# def gpt4_explanation(label: str, confidence: float): """ Uses GPT-4 to generate a detailed medical explanation of what might be wrong based on the identified imaging modality. """ system_prompt = ( "You are a medical imaging expert. A deep learning model has classified an image " f"as '{label}' with {confidence:.2f} confidence. " "Please provide a concise, expert-level medical explanation about what this imaging modality is used for, " "common abnormalities detected, and potential clinical implications. If applicable, " "suggest possible conditions related to the scan." ) user_prompt = ( f"The AI model classified the image as '{label}' with confidence {confidence:.2f}. " "Please provide a professional-level medical explanation." ) # Call OpenAI GPT-4 (Latest SDK) response = client.chat.completions.create( model=GPT_MODEL, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], temperature=0.7, ) # Extract and return GPT-4 response return response.choices[0].message.content.strip() ############################################# # ✅ 6) Gradio Interface ############################################# def explain_image(image: Image.Image): """ 1) Classify image with BiomedCLIP. 2) Use GPT-4 to generate a detailed explanation of the label. """ if image is None: return "No image provided." # Step 1: Get classification label label, confidence = classify_image(image) # Step 2: Get GPT-4 explanation explanation = gpt4_explanation(label, confidence) # Combine results output = ( f"### **Predicted Label:** {label}\n" f"**Confidence:** {confidence:.4f}\n\n" f"### **AI Medical Explanation:**\n{explanation}" ) return output demo = gr.Interface( fn=explain_image, inputs=gr.Image(type="pil"), outputs="markdown", title="🩺 Medical Image Diagnosis with BiomedCLIP + GPT-4", description=( "Upload a medical image. The AI model will classify it among several medical imaging types " "and GPT-4 will provide a detailed explanation. This tool is for educational and research purposes only." ) ) if __name__ == "__main__": demo.launch()