from fastapi import FastAPI, UploadFile, File from PIL import Image import os import uvicorn import torch import numpy as np from io import BytesIO from torchvision import transforms , models import torch.nn as nn from huggingface_hub import hf_hub_download import tempfile from pathlib import Path # Set up cache directory in a user-accessible location CACHE_DIR = Path(tempfile.gettempdir()) / "huggingface_cache" os.environ["TRANSFORMERS_CACHE"] = str(CACHE_DIR) CACHE_DIR.mkdir(parents=True, exist_ok=True) app = FastAPI() # Define preprocessing preprocessDensenet = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.3), transforms.RandomAffine( degrees=(-15, 15), translate=(0.1, 0.1), scale=(0.85, 1.15), fill=0 ), transforms.RandomApply([ transforms.ColorJitter( brightness=0.2, contrast=0.2 ) ], p=0.3), transforms.RandomApply([ transforms.GaussianBlur(kernel_size=3) ], p=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.RandomErasing(p=0.1) ]) preprocessResnet = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomAffine( degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), fill=0 ), transforms.RandomApply([ transforms.ColorJitter( brightness=0.3, contrast=0.3 ) ], p=0.3), transforms.RandomApply([ transforms.GaussianBlur(kernel_size=3) ], p=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.RandomErasing(p=0.2) ]) preprocessGooglenet = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.3), # Less aggressive flipping for medical images transforms.RandomAffine( degrees=(-5, 5), # Slight rotation translate=(0.05, 0.05), # Small translations scale=(0.95, 1.05), # Subtle scaling fill=0 # Fill with black ), transforms.RandomApply([ transforms.ColorJitter( brightness=0.2, contrast=0.2 ) ], p=0.3), # Subtle intensity variations transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def create_densenet169(): model = models.densenet169(pretrained=False) model.classifier = nn.Sequential( nn.BatchNorm1d(model.classifier.in_features), # Added batch normalization nn.Dropout(p=0.4), # Increased dropout nn.Linear(model.classifier.in_features, 512), # Added intermediate layer nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(512, 2) ) return model def create_resnet18(): model = models.resnet18(pretrained=False) model.fc = nn.Sequential( nn.Dropout(p=0.5), nn.Linear(model.fc.in_features, 2) ) return model def create_googlenet(): model = models.googlenet(pretrained=False) model.aux1 = None model.aux2 = None model.fc = nn.Sequential( nn.Dropout(p=0.5), nn.Linear(model.fc.in_features, 2) ) return model def load_model_from_hf(repo_id, model_creator): try: model_path = hf_hub_download( repo_id=repo_id, filename="model.pth", cache_dir=CACHE_DIR ) # Create model architecture model = model_creator() # Load the checkpoint checkpoint = torch.load(model_path, map_location=torch.device('cpu')) # Extract model_state_dict from the checkpoint if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint # In case it's just the state_dict without wrapping model.load_state_dict(state_dict) model.eval() return model except Exception as e: print(f"Error loading model from {repo_id}: {str(e)}") return None modelss = {"Densenet169": None, "Resnet18": None, "Googlenet": None} modelss["Densenet169"] = load_model_from_hf( "Arham-Irfan/Densenet169_pnuemonia_binaryclassification", create_densenet169 ) modelss["Resnet18"] = load_model_from_hf( "Arham-Irfan/Resnet18_pnuemonia_binaryclassification", create_resnet18 ) modelss["Googlenet"] = load_model_from_hf( "Arham-Irfan/Googlenet_pnuemonia_binaryclassification", create_googlenet ) classes = ["Normal", "Pneumonia"] @app.post("/predict") async def predict_pneumonia(file: UploadFile = File(...)): try: image = Image.open(BytesIO(await file.read())).convert("RGB") img_tensor1 = preprocessDensenet(image).unsqueeze(0) img_tensor2 = preprocessResnet(image).unsqueeze(0) img_tensor3 = preprocessGooglenet(image).unsqueeze(0) with torch.no_grad(): output1 = torch.softmax(modelss["Densenet169"](img_tensor1), dim=1).numpy()[0] output2 = torch.softmax(modelss["Resnet18"](img_tensor2), dim=1).numpy()[0] output3 = torch.softmax(modelss["Googlenet"](img_tensor3), dim=1).numpy()[0] weights = [0.45, 0.33, 0.22] ensemble_prob = weights[0] * output1 + weights[1] * output2 + weights[2] * output3 pred_index = np.argmax(ensemble_prob) return { "prediction": classes[pred_index], "confidence": float(ensemble_prob[pred_index]), "model_details": { "Densenet169": float(output1[pred_index]), "Resnet18": float(output2[pred_index]), "Googlenet": float(output3[pred_index]) } } except Exception as e: return {"error": f"Prediction error: {str(e)}"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)