Vedansh-7 commited on
Commit
f69a0b1
·
verified ·
1 Parent(s): 65eb734

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +424 -0
app.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import numpy as np
6
+ import math
7
+ import os
8
+ from threading import Event
9
+ import traceback
10
+ import cv2 # Added for bilateral filtering
11
+
12
+ # Constants
13
+ IMG_SIZE = 128
14
+ TIMESTEPS = 300 # From second code
15
+ NUM_CLASSES = 2
16
+
17
+ # Global Cancellation Flag
18
+ cancel_event = Event()
19
+
20
+ # Device Configuration
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ # --- Model Definitions ---
24
+ class SinusoidalPositionEmbeddings(nn.Module):
25
+ def __init__(self, dim):
26
+ super().__init__()
27
+ self.dim = dim
28
+ half_dim = dim // 2
29
+ emb = math.log(10000) / (half_dim - 1)
30
+ emb = torch.exp(torch.arange(half_dim) * -emb) # From second code (no dtype specified)
31
+ self.register_buffer('embeddings', emb)
32
+
33
+ def forward(self, time):
34
+ device = time.device # From second code
35
+ embeddings = self.embeddings.to(device)
36
+ embeddings = time[:, None] * embeddings[None, :] # From second code
37
+ return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
38
+
39
+ class UNet(nn.Module):
40
+ def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
41
+ super().__init__()
42
+ self.num_classes = num_classes
43
+ self.label_embedding = nn.Embedding(num_classes, time_dim)
44
+
45
+ self.time_mlp = nn.Sequential(
46
+ SinusoidalPositionEmbeddings(time_dim),
47
+ nn.Linear(time_dim, time_dim),
48
+ nn.ReLU(),
49
+ nn.Linear(time_dim, time_dim)
50
+ )
51
+
52
+ # Encoder
53
+ self.inc = self.double_conv(in_channels, 64)
54
+ self.down1 = self.down(64 + time_dim * 2, 128)
55
+ self.down2 = self.down(128 + time_dim * 2, 256)
56
+ self.down3 = self.down(256 + time_dim * 2, 512)
57
+
58
+ # Bottleneck
59
+ self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
60
+
61
+ # Decoder
62
+ self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
63
+ self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
64
+
65
+ self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
66
+ self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128)
67
+
68
+ self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
69
+ self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64)
70
+
71
+ self.outc = nn.Conv2d(64, out_channels, kernel_size=1)
72
+
73
+ def double_conv(self, in_channels, out_channels):
74
+ return nn.Sequential(
75
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
76
+ nn.ReLU(inplace=True),
77
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
78
+ nn.ReLU(inplace=True)
79
+ )
80
+
81
+ def down(self, in_channels, out_channels):
82
+ return nn.Sequential(
83
+ nn.MaxPool2d(2),
84
+ self.double_conv(in_channels, out_channels)
85
+ )
86
+
87
+ def forward(self, x, labels, time):
88
+ label_indices = torch.argmax(labels, dim=1)
89
+ label_emb = self.label_embedding(label_indices)
90
+ t_emb = self.time_mlp(time)
91
+
92
+ combined_emb = torch.cat([t_emb, label_emb], dim=1)
93
+ combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1)
94
+
95
+ x1 = self.inc(x)
96
+ x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1)
97
+
98
+ x2 = self.down1(x1_cat)
99
+ x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1)
100
+
101
+ x3 = self.down2(x2_cat)
102
+ x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1)
103
+
104
+ x4 = self.down3(x3_cat)
105
+ x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1)
106
+
107
+ x5 = self.bottleneck(x4_cat)
108
+
109
+ x = self.up1(x5)
110
+ x = torch.cat([x, x3], dim=1)
111
+ x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
112
+ x = self.upconv1(x)
113
+
114
+ x = self.up2(x)
115
+ x = torch.cat([x, x2], dim=1)
116
+ x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
117
+ x = self.upconv2(x)
118
+
119
+ x = self.up3(x)
120
+ x = torch.cat([x, x1], dim=1)
121
+ x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
122
+ x = self.upconv3(x)
123
+
124
+ return self.outc(x)
125
+
126
+ class DiffusionModel(nn.Module):
127
+ def __init__(self, model, timesteps=TIMESTEPS, time_dim=256):
128
+ super().__init__()
129
+ self.model = model
130
+ self.timesteps = timesteps
131
+ self.time_dim = time_dim
132
+
133
+ # Linear beta schedule with scaling from second code
134
+ scale = 1000 / timesteps
135
+ beta_start = scale * 0.0001
136
+ beta_end = scale * 0.02
137
+ self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
138
+ self.alphas = 1. - self.betas
139
+ self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0).float())
140
+
141
+ def forward_diffusion(self, x_0, t, noise):
142
+ x_0 = x_0.float()
143
+ noise = noise.float()
144
+ alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1)
145
+ x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1. - alpha_bar_t) * noise
146
+ return x_t
147
+
148
+ def forward(self, x_0, labels):
149
+ t = torch.randint(0, self.timesteps, (x_0.shape[0],), device=x_0.device).long()
150
+ noise = torch.randn_like(x_0)
151
+ x_t = self.forward_diffusion(x_0, t, noise)
152
+ predicted_noise = self.model(x_t, labels, t.float())
153
+ return predicted_noise, noise, t
154
+
155
+ @torch.no_grad()
156
+ def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
157
+ # Start with random noise
158
+ x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
159
+
160
+ # Label handling (one-hot if needed)
161
+ if labels.ndim == 1:
162
+ labels_one_hot = torch.zeros(num_images, num_classes).to(device)
163
+ labels_one_hot[torch.arange(num_images), labels] = 1
164
+ labels = labels_one_hot
165
+ else:
166
+ labels = labels.to(device)
167
+
168
+ # REVERTED SAMPLING LOOP WITH NOISE REDUCTION
169
+ for t in reversed(range(self.timesteps)):
170
+ if cancel_event.is_set():
171
+ return None
172
+
173
+ t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
174
+ predicted_noise = self.model(x_t, labels, t_tensor)
175
+
176
+ # Calculate coefficients
177
+ beta_t = self.betas[t].to(device)
178
+ alpha_t = self.alphas[t].to(device)
179
+ alpha_bar_t = self.alpha_bars[t].to(device)
180
+
181
+ mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
182
+ variance = beta_t
183
+
184
+ # Reduced noise injection with lower multiplier
185
+ if t > 0:
186
+ noise = torch.randn_like(x_t) * 0.8 # Reduced noise by 20%
187
+ else:
188
+ noise = torch.zeros_like(x_t)
189
+
190
+ x_t = mean + torch.sqrt(variance) * noise
191
+
192
+ if progress_callback:
193
+ progress_callback((self.timesteps - t) / self.timesteps)
194
+
195
+ # Clamp and denormalize
196
+ x_0 = torch.clamp(x_t, -1., 1.)
197
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
198
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
199
+ x_0 = std * x_0 + mean
200
+ x_0 = torch.clamp(x_0, 0., 1.)
201
+
202
+ # ENHANCED SHARPENING
203
+ # First apply mild bilateral filtering to reduce noise while preserving edges
204
+ x_np = x_0.cpu().permute(0, 2, 3, 1).numpy()
205
+ filtered = []
206
+ for img in x_np:
207
+ img = (img * 255).astype(np.uint8)
208
+ filtered_img = cv2.bilateralFilter(img, d=5, sigmaColor=15, sigmaSpace=15)
209
+ filtered.append(filtered_img / 255.0)
210
+ x_0 = torch.tensor(np.array(filtered), device=device).permute(0, 3, 1, 2)
211
+
212
+ # Then apply stronger unsharp masking
213
+ kernel = torch.ones(3, 1, 5, 5, device=device) / 75
214
+ kernel = kernel.to(x_0.dtype)
215
+ blurred = torch.nn.functional.conv2d(
216
+ x_0,
217
+ kernel,
218
+ padding=2,
219
+ groups=3
220
+ )
221
+ x_0 = torch.clamp(1.5 * x_0 - 0.5 * blurred, 0., 1.) # Increased sharpening factor
222
+
223
+ return x_0
224
+
225
+ def load_model(model_path, device):
226
+ unet_model = UNet(num_classes=NUM_CLASSES).to(device)
227
+ diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
228
+
229
+ if os.path.exists(model_path):
230
+ checkpoint = torch.load(model_path, map_location=device)
231
+
232
+ if 'model_state_dict' in checkpoint:
233
+ # Handle training checkpoint format
234
+ state_dict = {
235
+ k[6:]: v for k, v in checkpoint['model_state_dict'].items()
236
+ if k.startswith('model.')
237
+ }
238
+
239
+ # Load UNet weights
240
+ unet_model.load_state_dict(state_dict, strict=False)
241
+
242
+ # Initialize diffusion model with loaded UNet
243
+ diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
244
+
245
+ print(f"Loaded UNet weights from {model_path}")
246
+ else:
247
+ # Handle direct model weights format
248
+ try:
249
+ # First try loading full DiffusionModel
250
+ diffusion_model.load_state_dict(checkpoint)
251
+ print(f"Loaded full DiffusionModel from {model_path}")
252
+ except RuntimeError:
253
+ # If that fails, load just the UNet weights
254
+ unet_model.load_state_dict(checkpoint, strict=False)
255
+ diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
256
+ print(f"Loaded UNet weights only from {model_path}")
257
+ else:
258
+ print(f"Weights file not found at {model_path}")
259
+ print("Using randomly initialized weights")
260
+
261
+ diffusion_model.eval()
262
+ return diffusion_model
263
+
264
+ def cancel_generation():
265
+ cancel_event.set()
266
+ return "Generation cancelled"
267
+
268
+ def generate_images(label_str, num_images, progress=gr.Progress()):
269
+ global loaded_model
270
+ cancel_event.clear()
271
+
272
+ if num_images < 1 or num_images > 10:
273
+ raise gr.Error("Number of images must be between 1 and 10")
274
+
275
+ label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
276
+ if label_str not in label_map:
277
+ raise gr.Error("Invalid condition selected")
278
+
279
+ labels = torch.zeros(num_images, NUM_CLASSES)
280
+ labels[:, label_map[label_str]] = 1
281
+
282
+ try:
283
+ def progress_callback(progress_val):
284
+ progress(progress_val, desc="Generating...")
285
+ if cancel_event.is_set():
286
+ raise gr.Error("Generation was cancelled by user")
287
+
288
+ with torch.no_grad():
289
+ images = loaded_model.sample(
290
+ num_images=num_images,
291
+ img_size=IMG_SIZE,
292
+ num_classes=NUM_CLASSES,
293
+ labels=labels,
294
+ device=device,
295
+ progress_callback=progress_callback
296
+ )
297
+
298
+ if images is None:
299
+ return None, None
300
+
301
+ processed_images = []
302
+ for img in images:
303
+ img_np = img.cpu().permute(1, 2, 0).numpy()
304
+ img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
305
+ pil_img = Image.fromarray(img_np)
306
+ processed_images.append(pil_img)
307
+
308
+ if num_images == 1:
309
+ return processed_images[0], processed_images
310
+ else:
311
+ return None, processed_images
312
+
313
+ except Exception as e:
314
+ traceback.print_exc()
315
+ raise gr.Error(f"Generation failed: {str(e)}")
316
+ finally:
317
+ torch.cuda.empty_cache()
318
+
319
+ # Load model
320
+ MODEL_NAME = "model_weights.pth"
321
+ model_path = MODEL_NAME
322
+ print("Loading model...")
323
+ try:
324
+ loaded_model = load_model(model_path, device)
325
+ print("Model loaded successfully!")
326
+ except Exception as e:
327
+ print(f"Failed to load model: {e}")
328
+ print("Creating dummy model for demonstration")
329
+ loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES), timesteps=TIMESTEPS).to(device)
330
+
331
+ # Gradio UI (from first code)
332
+ with gr.Blocks(theme=gr.themes.Soft(
333
+ primary_hue="violet",
334
+ neutral_hue="slate",
335
+ font=[gr.themes.GoogleFont("Poppins")],
336
+ text_size="md"
337
+ )) as demo:
338
+ gr.Markdown("""
339
+ <center>
340
+ <h1>Synthetic X-ray Generator</h1>
341
+ <p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
342
+ </center>
343
+ """)
344
+
345
+ with gr.Row():
346
+ with gr.Column(scale=1):
347
+ condition = gr.Dropdown(
348
+ ["Pneumonia", "Pneumothorax"],
349
+ label="Select Condition",
350
+ value="Pneumonia",
351
+ interactive=True
352
+ )
353
+ num_images = gr.Slider(
354
+ 1, 10, value=1, step=1,
355
+ label="Number of Images",
356
+ interactive=True
357
+ )
358
+
359
+ with gr.Row():
360
+ submit_btn = gr.Button("Generate", variant="primary")
361
+ cancel_btn = gr.Button("Cancel", variant="stop")
362
+
363
+ gr.Markdown("""
364
+ <div style="text-align: center; margin-top: 10px;">
365
+ <small>Note: Generation may take several seconds per image</small>
366
+ </div>
367
+ """)
368
+
369
+ with gr.Column(scale=2):
370
+ with gr.Tabs():
371
+ with gr.TabItem("Output", id="output_tab"):
372
+ single_image = gr.Image(
373
+ label="Generated X-ray",
374
+ height=400,
375
+ visible=True
376
+ )
377
+ gallery = gr.Gallery(
378
+ label="Generated X-rays",
379
+ columns=3,
380
+ height="auto",
381
+ object_fit="contain",
382
+ visible=False
383
+ )
384
+
385
+ def update_ui_based_on_count(num_images):
386
+ if num_images == 1:
387
+ return {
388
+ single_image: gr.update(visible=True),
389
+ gallery: gr.update(visible=False)
390
+ }
391
+ else:
392
+ return {
393
+ single_image: gr.update(visible=False),
394
+ gallery: gr.update(visible=True)
395
+ }
396
+
397
+ num_images.change(
398
+ fn=update_ui_based_on_count,
399
+ inputs=num_images,
400
+ outputs=[single_image, gallery]
401
+ )
402
+
403
+ submit_btn.click(
404
+ fn=generate_images,
405
+ inputs=[condition, num_images],
406
+ outputs=[single_image, gallery]
407
+ )
408
+
409
+ cancel_btn.click(
410
+ fn=cancel_generation,
411
+ outputs=None
412
+ )
413
+
414
+ demo.css = """
415
+ .gradio-container {
416
+ background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%);
417
+ }
418
+ .gallery-container {
419
+ background-color: white !important;
420
+ }
421
+ """
422
+
423
+ if __name__ == "__main__":
424
+ demo.launch(server_name="0.0.0.0", server_port=7860)