thehammadishaq commited on
Commit
03f64ba
·
verified ·
1 Parent(s): bf2f6ca

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +191 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from PIL import Image
3
+ import os
4
+ import uvicorn
5
+ import torch
6
+ import numpy as np
7
+ from io import BytesIO
8
+ from torchvision import transforms , models
9
+ import torch.nn as nn
10
+ from huggingface_hub import hf_hub_download
11
+ import tempfile
12
+ from pathlib import Path
13
+
14
+ # Set up cache directory in a user-accessible location
15
+ CACHE_DIR = Path(tempfile.gettempdir()) / "huggingface_cache"
16
+ os.environ["TRANSFORMERS_CACHE"] = str(CACHE_DIR)
17
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
18
+
19
+
20
+ app = FastAPI()
21
+
22
+ # Define preprocessing
23
+ preprocessDensenet = transforms.Compose([
24
+ transforms.Resize((224, 224)),
25
+ transforms.RandomHorizontalFlip(p=0.3),
26
+ transforms.RandomAffine(
27
+ degrees=(-15, 15),
28
+ translate=(0.1, 0.1),
29
+ scale=(0.85, 1.15),
30
+ fill=0
31
+ ),
32
+ transforms.RandomApply([
33
+ transforms.ColorJitter(
34
+ brightness=0.2,
35
+ contrast=0.2
36
+ )
37
+ ], p=0.3),
38
+ transforms.RandomApply([
39
+ transforms.GaussianBlur(kernel_size=3)
40
+ ], p=0.2),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
43
+ transforms.RandomErasing(p=0.1)
44
+ ])
45
+
46
+ preprocessResnet = transforms.Compose([
47
+ transforms.Resize((224, 224)),
48
+ transforms.RandomHorizontalFlip(p=0.5),
49
+ transforms.RandomAffine(
50
+ degrees=(-10, 10),
51
+ translate=(0.1, 0.1),
52
+ scale=(0.9, 1.1),
53
+ fill=0
54
+ ),
55
+ transforms.RandomApply([
56
+ transforms.ColorJitter(
57
+ brightness=0.3,
58
+ contrast=0.3
59
+ )
60
+ ], p=0.3),
61
+ transforms.RandomApply([
62
+ transforms.GaussianBlur(kernel_size=3)
63
+ ], p=0.2),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
66
+ transforms.RandomErasing(p=0.2)
67
+ ])
68
+
69
+ preprocessGooglenet = transforms.Compose([
70
+ transforms.Resize((224, 224)),
71
+ transforms.RandomHorizontalFlip(p=0.3), # Less aggressive flipping for medical images
72
+ transforms.RandomAffine(
73
+ degrees=(-5, 5), # Slight rotation
74
+ translate=(0.05, 0.05), # Small translations
75
+ scale=(0.95, 1.05), # Subtle scaling
76
+ fill=0 # Fill with black
77
+ ),
78
+ transforms.RandomApply([
79
+ transforms.ColorJitter(
80
+ brightness=0.2,
81
+ contrast=0.2
82
+ )
83
+ ], p=0.3), # Subtle intensity variations
84
+ transforms.ToTensor(),
85
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
86
+ ])
87
+
88
+ def create_densenet169():
89
+ model = models.densenet169(pretrained=False)
90
+ model.classifier = nn.Sequential(
91
+ nn.BatchNorm1d(model.classifier.in_features), # Added batch normalization
92
+ nn.Dropout(p=0.4), # Increased dropout
93
+ nn.Linear(model.classifier.in_features, 512), # Added intermediate layer
94
+ nn.ReLU(),
95
+ nn.Dropout(p=0.3),
96
+ nn.Linear(512, 2)
97
+ )
98
+ return model
99
+
100
+ def create_resnet18():
101
+ model = models.resnet18(pretrained=False)
102
+ model.fc = nn.Sequential(
103
+ nn.Dropout(p=0.5),
104
+ nn.Linear(model.fc.in_features, 2)
105
+ )
106
+ return model
107
+
108
+ def create_googlenet():
109
+ model = models.googlenet(pretrained=False)
110
+ model.aux1 = None
111
+ model.aux2 = None
112
+ model.fc = nn.Sequential(
113
+ nn.Dropout(p=0.5),
114
+ nn.Linear(model.fc.in_features, 2)
115
+ )
116
+ return model
117
+
118
+ def load_model_from_hf(repo_id, model_creator):
119
+ try:
120
+ model_path = hf_hub_download(
121
+ repo_id=repo_id,
122
+ filename="model.pth",
123
+ cache_dir=CACHE_DIR
124
+ )
125
+ # Create model architecture
126
+ model = model_creator()
127
+ # Load the checkpoint
128
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
129
+
130
+ # Extract model_state_dict from the checkpoint
131
+ if "model_state_dict" in checkpoint:
132
+ state_dict = checkpoint["model_state_dict"]
133
+ else:
134
+ state_dict = checkpoint # In case it's just the state_dict without wrapping
135
+
136
+ model.load_state_dict(state_dict)
137
+ model.eval()
138
+ return model
139
+ except Exception as e:
140
+ print(f"Error loading model from {repo_id}: {str(e)}")
141
+ return None
142
+
143
+ modelss = {"Densenet169": None, "Resnet18": None, "Googlenet": None}
144
+
145
+ modelss["Densenet169"] = load_model_from_hf(
146
+ "Arham-Irfan/Densenet169_pnuemonia_binaryclassification",
147
+ create_densenet169
148
+ )
149
+ modelss["Resnet18"] = load_model_from_hf(
150
+ "Arham-Irfan/Resnet18_pnuemonia_binaryclassification",
151
+ create_resnet18
152
+ )
153
+ modelss["Googlenet"] = load_model_from_hf(
154
+ "Arham-Irfan/Googlenet_pnuemonia_binaryclassification",
155
+ create_googlenet
156
+ )
157
+
158
+ classes = ["Normal", "Pneumonia"]
159
+
160
+ @app.post("/predict")
161
+ async def predict_pneumonia(file: UploadFile = File(...)):
162
+ try:
163
+ image = Image.open(BytesIO(await file.read())).convert("RGB")
164
+ img_tensor1 = preprocessDensenet(image).unsqueeze(0)
165
+ img_tensor2 = preprocessResnet(image).unsqueeze(0)
166
+ img_tensor3 = preprocessGooglenet(image).unsqueeze(0)
167
+
168
+ with torch.no_grad():
169
+ output1 = torch.softmax(modelss["Densenet169"](img_tensor1), dim=1).numpy()[0]
170
+ output2 = torch.softmax(modelss["Resnet18"](img_tensor2), dim=1).numpy()[0]
171
+ output3 = torch.softmax(modelss["Googlenet"](img_tensor3), dim=1).numpy()[0]
172
+
173
+ weights = [0.45, 0.33, 0.22]
174
+ ensemble_prob = weights[0] * output1 + weights[1] * output2 + weights[2] * output3
175
+ pred_index = np.argmax(ensemble_prob)
176
+
177
+ return {
178
+ "prediction": classes[pred_index],
179
+ "confidence": float(ensemble_prob[pred_index]),
180
+ "model_details": {
181
+ "Densenet169": float(output1[pred_index]),
182
+ "Resnet18": float(output2[pred_index]),
183
+ "Googlenet": float(output3[pred_index])
184
+ }
185
+ }
186
+ except Exception as e:
187
+ return {"error": f"Prediction error: {str(e)}"}
188
+
189
+
190
+ if __name__ == "__main__":
191
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ torchvision
5
+ numpy
6
+ Pillow
7
+ huggingface_hub
8
+ python-multipart