Vedansh-7 commited on
Commit
65eb734
·
verified ·
1 Parent(s): 0aee992

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -423
app.py DELETED
@@ -1,423 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import gradio as gr
4
- from PIL import Image
5
- import numpy as np
6
- import math
7
- import os
8
- from threading import Event
9
- import traceback
10
-
11
- # Constants
12
- IMG_SIZE = 128
13
- TIMESTEPS = 300
14
- NUM_CLASSES = 2
15
-
16
- # Global Cancellation Flag
17
- cancel_event = Event()
18
-
19
- # Device Configuration
20
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
-
22
- # --- Model Definitions ---
23
- class SinusoidalPositionEmbeddings(nn.Module):
24
- def __init__(self, dim):
25
- super().__init__()
26
- self.dim = dim
27
- half_dim = dim // 2
28
- emb = math.log(10000) / (half_dim - 1)
29
- emb = torch.exp(torch.arange(half_dim) * -emb)
30
- self.register_buffer('embeddings', emb)
31
-
32
- def forward(self, time):
33
- device = time.device
34
- embeddings = self.embeddings.to(device)
35
- embeddings = time[:, None] * embeddings[None, :]
36
- return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
37
-
38
- class UNet(nn.Module):
39
- def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
40
- super().__init__()
41
- self.num_classes = num_classes
42
- self.label_embedding = nn.Embedding(num_classes, time_dim)
43
-
44
- self.time_mlp = nn.Sequential(
45
- SinusoidalPositionEmbeddings(time_dim),
46
- nn.Linear(time_dim, time_dim),
47
- nn.ReLU(),
48
- nn.Linear(time_dim, time_dim)
49
- )
50
-
51
- # Encoder
52
- self.inc = self.double_conv(in_channels, 64)
53
- self.down1 = self.down(64 + time_dim * 2, 128)
54
- self.down2 = self.down(128 + time_dim * 2, 256)
55
- self.down3 = self.down(256 + time_dim * 2, 512)
56
-
57
- # Bottleneck
58
- self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
59
-
60
- # Decoder
61
- self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
62
- self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
63
-
64
- self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
65
- self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128)
66
-
67
- self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
68
- self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64)
69
-
70
- self.outc = nn.Conv2d(64, out_channels, kernel_size=1)
71
-
72
- def double_conv(self, in_channels, out_channels):
73
- return nn.Sequential(
74
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
75
- nn.ReLU(inplace=True),
76
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
77
- nn.ReLU(inplace=True)
78
- )
79
-
80
- def down(self, in_channels, out_channels):
81
- return nn.Sequential(
82
- nn.MaxPool2d(2),
83
- self.double_conv(in_channels, out_channels)
84
- )
85
-
86
- def forward(self, x, labels, time):
87
- label_indices = torch.argmax(labels, dim=1)
88
- label_emb = self.label_embedding(label_indices)
89
- t_emb = self.time_mlp(time)
90
-
91
- combined_emb = torch.cat([t_emb, label_emb], dim=1)
92
- combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1)
93
-
94
- x1 = self.inc(x)
95
- x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1)
96
-
97
- x2 = self.down1(x1_cat)
98
- x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1)
99
-
100
- x3 = self.down2(x2_cat)
101
- x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1)
102
-
103
- x4 = self.down3(x3_cat)
104
- x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1)
105
-
106
- x5 = self.bottleneck(x4_cat)
107
-
108
- x = self.up1(x5)
109
- x = torch.cat([x, x3], dim=1)
110
- x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
111
- x = self.upconv1(x)
112
-
113
- x = self.up2(x)
114
- x = torch.cat([x, x2], dim=1)
115
- x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
116
- x = self.upconv2(x)
117
-
118
- x = self.up3(x)
119
- x = torch.cat([x, x1], dim=1)
120
- x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
121
- x = self.upconv3(x)
122
-
123
- return self.outc(x)
124
-
125
- class DiffusionModel(nn.Module):
126
- def __init__(self, model, timesteps=TIMESTEPS, time_dim=256):
127
- super().__init__()
128
- self.model = model
129
- self.timesteps = timesteps
130
- self.time_dim = time_dim
131
-
132
- # Linear beta schedule with scaling from second code
133
- scale = 1000 / timesteps
134
- beta_start = scale * 0.0001
135
- beta_end = scale * 0.02
136
- self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
137
- self.alphas = 1. - self.betas
138
- self.alpha_bars = torch.cumprod(self.alphas, dim=0).float()
139
-
140
- # Register buffers for device compatibility
141
- self.register_buffer('betas_buffer', self.betas)
142
- self.register_buffer('alphas_buffer', self.alphas)
143
- self.register_buffer('alpha_bars_buffer', self.alpha_bars)
144
-
145
- def forward_diffusion(self, x_0, t, noise):
146
- x_0 = x_0.float()
147
- noise = noise.float()
148
- alpha_bar_t = self.alpha_bars_buffer[t].view(-1, 1, 1, 1)
149
- x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1. - alpha_bar_t) * noise
150
- return x_t
151
-
152
- def forward(self, x_0, labels):
153
- t = torch.randint(0, self.timesteps, (x_0.shape[0],), device=x_0.device).long()
154
- noise = torch.randn_like(x_0)
155
- x_t = self.forward_diffusion(x_0, t, noise)
156
- predicted_noise = self.model(x_t, labels, t.float())
157
- return predicted_noise, noise, t
158
-
159
- @torch.no_grad()
160
- def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
161
- # Start with random noise (slightly reduced variance as in code 2)
162
- x_t = torch.randn(num_images, 3, img_size, img_size, device=device) * 0.9
163
-
164
- # Label handling
165
- labels = labels.to(device)
166
- if labels.ndim == 1:
167
- labels_one_hot = torch.zeros(num_images, num_classes, device=device)
168
- labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)
169
- labels = labels_one_hot
170
-
171
- # ---- OPTIMIZED REVERSE DIFFUSION (from code 2) ----
172
- for t in reversed(range(self.timesteps)):
173
- if cancel_event.is_set():
174
- return None
175
-
176
- t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
177
- predicted_noise = self.model(x_t, labels, t_tensor)
178
-
179
- # Calculate coefficients
180
- beta_t = self.betas_buffer[t].to(device)
181
- alpha_t = self.alphas_buffer[t].to(device)
182
- alpha_bar_t = self.alpha_bars_buffer[t].to(device)
183
-
184
- mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
185
-
186
- # Dynamic noise reduction - less noise in later steps (from code 2)
187
- if t > 0:
188
- noise_factor = 0.6 + 0.4 * (t / self.timesteps) # Ranges from 0.6 to 1.0
189
- noise = torch.randn_like(x_t) * noise_factor
190
- else:
191
- noise = torch.zeros_like(x_t)
192
-
193
- x_t = mean + torch.sqrt(beta_t) * noise
194
-
195
- if progress_callback:
196
- progress_callback((self.timesteps - t) / self.timesteps)
197
-
198
- # ---- POST-PROCESSING PIPELINE (from code 2) ----
199
- # 1. Denormalize
200
- norm_mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
201
- norm_std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
202
- x_0 = torch.clamp(x_t * norm_std + norm_mean, 0, 1)
203
-
204
- # 2. Edge-preserving smoothing (PyTorch implementation)
205
- blurred = torch.nn.functional.avg_pool2d(x_0, kernel_size=5, stride=1, padding=2)
206
- mask = torch.abs(x_0 - blurred) < 0.1 # Only smooth low-frequency areas
207
- x_0 = torch.where(mask, 0.7*x_0 + 0.3*blurred, x_0)
208
-
209
- # 3. Adaptive sharpening
210
- sharpening_strength = 1.4 # Slightly stronger than original
211
- edge_boost = 0.15 # Additional edge enhancement
212
-
213
- # Main sharpening
214
- low_pass = torch.nn.functional.avg_pool2d(x_0, kernel_size=3, stride=1, padding=1)
215
- x_0 = torch.clamp((1 + sharpening_strength) * x_0 - sharpening_strength * low_pass, 0, 1)
216
-
217
- # Edge boost (only affects high-contrast areas)
218
- edges = x_0 - torch.nn.functional.avg_pool2d(x_0, kernel_size=5, stride=1, padding=2)
219
- edges = edges * edge_boost
220
- x_0 = torch.clamp(x_0 + edges, 0, 1)
221
-
222
- return x_0
223
-
224
- def load_model(model_path, device):
225
- unet_model = UNet(num_classes=NUM_CLASSES).to(device)
226
- diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
227
-
228
- if os.path.exists(model_path):
229
- checkpoint = torch.load(model_path, map_location=device)
230
-
231
- if 'model_state_dict' in checkpoint:
232
- # Handle training checkpoint format
233
- state_dict = {
234
- k[6:]: v for k, v in checkpoint['model_state_dict'].items()
235
- if k.startswith('model.')
236
- }
237
-
238
- # Load UNet weights
239
- unet_model.load_state_dict(state_dict, strict=False)
240
-
241
- # Initialize diffusion model with loaded UNet
242
- diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
243
-
244
- print(f"Loaded UNet weights from {model_path}")
245
- else:
246
- # Handle direct model weights format
247
- try:
248
- # First try loading full DiffusionModel
249
- diffusion_model.load_state_dict(checkpoint)
250
- print(f"Loaded full DiffusionModel from {model_path}")
251
- except RuntimeError:
252
- # If that fails, load just the UNet weights
253
- unet_model.load_state_dict(checkpoint, strict=False)
254
- diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
255
- print(f"Loaded UNet weights only from {model_path}")
256
- else:
257
- print(f"Weights file not found at {model_path}")
258
- print("Using randomly initialized weights")
259
-
260
- diffusion_model.eval()
261
- return diffusion_model
262
-
263
- def cancel_generation():
264
- cancel_event.set()
265
- return "Generation cancelled"
266
-
267
- def generate_images(label_str, num_images, progress=gr.Progress()):
268
- global loaded_model
269
- cancel_event.clear()
270
-
271
- if num_images < 1 or num_images > 10:
272
- raise gr.Error("Number of images must be between 1 and 10")
273
-
274
- label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
275
- if label_str not in label_map:
276
- raise gr.Error("Invalid condition selected")
277
-
278
- labels = torch.zeros(num_images, NUM_CLASSES)
279
- labels[:, label_map[label_str]] = 1
280
-
281
- try:
282
- def progress_callback(progress_val):
283
- progress(progress_val, desc="Generating...")
284
- if cancel_event.is_set():
285
- raise gr.Error("Generation was cancelled by user")
286
-
287
- with torch.no_grad():
288
- images = loaded_model.sample(
289
- num_images=num_images,
290
- img_size=IMG_SIZE,
291
- num_classes=NUM_CLASSES,
292
- labels=labels,
293
- device=device,
294
- progress_callback=progress_callback
295
- )
296
-
297
- if images is None:
298
- return None, None
299
-
300
- processed_images = []
301
- for img in images:
302
- img_np = img.cpu().permute(1, 2, 0).numpy()
303
- img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
304
- pil_img = Image.fromarray(img_np)
305
- processed_images.append(pil_img)
306
-
307
- if num_images == 1:
308
- return processed_images[0], processed_images
309
- else:
310
- return None, processed_images
311
-
312
- except Exception as e:
313
- traceback.print_exc()
314
- raise gr.Error(f"Generation failed: {str(e)}")
315
- finally:
316
- torch.cuda.empty_cache()
317
-
318
- # Load model
319
- MODEL_NAME = "model_weights.pth"
320
- model_path = MODEL_NAME
321
- print("Loading model...")
322
- try:
323
- loaded_model = load_model(model_path, device)
324
- print("Model loaded successfully!")
325
- except Exception as e:
326
- print(f"Failed to load model: {e}")
327
- print("Creating dummy model for demonstration")
328
- loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES), timesteps=TIMESTEPS).to(device)
329
-
330
- # Gradio UI
331
- with gr.Blocks(theme=gr.themes.Soft(
332
- primary_hue="violet",
333
- neutral_hue="slate",
334
- font=[gr.themes.GoogleFont("Poppins")],
335
- text_size="md"
336
- )) as demo:
337
- gr.Markdown("""
338
- <center>
339
- <h1>Synthetic X-ray Generator</h1>
340
- <p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
341
- </center>
342
- """)
343
-
344
- with gr.Row():
345
- with gr.Column(scale=1):
346
- condition = gr.Dropdown(
347
- ["Pneumonia", "Pneumothorax"],
348
- label="Select Condition",
349
- value="Pneumonia",
350
- interactive=True
351
- )
352
- num_images = gr.Slider(
353
- 1, 10, value=1, step=1,
354
- label="Number of Images",
355
- interactive=True
356
- )
357
-
358
- with gr.Row():
359
- submit_btn = gr.Button("Generate", variant="primary")
360
- cancel_btn = gr.Button("Cancel", variant="stop")
361
-
362
- gr.Markdown("""
363
- <div style="text-align: center; margin-top: 10px;">
364
- <small>Note: Generation may take several seconds per image</small>
365
- </div>
366
- """)
367
-
368
- with gr.Column(scale=2):
369
- with gr.Tabs():
370
- with gr.TabItem("Output", id="output_tab"):
371
- single_image = gr.Image(
372
- label="Generated X-ray",
373
- height=400,
374
- visible=True
375
- )
376
- gallery = gr.Gallery(
377
- label="Generated X-rays",
378
- columns=3,
379
- height="auto",
380
- object_fit="contain",
381
- visible=False
382
- )
383
-
384
- def update_ui_based_on_count(num_images):
385
- if num_images == 1:
386
- return {
387
- single_image: gr.update(visible=True),
388
- gallery: gr.update(visible=False)
389
- }
390
- else:
391
- return {
392
- single_image: gr.update(visible=False),
393
- gallery: gr.update(visible=True)
394
- }
395
-
396
- num_images.change(
397
- fn=update_ui_based_on_count,
398
- inputs=num_images,
399
- outputs=[single_image, gallery]
400
- )
401
-
402
- submit_btn.click(
403
- fn=generate_images,
404
- inputs=[condition, num_images],
405
- outputs=[single_image, gallery]
406
- )
407
-
408
- cancel_btn.click(
409
- fn=cancel_generation,
410
- outputs=None
411
- )
412
-
413
- demo.css = """
414
- .gradio-container {
415
- background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%);
416
- }
417
- .gallery-container {
418
- background-color: white !important;
419
- }
420
- """
421
-
422
- if __name__ == "__main__":
423
- demo.launch(server_name="0.0.0.0", server_port=7860)