Vedansh-7 commited on
Commit
78db2c7
·
1 Parent(s): f25462b

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -451
app.py DELETED
@@ -1,451 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import gradio as gr
4
- from PIL import Image
5
- import numpy as np
6
- import math
7
- import os
8
- from threading import Event
9
- import traceback
10
- import cv2 # Added for bilateral filtering
11
- import matplotlib.pyplot as plt
12
- from io import BytesIO
13
-
14
- # Constants
15
- IMG_SIZE = 128
16
- TIMESTEPS = 300 # From second code
17
- NUM_CLASSES = 2
18
-
19
- # Global Cancellation Flag
20
- cancel_event = Event()
21
-
22
- # Device Configuration
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
-
25
- # --- Model Definitions ---
26
- class SinusoidalPositionEmbeddings(nn.Module):
27
- def __init__(self, dim):
28
- super().__init__()
29
- self.dim = dim
30
- half_dim = dim // 2
31
- emb = math.log(10000) / (half_dim - 1)
32
- emb = torch.exp(torch.arange(half_dim) * -emb) # From second code (no dtype specified)
33
- self.register_buffer('embeddings', emb)
34
-
35
- def forward(self, time):
36
- device = time.device # From second code
37
- embeddings = self.embeddings.to(device)
38
- embeddings = time[:, None] * embeddings[None, :] # From second code
39
- return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
40
-
41
- class UNet(nn.Module):
42
- def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
43
- super().__init__()
44
- self.num_classes = num_classes
45
- self.label_embedding = nn.Embedding(num_classes, time_dim)
46
-
47
- self.time_mlp = nn.Sequential(
48
- SinusoidalPositionEmbeddings(time_dim),
49
- nn.Linear(time_dim, time_dim),
50
- nn.ReLU(),
51
- nn.Linear(time_dim, time_dim)
52
- )
53
-
54
- # Encoder
55
- self.inc = self.double_conv(in_channels, 64)
56
- self.down1 = self.down(64 + time_dim * 2, 128)
57
- self.down2 = self.down(128 + time_dim * 2, 256)
58
- self.down3 = self.down(256 + time_dim * 2, 512)
59
-
60
- # Bottleneck
61
- self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
62
-
63
- # Decoder
64
- self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
65
- self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
66
-
67
- self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
68
- self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128)
69
-
70
- self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
71
- self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64)
72
-
73
- self.outc = nn.Conv2d(64, out_channels, kernel_size=1)
74
-
75
- def double_conv(self, in_channels, out_channels):
76
- return nn.Sequential(
77
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
78
- nn.ReLU(inplace=True),
79
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
80
- nn.ReLU(inplace=True)
81
- )
82
-
83
- def down(self, in_channels, out_channels):
84
- return nn.Sequential(
85
- nn.MaxPool2d(2),
86
- self.double_conv(in_channels, out_channels)
87
- )
88
-
89
- def forward(self, x, labels, time):
90
- label_indices = torch.argmax(labels, dim=1)
91
- label_emb = self.label_embedding(label_indices)
92
- t_emb = self.time_mlp(time)
93
-
94
- combined_emb = torch.cat([t_emb, label_emb], dim=1)
95
- combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1)
96
-
97
- x1 = self.inc(x)
98
- x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1)
99
-
100
- x2 = self.down1(x1_cat)
101
- x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1)
102
-
103
- x3 = self.down2(x2_cat)
104
- x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1)
105
-
106
- x4 = self.down3(x3_cat)
107
- x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1)
108
-
109
- x5 = self.bottleneck(x4_cat)
110
-
111
- x = self.up1(x5)
112
- x = torch.cat([x, x3], dim=1)
113
- x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
114
- x = self.upconv1(x)
115
-
116
- x = self.up2(x)
117
- x = torch.cat([x, x2], dim=1)
118
- x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
119
- x = self.upconv2(x)
120
-
121
- x = self.up3(x)
122
- x = torch.cat([x, x1], dim=1)
123
- x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
124
- x = self.upconv3(x)
125
-
126
- return self.outc(x)
127
-
128
- class DiffusionModel(nn.Module):
129
- def __init__(self, model, timesteps=TIMESTEPS, time_dim=256):
130
- super().__init__()
131
- self.model = model
132
- self.timesteps = timesteps
133
- self.time_dim = time_dim
134
-
135
- # Linear beta schedule with scaling from second code
136
- scale = 1000 / timesteps
137
- beta_start = scale * 0.0001
138
- beta_end = scale * 0.02
139
- self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
140
- self.alphas = 1. - self.betas
141
- self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0).float())
142
-
143
- def forward_diffusion(self, x_0, t, noise):
144
- x_0 = x_0.float()
145
- noise = noise.float()
146
- alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1)
147
- x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1. - alpha_bar_t) * noise
148
- return x_t
149
-
150
- def forward(self, x_0, labels):
151
- t = torch.randint(0, self.timesteps, (x_0.shape[0],), device=x_0.device).long()
152
- noise = torch.randn_like(x_0)
153
- x_t = self.forward_diffusion(x_0, t, noise)
154
- predicted_noise = self.model(x_t, labels, t.float())
155
- return predicted_noise, noise, t
156
-
157
- @torch.no_grad()
158
- def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
159
- # Start with random noise
160
- x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
161
-
162
- # Label handling (one-hot if needed)
163
- if labels.ndim == 1:
164
- labels_one_hot = torch.zeros(num_images, num_classes).to(device)
165
- labels_one_hot[torch.arange(num_images), labels] = 1
166
- labels = labels_one_hot
167
- else:
168
- labels = labels.to(device)
169
-
170
- # REVERTED SAMPLING LOOP WITH NOISE REDUCTION
171
- for t in reversed(range(self.timesteps)):
172
- if cancel_event.is_set():
173
- return None
174
-
175
- t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
176
- predicted_noise = self.model(x_t, labels, t_tensor)
177
-
178
- # Calculate coefficients
179
- beta_t = self.betas[t].to(device)
180
- alpha_t = self.alphas[t].to(device)
181
- alpha_bar_t = self.alpha_bars[t].to(device)
182
-
183
- mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
184
- variance = beta_t
185
-
186
- # Reduced noise injection with lower multiplier
187
- if t > 0:
188
- noise = torch.randn_like(x_t) * 0.8 # Reduced noise by 20%
189
- else:
190
- noise = torch.zeros_like(x_t)
191
-
192
- x_t = mean + torch.sqrt(variance) * noise
193
-
194
- if progress_callback:
195
- progress_callback((self.timesteps - t) / self.timesteps)
196
-
197
- # Clamp and denormalize
198
- x_0 = torch.clamp(x_t, -1., 1.)
199
- mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
200
- std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
201
- x_0 = std * x_0 + mean
202
- x_0 = torch.clamp(x_0, 0., 1.)
203
-
204
- # ENHANCED SHARPENING
205
- # First apply mild bilateral filtering to reduce noise while preserving edges
206
- x_np = x_0.cpu().permute(0, 2, 3, 1).numpy()
207
- filtered = []
208
- for img in x_np:
209
- img = (img * 255).astype(np.uint8)
210
- filtered_img = cv2.bilateralFilter(img, d=5, sigmaColor=15, sigmaSpace=15)
211
- filtered.append(filtered_img / 255.0)
212
- x_0 = torch.tensor(np.array(filtered), device=device).permute(0, 3, 1, 2)
213
-
214
- # Then apply stronger unsharp masking
215
- kernel = torch.ones(3, 1, 5, 5, device=device) / 75
216
- kernel = kernel.to(x_0.dtype)
217
- blurred = torch.nn.functional.conv2d(
218
- x_0,
219
- kernel,
220
- padding=2,
221
- groups=3
222
- )
223
- x_0 = torch.clamp(1.5 * x_0 - 0.5 * blurred, 0., 1.) # Increased sharpening factor
224
-
225
- return x_0
226
-
227
- def load_model(model_path, device):
228
- unet_model = UNet(num_classes=NUM_CLASSES).to(device)
229
- diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
230
-
231
- if os.path.exists(model_path):
232
- checkpoint = torch.load(model_path, map_location=device)
233
-
234
- if 'model_state_dict' in checkpoint:
235
- # Handle training checkpoint format
236
- state_dict = {
237
- k[6:]: v for k, v in checkpoint['model_state_dict'].items()
238
- if k.startswith('model.')
239
- }
240
-
241
- # Load UNet weights
242
- unet_model.load_state_dict(state_dict, strict=False)
243
-
244
- # Initialize diffusion model with loaded UNet
245
- diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
246
-
247
- print(f"Loaded UNet weights from {model_path}")
248
- else:
249
- # Handle direct model weights format
250
- try:
251
- # First try loading full DiffusionModel
252
- diffusion_model.load_state_dict(checkpoint)
253
- print(f"Loaded full DiffusionModel from {model_path}")
254
- except RuntimeError:
255
- # If that fails, load just the UNet weights
256
- unet_model.load_state_dict(checkpoint, strict=False)
257
- diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
258
- print(f"Loaded UNet weights only from {model_path}")
259
- else:
260
- print(f"Weights file not found at {model_path}")
261
- print("Using randomly initialized weights")
262
-
263
- diffusion_model.eval()
264
- return diffusion_model
265
-
266
- def cancel_generation():
267
- cancel_event.set()
268
- return "Generation cancelled"
269
-
270
- def generate_images(label_str, num_images, progress=gr.Progress()):
271
- global loaded_model
272
- cancel_event.clear()
273
-
274
- if num_images < 1 or num_images > 10:
275
- raise gr.Error("Number of images must be between 1 and 10")
276
-
277
- label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
278
- if label_str not in label_map:
279
- raise gr.Error("Invalid condition selected")
280
-
281
- labels = torch.zeros(num_images, NUM_CLASSES)
282
- labels[:, label_map[label_str]] = 1
283
-
284
- try:
285
- def progress_callback(progress_val):
286
- progress(progress_val, desc="Generating...")
287
- if cancel_event.is_set():
288
- raise gr.Error("Generation was cancelled by user")
289
-
290
- with torch.no_grad():
291
- images = loaded_model.sample(
292
- num_images=num_images,
293
- img_size=IMG_SIZE,
294
- num_classes=NUM_CLASSES,
295
- labels=labels,
296
- device=device,
297
- progress_callback=progress_callback
298
- )
299
-
300
- if images is None:
301
- return None, None
302
-
303
- processed_images = []
304
- for img in images:
305
- # Convert to grayscale and apply bone colormap
306
- img_np = img.cpu().permute(1, 2, 0).mean(dim=-1).numpy()
307
-
308
- # Normalize to 0-1
309
- img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
310
-
311
- # Apply additional sharpening with OpenCV
312
- img_np_uint8 = (img_np * 255).astype(np.uint8)
313
-
314
- # Apply unsharp mask for additional sharpness
315
- blurred = cv2.GaussianBlur(img_np_uint8, (0, 0), 2.0)
316
- sharpened = cv2.addWeighted(img_np_uint8, 1.5, blurred, -0.5, 0)
317
-
318
- # Apply bone colormap using matplotlib - FIXED APPROACH
319
- # Create a simple bone-like colormap manually to avoid matplotlib issues
320
- sharpened_normalized = sharpened / 255.0
321
- # Simulate bone colormap: black to white with blueish tones
322
- r = np.clip(sharpened_normalized * 1.2 - 0.1, 0, 1)
323
- g = np.clip(sharpened_normalized * 1.1 - 0.05, 0, 1)
324
- b = np.clip(sharpened_normalized * 1.0 + 0.1, 0, 1)
325
-
326
- # Combine channels and convert to uint8
327
- bone_colored = np.stack([r, g, b], axis=-1)
328
- bone_colored_uint8 = (bone_colored * 255).astype(np.uint8)
329
-
330
- # Create PIL image
331
- pil_img = Image.fromarray(bone_colored_uint8)
332
-
333
- processed_images.append(pil_img)
334
-
335
- if num_images == 1:
336
- return processed_images[0], processed_images
337
- else:
338
- return None, processed_images
339
-
340
- except Exception as e:
341
- traceback.print_exc()
342
- raise gr.Error(f"Generation failed: {str(e)}")
343
- finally:
344
- torch.cuda.empty_cache()
345
-
346
- # Load model
347
- MODEL_NAME = "model_weights.pth"
348
- model_path = MODEL_NAME
349
- print("Loading model...")
350
- try:
351
- loaded_model = load_model(model_path, device)
352
- print("Model loaded successfully!")
353
- except Exception as e:
354
- print(f"Failed to load model: {e}")
355
- print("Creating dummy model for demonstration")
356
- loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES), timesteps=TIMESTEPS).to(device)
357
-
358
- # Gradio UI (from first code)
359
- with gr.Blocks(theme=gr.themes.Soft(
360
- primary_hue="violet",
361
- neutral_hue="slate",
362
- font=[gr.themes.GoogleFont("Poppins")],
363
- text_size="md"
364
- )) as demo:
365
- gr.Markdown("""
366
- <center>
367
- <h1>Synthetic X-ray Generator</h1>
368
- <p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
369
- </center>
370
- """)
371
-
372
- with gr.Row():
373
- with gr.Column(scale=1):
374
- condition = gr.Dropdown(
375
- ["Pneumonia", "Pneumothorax"],
376
- label="Select Condition",
377
- value="Pneumonia",
378
- interactive=True
379
- )
380
- num_images = gr.Slider(
381
- 1, 10, value=1, step=1,
382
- label="Number of Images",
383
- interactive=True
384
- )
385
-
386
- with gr.Row():
387
- submit_btn = gr.Button("Generate", variant="primary")
388
- cancel_btn = gr.Button("Cancel", variant="stop")
389
-
390
- gr.Markdown("""
391
- <div style="text-align: center; margin-top: 10px;">
392
- <small>Note: Generation may take several seconds per image</small>
393
- </div>
394
- """)
395
-
396
- with gr.Column(scale=2):
397
- with gr.Tabs():
398
- with gr.TabItem("Output", id="output_tab"):
399
- single_image = gr.Image(
400
- label="Generated X-ray",
401
- height=400,
402
- visible=True
403
- )
404
- gallery = gr.Gallery(
405
- label="Generated X-rays",
406
- columns=3,
407
- height="auto",
408
- object_fit="contain",
409
- visible=False
410
- )
411
-
412
- def update_ui_based_on_count(num_images):
413
- if num_images == 1:
414
- return {
415
- single_image: gr.update(visible=True),
416
- gallery: gr.update(visible=False)
417
- }
418
- else:
419
- return {
420
- single_image: gr.update(visible=False),
421
- gallery: gr.update(visible=True)
422
- }
423
-
424
- num_images.change(
425
- fn=update_ui_based_on_count,
426
- inputs=num_images,
427
- outputs=[single_image, gallery]
428
- )
429
-
430
- submit_btn.click(
431
- fn=generate_images,
432
- inputs=[condition, num_images],
433
- outputs=[single_image, gallery]
434
- )
435
-
436
- cancel_btn.click(
437
- fn=cancel_generation,
438
- outputs=None
439
- )
440
-
441
- demo.css = """
442
- .gradio-container {
443
- background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%);
444
- }
445
- .gallery-container {
446
- background-color: white !important;
447
- }
448
- """
449
-
450
- if __name__ == "__main__":
451
- demo.launch(server_name="0.0.0.0", server_port=7860)