PAVULURI KIRAN
Updated FastAPI app and requirements
7818ba9
from fastapi import FastAPI, File, UploadFile, HTTPException
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from PIL import Image
import io
import base64
import os
import logging
from huggingface_hub import login
# Enable logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
MODEL_NAME = "mervinpraison/Llama-3.2-11B-Vision-Radiology-mini"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Set Hugging Face Cache Directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
# Ensure Hugging Face Token is Set
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if not HF_TOKEN:
logger.error("Hugging Face token not found! Set HUGGINGFACE_TOKEN in the environment.")
raise RuntimeError("Hugging Face token missing. Set it in your environment.")
# Login to Hugging Face
try:
login(HF_TOKEN)
except Exception as e:
logger.error(f"Failed to authenticate Hugging Face token: {e}")
raise RuntimeError("Authentication with Hugging Face failed.")
# Configure Quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Change to load_in_8bit=True if 4-bit fails
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# Load Model
try:
logger.info("Loading model and processor...")
processor = AutoProcessor.from_pretrained(
MODEL_NAME, cache_dir="/tmp/huggingface", force_download=False
)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_NAME, quantization_config=quantization_config, cache_dir="/tmp/huggingface"
).to(DEVICE)
torch.backends.cuda.matmul.allow_tf32 = True # Optimize CUDA
torch.cuda.empty_cache()
logger.info("Model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise RuntimeError("Model loading failed. Check model accessibility.")
# Allowed Formats
ALLOWED_FORMATS = {"jpeg", "jpg", "png", "bmp", "tiff"}
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
try:
ext = file.filename.split(".")[-1].lower()
if ext not in ALLOWED_FORMATS:
raise HTTPException(status_code=400, detail=f"Invalid file format: {ext}. Upload an image file.")
# Read Image
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Convert Image to Base64
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Validation Step
validation_prompt = "Is this a medical X-ray or CT scan? Answer only 'yes' or 'no'."
validation_inputs = processor(
text=validation_prompt, images=image, return_tensors="pt"
).to(DEVICE)
with torch.no_grad():
validation_output = model.generate(
**validation_inputs, max_new_tokens=10, temperature=0.1, top_p=0.7, top_k=50
)
validation_result = processor.batch_decode(validation_output, skip_special_tokens=True)[0].strip().lower()
logger.info(f"Validation result: {validation_result}")
if "yes" not in validation_result:
raise HTTPException(status_code=400, detail="Uploaded image is not an X-ray or CT scan.")
# Analysis Step
analysis_prompt = """Analyze this X-ray image and provide a detailed medical report:
Type of X-ray:
Key Findings:
• [Findings]
Potential Conditions:
• [Possible Diagnoses]
Recommendations:
• [Follow-up Actions]
"""
analysis_inputs = processor(text=analysis_prompt, images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
analysis_output = model.generate(
**analysis_inputs, max_new_tokens=512, temperature=0.7, top_p=0.7, top_k=50
)
analysis_content = processor.batch_decode(analysis_output, skip_special_tokens=True)[0]
cleaned_analysis = (
analysis_content.replace("**", "").replace("*", "•").replace("_", "").strip()
)
return {"analysis": cleaned_analysis}
except HTTPException as http_err:
logger.error(f"Validation error: {http_err.detail}")
raise http_err
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")