Arham-Irfan's picture
Update app.py
e5dcce7 verified
from fastapi import FastAPI, UploadFile, File
from PIL import Image
import os
import uvicorn
import torch
import numpy as np
from io import BytesIO
from torchvision import transforms, models
import torch.nn as nn
from huggingface_hub import hf_hub_download
import tempfile
from pathlib import Path
import pydicom
CACHE_DIR = Path("/huggingface/cache")
HF_CACHE_DIR = Path("/huggingface/cache")
os.environ["TRANSFORMERS_CACHE"] = str(HF_CACHE_DIR)
os.environ["HF_HOME"] = str(HF_CACHE_DIR.parent)
os.environ["HF_HUB_CACHE"] = str(HF_CACHE_DIR)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
CACHE_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI()
preprocessDensenet = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.3),
transforms.RandomAffine(
degrees=(-15, 15),
translate=(0.1, 0.1),
scale=(0.85, 1.15),
fill=0
),
transforms.RandomApply([
transforms.ColorJitter(
brightness=0.2,
contrast=0.2
)
], p=0.3),
transforms.RandomApply([
transforms.GaussianBlur(kernel_size=3)
], p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.RandomErasing(p=0.1)
])
preprocessResnet = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomAffine(
degrees=(-10, 10),
translate=(0.1, 0.1),
scale=(0.9, 1.1),
fill=0
),
transforms.RandomApply([
transforms.ColorJitter(
brightness=0.3,
contrast=0.3
)
], p=0.3),
transforms.RandomApply([
transforms.GaussianBlur(kernel_size=3)
], p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.RandomErasing(p=0.2)
])
preprocessGooglenet = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.3),
transforms.RandomAffine(
degrees=(-5, 5),
translate=(0.05, 0.05),
scale=(0.95, 1.05),
fill=0
),
transforms.RandomApply([
transforms.ColorJitter(
brightness=0.2,
contrast=0.2
)
], p=0.3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def create_densenet169():
model = models.densenet169(pretrained=False)
model.classifier = nn.Sequential(
nn.BatchNorm1d(model.classifier.in_features),
nn.Dropout(p=0.4),
nn.Linear(model.classifier.in_features, 512),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(512, 2)
)
return model
def create_resnet18():
model = models.resnet18(pretrained=False)
model.fc = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(model.fc.in_features, 2)
)
return model
def create_googlenet():
model = models.googlenet(pretrained=False)
model.aux1 = None
model.aux2 = None
model.fc = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(model.fc.in_features, 2)
)
return model
def load_model_from_hf(repo_id, model_creator):
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename="model.pth",
cache_dir=CACHE_DIR
)
model = model_creator()
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
model.load_state_dict(state_dict)
model.eval()
return model
except Exception as e:
print(f"Error loading model from {repo_id}: {str(e)}")
return None
modelss = {"Densenet169": None, "Resnet18": None, "Googlenet": None}
modelss["Densenet169"] = load_model_from_hf(
"Arham-Irfan/Densenet169_pnuemonia_binaryclassification",
create_densenet169
)
modelss["Resnet18"] = load_model_from_hf(
"Arham-Irfan/Resnet18_pnuemonia_binaryclassification",
create_resnet18
)
modelss["Googlenet"] = load_model_from_hf(
"Arham-Irfan/Googlenet_pnuemonia_binaryclassification",
create_googlenet
)
classes = ["Normal", "Pneumonia"]
def convert_dicom_to_rgb(dicom_path):
dicom_data = pydicom.dcmread(dicom_path)
img_array = dicom_data.pixel_array
img_array = (img_array - np.min(img_array)) / (np.max(img_array) - np.min(img_array)) * 255
img_array = img_array.astype(np.uint8)
if len(img_array.shape) == 2:
img_array = np.stack([img_array] * 3, axis=-1)
return Image.fromarray(img_array)
@app.post("/predict")
async def predict_pneumonia(file: UploadFile = File(...)):
try:
# Check file type
file_bytes = await file.read()
file_ext = file.filename.split(".")[-1].lower()
if file_ext == "dcm":
temp_path = f"/tmp/{file.filename}"
with open(temp_path, "wb") as f:
f.write(file_bytes)
image = convert_dicom_to_rgb(temp_path)
else:
image = Image.open(BytesIO(file_bytes)).convert("RGB")
# Preprocess for each model
img_tensor1 = preprocessDensenet(image).unsqueeze(0)
img_tensor2 = preprocessResnet(image).unsqueeze(0)
img_tensor3 = preprocessGooglenet(image).unsqueeze(0)
with torch.no_grad():
output1 = torch.softmax(modelss["Densenet169"](img_tensor1), dim=1).numpy()[0]
output2 = torch.softmax(modelss["Resnet18"](img_tensor2), dim=1).numpy()[0]
output3 = torch.softmax(modelss["Googlenet"](img_tensor3), dim=1).numpy()[0]
weights = [0.45, 0.33, 0.22]
ensemble_prob = weights[0] * output1 + weights[1] * output2 + weights[2] * output3
pred_index = np.argmax(ensemble_prob)
return {
"prediction": classes[pred_index],
"confidence": float(ensemble_prob[pred_index]),
"model_details": {
"Densenet169": float(output1[pred_index]),
"Resnet18": float(output2[pred_index]),
"Googlenet": float(output3[pred_index])
}
}
except Exception as e:
return {"error": f"Prediction error: {str(e)}"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)