Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import torch | |
| from typing import Tuple, Optional, Dict, Any | |
| from dataclasses import dataclass | |
| import random | |
| from datetime import datetime, timedelta | |
| class PatientMetadata: | |
| age: int | |
| smoking_status: str | |
| family_history: bool | |
| menopause_status: str | |
| previous_mammogram: bool | |
| breast_density: str | |
| hormone_therapy: bool | |
| class AnalysisResult: | |
| has_tumor: bool | |
| tumor_size: str | |
| confidence: float | |
| metadata: PatientMetadata | |
| class BreastSinogramAnalyzer: | |
| def __init__(self): | |
| """Initialize the analyzer with required models.""" | |
| print("Initializing system...") | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| self._init_vision_models() | |
| self._init_llm() | |
| print("Initialization complete!") | |
| def _init_vision_models(self) -> None: | |
| """Initialize vision models for abnormality detection and size measurement.""" | |
| print("Loading detection models...") | |
| self.tumor_detector = AutoModelForImageClassification.from_pretrained( | |
| "SIATCN/vit_tumor_classifier" | |
| ).to(self.device).eval() | |
| self.tumor_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier") | |
| self.size_detector = AutoModelForImageClassification.from_pretrained( | |
| "SIATCN/vit_tumor_radius_detection_finetuned" | |
| ).to(self.device).eval() | |
| self.size_processor = AutoImageProcessor.from_pretrained( | |
| "SIATCN/vit_tumor_radius_detection_finetuned" | |
| ) | |
| def _init_llm(self) -> None: | |
| """Initialize the language model for report generation.""" | |
| print("Loading language model pipeline...") | |
| self.pipe = pipeline( | |
| "text-generation", | |
| model="unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| model_kwargs={ | |
| "load_in_4bit": True, | |
| "bnb_4bit_compute_dtype": torch.float16, | |
| } | |
| ) | |
| def _generate_synthetic_metadata(self) -> PatientMetadata: | |
| """Generate realistic patient metadata for breast cancer screening.""" | |
| age = random.randint(40, 75) | |
| smoking_status = random.choice(["Never Smoker", "Former Smoker", "Current Smoker"]) | |
| family_history = random.choice([True, False]) | |
| menopause_status = "Post-menopausal" if age > 50 else "Pre-menopausal" | |
| previous_mammogram = random.choice([True, False]) | |
| breast_density = random.choice(["A: Almost entirely fatty", | |
| "B: Scattered fibroglandular", | |
| "C: Heterogeneously dense", | |
| "D: Extremely dense"]) | |
| hormone_therapy = random.choice([True, False]) | |
| return PatientMetadata( | |
| age=age, | |
| smoking_status=smoking_status, | |
| family_history=family_history, | |
| menopause_status=menopause_status, | |
| previous_mammogram=previous_mammogram, | |
| breast_density=breast_density, | |
| hormone_therapy=hormone_therapy | |
| ) | |
| def _process_image(self, image: Image.Image) -> Image.Image: | |
| """Process input image for model consumption.""" | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| return image.resize((224, 224)) | |
| def _analyze_image(self, image: Image.Image) -> AnalysisResult: | |
| """Perform abnormality detection and size measurement.""" | |
| # Generate metadata | |
| metadata = self._generate_synthetic_metadata() | |
| # Detect abnormality | |
| tumor_inputs = self.tumor_processor(image, return_tensors="pt").to(self.device) | |
| tumor_outputs = self.tumor_detector(**tumor_inputs) | |
| tumor_probs = tumor_outputs.logits.softmax(dim=-1)[0].cpu() | |
| has_tumor = tumor_probs[1] > tumor_probs[0] | |
| confidence = float(tumor_probs[1] if has_tumor else tumor_probs[0]) | |
| # Measure size | |
| size_inputs = self.size_processor(image, return_tensors="pt").to(self.device) | |
| size_outputs = self.size_detector(**size_inputs) | |
| size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu() | |
| sizes = ["no-tumor", "0.5", "1.0", "1.5"] | |
| tumor_size = sizes[size_pred.argmax().item()] | |
| return AnalysisResult(has_tumor, tumor_size, confidence, metadata) | |
| def _generate_medical_report(self, analysis: AnalysisResult) -> str: | |
| """Generate a simplified medical report.""" | |
| prompt = f"""<|system|>You are a radiologist providing clear and concise medical reports.</s> | |
| <|user|>Generate a brief medical report for this microwave breast imaging scan: | |
| Findings: | |
| - {'Abnormal' if analysis.has_tumor else 'Normal'} dielectric properties | |
| - Size: {analysis.tumor_size} cm | |
| - Confidence: {analysis.confidence:.2%} | |
| - Patient age: {analysis.metadata.age} | |
| - Risk factors: {', '.join([ | |
| 'family history' if analysis.metadata.family_history else '', | |
| analysis.metadata.smoking_status.lower(), | |
| 'hormone therapy' if analysis.metadata.hormone_therapy else '' | |
| ]).strip(', ')} | |
| Provide: | |
| 1. One sentence interpreting the findings | |
| 2. One clear management recommendation</s> | |
| <|assistant|>""" | |
| try: | |
| response = self.pipe( | |
| prompt, | |
| max_new_tokens=128, | |
| temperature=0.3, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| num_return_sequences=1 | |
| )[0]["generated_text"] | |
| # Extract assistant's response | |
| if "<|assistant|>" in response: | |
| report = response.split("<|assistant|>")[-1].strip() | |
| else: | |
| report = response[len(prompt):].strip() | |
| # Simple validation | |
| if len(report.split()) >= 10: | |
| return f"""INTERPRETATION AND RECOMMENDATION: | |
| {report}""" | |
| print("Report too short, using fallback") | |
| return self._generate_fallback_report(analysis) | |
| except Exception as e: | |
| print(f"Error in report generation: {str(e)}") | |
| return self._generate_fallback_report(analysis) | |
| def _generate_fallback_report(self, analysis: AnalysisResult) -> str: | |
| """Generate a simple fallback report.""" | |
| if analysis.has_tumor: | |
| return f"""INTERPRETATION AND RECOMMENDATION: | |
| Microwave imaging reveals abnormal dielectric properties measuring {analysis.tumor_size} cm with {analysis.confidence:.1%} confidence level. | |
| {'Immediate conventional imaging and clinical correlation recommended.' if analysis.tumor_size in ['1.0', '1.5'] else 'Follow-up imaging recommended in 6 months.'}""" | |
| else: | |
| return f"""INTERPRETATION AND RECOMMENDATION: | |
| Microwave imaging shows normal dielectric properties with {analysis.confidence:.1%} confidence level. | |
| Routine screening recommended per standard protocol.""" | |
| def analyze(self, image: Image.Image) -> str: | |
| """Main analysis pipeline.""" | |
| try: | |
| processed_image = self._process_image(image) | |
| analysis = self._analyze_image(processed_image) | |
| report = self._generate_medical_report(analysis) | |
| return f"""MICROWAVE IMAGING ANALYSIS: | |
| • Detection: {'Positive' if analysis.has_tumor else 'Negative'} | |
| • Size: {analysis.tumor_size} cm | |
| PATIENT INFO: | |
| • Age: {analysis.metadata.age} years | |
| • Risk Factors: {', '.join([ | |
| 'family history' if analysis.metadata.family_history else '', | |
| analysis.metadata.smoking_status.lower(), | |
| 'hormone therapy' if analysis.metadata.hormone_therapy else '', | |
| ]).strip(', ')} | |
| REPORT: | |
| {report}""" | |
| except Exception as e: | |
| return f"Error during analysis: {str(e)}" | |
| def create_interface() -> gr.Interface: | |
| """Create the Gradio interface.""" | |
| analyzer = BreastSinogramAnalyzer() | |
| interface = gr.Interface( | |
| fn=analyzer.analyze, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Breast Microwave Image") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Analysis Results", lines=20) | |
| ], | |
| title="Breast Cancer Microwave Imaging Analysis System", | |
| description="Upload a breast microwave image for comprehensive analysis and medical assessment.", | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| print("Starting application...") | |
| interface = create_interface() | |
| interface.launch(debug=True, share=True) |