Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File | |
from PIL import Image | |
import os | |
import uvicorn | |
import torch | |
import numpy as np | |
from io import BytesIO | |
from torchvision import transforms , models | |
import torch.nn as nn | |
from huggingface_hub import hf_hub_download | |
import tempfile | |
from pathlib import Path | |
# Set up cache directory in a user-accessible location | |
CACHE_DIR = Path(tempfile.gettempdir()) / "huggingface_cache" | |
os.environ["TRANSFORMERS_CACHE"] = str(CACHE_DIR) | |
CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
app = FastAPI() | |
# Define preprocessing | |
preprocessDensenet = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.RandomHorizontalFlip(p=0.3), | |
transforms.RandomAffine( | |
degrees=(-15, 15), | |
translate=(0.1, 0.1), | |
scale=(0.85, 1.15), | |
fill=0 | |
), | |
transforms.RandomApply([ | |
transforms.ColorJitter( | |
brightness=0.2, | |
contrast=0.2 | |
) | |
], p=0.3), | |
transforms.RandomApply([ | |
transforms.GaussianBlur(kernel_size=3) | |
], p=0.2), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
transforms.RandomErasing(p=0.1) | |
]) | |
preprocessResnet = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.RandomHorizontalFlip(p=0.5), | |
transforms.RandomAffine( | |
degrees=(-10, 10), | |
translate=(0.1, 0.1), | |
scale=(0.9, 1.1), | |
fill=0 | |
), | |
transforms.RandomApply([ | |
transforms.ColorJitter( | |
brightness=0.3, | |
contrast=0.3 | |
) | |
], p=0.3), | |
transforms.RandomApply([ | |
transforms.GaussianBlur(kernel_size=3) | |
], p=0.2), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
transforms.RandomErasing(p=0.2) | |
]) | |
preprocessGooglenet = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.RandomHorizontalFlip(p=0.3), # Less aggressive flipping for medical images | |
transforms.RandomAffine( | |
degrees=(-5, 5), # Slight rotation | |
translate=(0.05, 0.05), # Small translations | |
scale=(0.95, 1.05), # Subtle scaling | |
fill=0 # Fill with black | |
), | |
transforms.RandomApply([ | |
transforms.ColorJitter( | |
brightness=0.2, | |
contrast=0.2 | |
) | |
], p=0.3), # Subtle intensity variations | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
def create_densenet169(): | |
model = models.densenet169(pretrained=False) | |
model.classifier = nn.Sequential( | |
nn.BatchNorm1d(model.classifier.in_features), # Added batch normalization | |
nn.Dropout(p=0.4), # Increased dropout | |
nn.Linear(model.classifier.in_features, 512), # Added intermediate layer | |
nn.ReLU(), | |
nn.Dropout(p=0.3), | |
nn.Linear(512, 2) | |
) | |
return model | |
def create_resnet18(): | |
model = models.resnet18(pretrained=False) | |
model.fc = nn.Sequential( | |
nn.Dropout(p=0.5), | |
nn.Linear(model.fc.in_features, 2) | |
) | |
return model | |
def create_googlenet(): | |
model = models.googlenet(pretrained=False) | |
model.aux1 = None | |
model.aux2 = None | |
model.fc = nn.Sequential( | |
nn.Dropout(p=0.5), | |
nn.Linear(model.fc.in_features, 2) | |
) | |
return model | |
def load_model_from_hf(repo_id, model_creator): | |
try: | |
model_path = hf_hub_download( | |
repo_id=repo_id, | |
filename="model.pth", | |
cache_dir=CACHE_DIR | |
) | |
# Create model architecture | |
model = model_creator() | |
# Load the checkpoint | |
checkpoint = torch.load(model_path, map_location=torch.device('cpu')) | |
# Extract model_state_dict from the checkpoint | |
if "model_state_dict" in checkpoint: | |
state_dict = checkpoint["model_state_dict"] | |
else: | |
state_dict = checkpoint # In case it's just the state_dict without wrapping | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model | |
except Exception as e: | |
print(f"Error loading model from {repo_id}: {str(e)}") | |
return None | |
modelss = {"Densenet169": None, "Resnet18": None, "Googlenet": None} | |
modelss["Densenet169"] = load_model_from_hf( | |
"Arham-Irfan/Densenet169_pnuemonia_binaryclassification", | |
create_densenet169 | |
) | |
modelss["Resnet18"] = load_model_from_hf( | |
"Arham-Irfan/Resnet18_pnuemonia_binaryclassification", | |
create_resnet18 | |
) | |
modelss["Googlenet"] = load_model_from_hf( | |
"Arham-Irfan/Googlenet_pnuemonia_binaryclassification", | |
create_googlenet | |
) | |
classes = ["Normal", "Pneumonia"] | |
async def predict_pneumonia(file: UploadFile = File(...)): | |
try: | |
image = Image.open(BytesIO(await file.read())).convert("RGB") | |
img_tensor1 = preprocessDensenet(image).unsqueeze(0) | |
img_tensor2 = preprocessResnet(image).unsqueeze(0) | |
img_tensor3 = preprocessGooglenet(image).unsqueeze(0) | |
with torch.no_grad(): | |
output1 = torch.softmax(modelss["Densenet169"](img_tensor1), dim=1).numpy()[0] | |
output2 = torch.softmax(modelss["Resnet18"](img_tensor2), dim=1).numpy()[0] | |
output3 = torch.softmax(modelss["Googlenet"](img_tensor3), dim=1).numpy()[0] | |
weights = [0.45, 0.33, 0.22] | |
ensemble_prob = weights[0] * output1 + weights[1] * output2 + weights[2] * output3 | |
pred_index = np.argmax(ensemble_prob) | |
return { | |
"prediction": classes[pred_index], | |
"confidence": float(ensemble_prob[pred_index]), | |
"model_details": { | |
"Densenet169": float(output1[pred_index]), | |
"Resnet18": float(output2[pred_index]), | |
"Googlenet": float(output3[pred_index]) | |
} | |
} | |
except Exception as e: | |
return {"error": f"Prediction error: {str(e)}"} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) |