Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile, HTTPException | |
import torch | |
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig | |
from PIL import Image | |
import io | |
import base64 | |
import os | |
import logging | |
from huggingface_hub import login | |
# Enable logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
MODEL_NAME = "mervinpraison/Llama-3.2-11B-Vision-Radiology-mini" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Set Hugging Face Cache Directory | |
os.environ["HF_HOME"] = "/tmp/huggingface" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" | |
# Ensure Hugging Face Token is Set | |
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
if not HF_TOKEN: | |
logger.error("Hugging Face token not found! Set HUGGINGFACE_TOKEN in the environment.") | |
raise RuntimeError("Hugging Face token missing. Set it in your environment.") | |
# Login to Hugging Face | |
try: | |
login(HF_TOKEN) | |
except Exception as e: | |
logger.error(f"Failed to authenticate Hugging Face token: {e}") | |
raise RuntimeError("Authentication with Hugging Face failed.") | |
# Configure Quantization | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, # Change to load_in_8bit=True if 4-bit fails | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
) | |
# Load Model | |
try: | |
logger.info("Loading model and processor...") | |
processor = AutoProcessor.from_pretrained( | |
MODEL_NAME, cache_dir="/tmp/huggingface", force_download=False | |
) | |
model = AutoModelForImageTextToText.from_pretrained( | |
MODEL_NAME, quantization_config=quantization_config, cache_dir="/tmp/huggingface" | |
).to(DEVICE) | |
torch.backends.cuda.matmul.allow_tf32 = True # Optimize CUDA | |
torch.cuda.empty_cache() | |
logger.info("Model loaded successfully.") | |
except Exception as e: | |
logger.error(f"Failed to load model: {e}") | |
raise RuntimeError("Model loading failed. Check model accessibility.") | |
# Allowed Formats | |
ALLOWED_FORMATS = {"jpeg", "jpg", "png", "bmp", "tiff"} | |
async def predict(file: UploadFile = File(...)): | |
try: | |
ext = file.filename.split(".")[-1].lower() | |
if ext not in ALLOWED_FORMATS: | |
raise HTTPException(status_code=400, detail=f"Invalid file format: {ext}. Upload an image file.") | |
# Read Image | |
image_bytes = await file.read() | |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
# Convert Image to Base64 | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
# Validation Step | |
validation_prompt = "Is this a medical X-ray or CT scan? Answer only 'yes' or 'no'." | |
validation_inputs = processor( | |
text=validation_prompt, images=image, return_tensors="pt" | |
).to(DEVICE) | |
with torch.no_grad(): | |
validation_output = model.generate( | |
**validation_inputs, max_new_tokens=10, temperature=0.1, top_p=0.7, top_k=50 | |
) | |
validation_result = processor.batch_decode(validation_output, skip_special_tokens=True)[0].strip().lower() | |
logger.info(f"Validation result: {validation_result}") | |
if "yes" not in validation_result: | |
raise HTTPException(status_code=400, detail="Uploaded image is not an X-ray or CT scan.") | |
# Analysis Step | |
analysis_prompt = """Analyze this X-ray image and provide a detailed medical report: | |
Type of X-ray: | |
Key Findings: | |
• [Findings] | |
Potential Conditions: | |
• [Possible Diagnoses] | |
Recommendations: | |
• [Follow-up Actions] | |
""" | |
analysis_inputs = processor(text=analysis_prompt, images=image, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
analysis_output = model.generate( | |
**analysis_inputs, max_new_tokens=512, temperature=0.7, top_p=0.7, top_k=50 | |
) | |
analysis_content = processor.batch_decode(analysis_output, skip_special_tokens=True)[0] | |
cleaned_analysis = ( | |
analysis_content.replace("**", "").replace("*", "•").replace("_", "").strip() | |
) | |
return {"analysis": cleaned_analysis} | |
except HTTPException as http_err: | |
logger.error(f"Validation error: {http_err.detail}") | |
raise http_err | |
except Exception as e: | |
logger.error(f"Unexpected error: {e}") | |
raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}") | |