from fastapi import FastAPI, UploadFile, File from PIL import Image import os import uvicorn import torch import numpy as np from io import BytesIO from torchvision import transforms, models import torch.nn as nn from huggingface_hub import hf_hub_download import tempfile from pathlib import Path import pydicom CACHE_DIR = Path("/huggingface/cache") HF_CACHE_DIR = Path("/huggingface/cache") os.environ["TRANSFORMERS_CACHE"] = str(HF_CACHE_DIR) os.environ["HF_HOME"] = str(HF_CACHE_DIR.parent) os.environ["HF_HUB_CACHE"] = str(HF_CACHE_DIR) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" CACHE_DIR.mkdir(parents=True, exist_ok=True) app = FastAPI() preprocessDensenet = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.3), transforms.RandomAffine( degrees=(-15, 15), translate=(0.1, 0.1), scale=(0.85, 1.15), fill=0 ), transforms.RandomApply([ transforms.ColorJitter( brightness=0.2, contrast=0.2 ) ], p=0.3), transforms.RandomApply([ transforms.GaussianBlur(kernel_size=3) ], p=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.RandomErasing(p=0.1) ]) preprocessResnet = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomAffine( degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), fill=0 ), transforms.RandomApply([ transforms.ColorJitter( brightness=0.3, contrast=0.3 ) ], p=0.3), transforms.RandomApply([ transforms.GaussianBlur(kernel_size=3) ], p=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.RandomErasing(p=0.2) ]) preprocessGooglenet = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.3), transforms.RandomAffine( degrees=(-5, 5), translate=(0.05, 0.05), scale=(0.95, 1.05), fill=0 ), transforms.RandomApply([ transforms.ColorJitter( brightness=0.2, contrast=0.2 ) ], p=0.3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def create_densenet169(): model = models.densenet169(pretrained=False) model.classifier = nn.Sequential( nn.BatchNorm1d(model.classifier.in_features), nn.Dropout(p=0.4), nn.Linear(model.classifier.in_features, 512), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(512, 2) ) return model def create_resnet18(): model = models.resnet18(pretrained=False) model.fc = nn.Sequential( nn.Dropout(p=0.5), nn.Linear(model.fc.in_features, 2) ) return model def create_googlenet(): model = models.googlenet(pretrained=False) model.aux1 = None model.aux2 = None model.fc = nn.Sequential( nn.Dropout(p=0.5), nn.Linear(model.fc.in_features, 2) ) return model def load_model_from_hf(repo_id, model_creator): try: model_path = hf_hub_download( repo_id=repo_id, filename="model.pth", cache_dir=CACHE_DIR ) model = model_creator() checkpoint = torch.load(model_path, map_location=torch.device('cpu')) if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint model.load_state_dict(state_dict) model.eval() return model except Exception as e: print(f"Error loading model from {repo_id}: {str(e)}") return None modelss = {"Densenet169": None, "Resnet18": None, "Googlenet": None} modelss["Densenet169"] = load_model_from_hf( "Arham-Irfan/Densenet169_pnuemonia_binaryclassification", create_densenet169 ) modelss["Resnet18"] = load_model_from_hf( "Arham-Irfan/Resnet18_pnuemonia_binaryclassification", create_resnet18 ) modelss["Googlenet"] = load_model_from_hf( "Arham-Irfan/Googlenet_pnuemonia_binaryclassification", create_googlenet ) classes = ["Normal", "Pneumonia"] def convert_dicom_to_rgb(dicom_path): dicom_data = pydicom.dcmread(dicom_path) img_array = dicom_data.pixel_array img_array = (img_array - np.min(img_array)) / (np.max(img_array) - np.min(img_array)) * 255 img_array = img_array.astype(np.uint8) if len(img_array.shape) == 2: img_array = np.stack([img_array] * 3, axis=-1) return Image.fromarray(img_array) @app.post("/predict") async def predict_pneumonia(file: UploadFile = File(...)): try: # Check file type file_bytes = await file.read() file_ext = file.filename.split(".")[-1].lower() if file_ext == "dcm": temp_path = f"/tmp/{file.filename}" with open(temp_path, "wb") as f: f.write(file_bytes) image = convert_dicom_to_rgb(temp_path) else: image = Image.open(BytesIO(file_bytes)).convert("RGB") # Preprocess for each model img_tensor1 = preprocessDensenet(image).unsqueeze(0) img_tensor2 = preprocessResnet(image).unsqueeze(0) img_tensor3 = preprocessGooglenet(image).unsqueeze(0) with torch.no_grad(): output1 = torch.softmax(modelss["Densenet169"](img_tensor1), dim=1).numpy()[0] output2 = torch.softmax(modelss["Resnet18"](img_tensor2), dim=1).numpy()[0] output3 = torch.softmax(modelss["Googlenet"](img_tensor3), dim=1).numpy()[0] weights = [0.45, 0.33, 0.22] ensemble_prob = weights[0] * output1 + weights[1] * output2 + weights[2] * output3 pred_index = np.argmax(ensemble_prob) return { "prediction": classes[pred_index], "confidence": float(ensemble_prob[pred_index]), "model_details": { "Densenet169": float(output1[pred_index]), "Resnet18": float(output2[pred_index]), "Googlenet": float(output3[pred_index]) } } except Exception as e: return {"error": f"Prediction error: {str(e)}"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)