Vedansh-7 commited on
Commit
786c6b5
·
1 Parent(s): 78db2c7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +398 -0
app.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import numpy as np
6
+ import math
7
+ import os
8
+ from threading import Event
9
+ import traceback
10
+
11
+ # Constants
12
+ IMG_SIZE = 128
13
+ TIMESTEPS = 300 # From second code
14
+ NUM_CLASSES = 2
15
+
16
+ # Global Cancellation Flag
17
+ cancel_event = Event()
18
+
19
+ # Device Configuration
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ # --- Model Definitions ---
23
+ class SinusoidalPositionEmbeddings(nn.Module):
24
+ def __init__(self, dim):
25
+ super().__init__()
26
+ self.dim = dim
27
+ half_dim = dim // 2
28
+ emb = math.log(10000) / (half_dim - 1)
29
+ emb = torch.exp(torch.arange(half_dim) * -emb) # From second code (no dtype specified)
30
+ self.register_buffer('embeddings', emb)
31
+
32
+ def forward(self, time):
33
+ device = time.device # From second code
34
+ embeddings = self.embeddings.to(device)
35
+ embeddings = time[:, None] * embeddings[None, :] # From second code
36
+ return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
37
+
38
+ class UNet(nn.Module):
39
+ def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
40
+ super().__init__()
41
+ self.num_classes = num_classes
42
+ self.label_embedding = nn.Embedding(num_classes, time_dim)
43
+
44
+ self.time_mlp = nn.Sequential(
45
+ SinusoidalPositionEmbeddings(time_dim),
46
+ nn.Linear(time_dim, time_dim),
47
+ nn.ReLU(),
48
+ nn.Linear(time_dim, time_dim)
49
+ )
50
+
51
+ # Encoder
52
+ self.inc = self.double_conv(in_channels, 64)
53
+ self.down1 = self.down(64 + time_dim * 2, 128)
54
+ self.down2 = self.down(128 + time_dim * 2, 256)
55
+ self.down3 = self.down(256 + time_dim * 2, 512)
56
+
57
+ # Bottleneck
58
+ self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
59
+
60
+ # Decoder
61
+ self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
62
+ self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
63
+
64
+ self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
65
+ self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128)
66
+
67
+ self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
68
+ self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64)
69
+
70
+ self.outc = nn.Conv2d(64, out_channels, kernel_size=1)
71
+
72
+ def double_conv(self, in_channels, out_channels):
73
+ return nn.Sequential(
74
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
75
+ nn.ReLU(inplace=True),
76
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
77
+ nn.ReLU(inplace=True)
78
+ )
79
+
80
+ def down(self, in_channels, out_channels):
81
+ return nn.Sequential(
82
+ nn.MaxPool2d(2),
83
+ self.double_conv(in_channels, out_channels)
84
+ )
85
+
86
+ def forward(self, x, labels, time):
87
+ label_indices = torch.argmax(labels, dim=1)
88
+ label_emb = self.label_embedding(label_indices)
89
+ t_emb = self.time_mlp(time)
90
+
91
+ combined_emb = torch.cat([t_emb, label_emb], dim=1)
92
+ combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1)
93
+
94
+ x1 = self.inc(x)
95
+ x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1)
96
+
97
+ x2 = self.down1(x1_cat)
98
+ x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1)
99
+
100
+ x3 = self.down2(x2_cat)
101
+ x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1)
102
+
103
+ x4 = self.down3(x3_cat)
104
+ x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1)
105
+
106
+ x5 = self.bottleneck(x4_cat)
107
+
108
+ x = self.up1(x5)
109
+ x = torch.cat([x, x3], dim=1)
110
+ x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
111
+ x = self.upconv1(x)
112
+
113
+ x = self.up2(x)
114
+ x = torch.cat([x, x2], dim=1)
115
+ x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
116
+ x = self.upconv2(x)
117
+
118
+ x = self.up3(x)
119
+ x = torch.cat([x, x1], dim=1)
120
+ x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
121
+ x = self.upconv3(x)
122
+
123
+ return self.outc(x)
124
+
125
+ class DiffusionModel(nn.Module):
126
+ def __init__(self, model, timesteps=TIMESTEPS, time_dim=256):
127
+ super().__init__()
128
+ self.model = model
129
+ self.timesteps = timesteps
130
+ self.time_dim = time_dim
131
+
132
+ # Linear beta schedule with scaling from second code
133
+ scale = 1000 / timesteps
134
+ beta_start = scale * 0.0001
135
+ beta_end = scale * 0.02
136
+ self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
137
+ self.alphas = 1. - self.betas
138
+ self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0).float())
139
+
140
+ def forward_diffusion(self, x_0, t, noise):
141
+ x_0 = x_0.float()
142
+ noise = noise.float()
143
+ alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1)
144
+ x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1. - alpha_bar_t) * noise
145
+ return x_t
146
+
147
+ def forward(self, x_0, labels):
148
+ t = torch.randint(0, self.timesteps, (x_0.shape[0],), device=x_0.device).long()
149
+ noise = torch.randn_like(x_0)
150
+ x_t = self.forward_diffusion(x_0, t, noise)
151
+ predicted_noise = self.model(x_t, labels, t.float())
152
+ return predicted_noise, noise, t
153
+
154
+ @torch.no_grad()
155
+ def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
156
+ x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
157
+
158
+ if labels.ndim == 1:
159
+ labels_one_hot = torch.zeros(num_images, num_classes).to(device)
160
+ labels_one_hot[torch.arange(num_images), labels] = 1
161
+ labels = labels_one_hot
162
+ else:
163
+ labels = labels.to(device)
164
+
165
+ for t in reversed(range(self.timesteps)):
166
+ if cancel_event.is_set():
167
+ return None
168
+
169
+ t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
170
+ predicted_noise = self.model(x_t, labels, t_tensor)
171
+
172
+ beta_t = self.betas[t].to(device)
173
+ alpha_t = self.alphas[t].to(device)
174
+ alpha_bar_t = self.alpha_bars[t].to(device)
175
+
176
+ mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
177
+ variance = beta_t
178
+
179
+ if t > 0:
180
+ noise = torch.randn_like(x_t)
181
+ else:
182
+ noise = torch.zeros_like(x_t)
183
+
184
+ x_t = mean + torch.sqrt(variance) * noise
185
+
186
+ if progress_callback:
187
+ progress_callback((self.timesteps - t) / self.timesteps)
188
+
189
+ x_0 = torch.clamp(x_t, -1., 1.)
190
+
191
+ # Normalization
192
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
193
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
194
+ x_0 = std * x_0 + mean
195
+ x_0 = torch.clamp(x_0, 0., 1.)
196
+
197
+ return x_0
198
+
199
+ def load_model(model_path, device):
200
+ unet_model = UNet(num_classes=NUM_CLASSES).to(device)
201
+ diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
202
+
203
+ if os.path.exists(model_path):
204
+ checkpoint = torch.load(model_path, map_location=device)
205
+
206
+ if 'model_state_dict' in checkpoint:
207
+ # Handle training checkpoint format
208
+ state_dict = {
209
+ k[6:]: v for k, v in checkpoint['model_state_dict'].items()
210
+ if k.startswith('model.')
211
+ }
212
+
213
+ # Load UNet weights
214
+ unet_model.load_state_dict(state_dict, strict=False)
215
+
216
+ # Initialize diffusion model with loaded UNet
217
+ diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
218
+
219
+ print(f"Loaded UNet weights from {model_path}")
220
+ else:
221
+ # Handle direct model weights format
222
+ try:
223
+ # First try loading full DiffusionModel
224
+ diffusion_model.load_state_dict(checkpoint)
225
+ print(f"Loaded full DiffusionModel from {model_path}")
226
+ except RuntimeError:
227
+ # If that fails, load just the UNet weights
228
+ unet_model.load_state_dict(checkpoint, strict=False)
229
+ diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
230
+ print(f"Loaded UNet weights only from {model_path}")
231
+ else:
232
+ print(f"Weights file not found at {model_path}")
233
+ print("Using randomly initialized weights")
234
+
235
+ diffusion_model.eval()
236
+ return diffusion_model
237
+
238
+ def cancel_generation():
239
+ cancel_event.set()
240
+ return "Generation cancelled"
241
+
242
+ def generate_images(label_str, num_images, progress=gr.Progress()):
243
+ global loaded_model
244
+ cancel_event.clear()
245
+
246
+ if num_images < 1 or num_images > 10:
247
+ raise gr.Error("Number of images must be between 1 and 10")
248
+
249
+ label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
250
+ if label_str not in label_map:
251
+ raise gr.Error("Invalid condition selected")
252
+
253
+ labels = torch.zeros(num_images, NUM_CLASSES)
254
+ labels[:, label_map[label_str]] = 1
255
+
256
+ try:
257
+ def progress_callback(progress_val):
258
+ progress(progress_val, desc="Generating...")
259
+ if cancel_event.is_set():
260
+ raise gr.Error("Generation was cancelled by user")
261
+
262
+ with torch.no_grad():
263
+ images = loaded_model.sample(
264
+ num_images=num_images,
265
+ img_size=IMG_SIZE,
266
+ num_classes=NUM_CLASSES,
267
+ labels=labels,
268
+ device=device,
269
+ progress_callback=progress_callback
270
+ )
271
+
272
+ if images is None:
273
+ return None, None
274
+
275
+ processed_images = []
276
+ for img in images:
277
+ img_np = img.cpu().permute(1, 2, 0).numpy()
278
+ img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
279
+ pil_img = Image.fromarray(img_np)
280
+ processed_images.append(pil_img)
281
+
282
+ if num_images == 1:
283
+ return processed_images[0], processed_images
284
+ else:
285
+ return None, processed_images
286
+
287
+ except Exception as e:
288
+ traceback.print_exc()
289
+ raise gr.Error(f"Generation failed: {str(e)}")
290
+ finally:
291
+ torch.cuda.empty_cache()
292
+
293
+ # Load model
294
+ MODEL_NAME = "model_weights.pth"
295
+ model_path = MODEL_NAME
296
+ print("Loading model...")
297
+ try:
298
+ loaded_model = load_model(model_path, device)
299
+ print("Model loaded successfully!")
300
+ except Exception as e:
301
+ print(f"Failed to load model: {e}")
302
+ print("Creating dummy model for demonstration")
303
+ loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES), timesteps=TIMESTEPS).to(device)
304
+
305
+ # Gradio UI (from first code)
306
+ with gr.Blocks(theme=gr.themes.Soft(
307
+ primary_hue="violet",
308
+ neutral_hue="slate",
309
+ font=[gr.themes.GoogleFont("Poppins")],
310
+ text_size="md"
311
+ )) as demo:
312
+ gr.Markdown("""
313
+ <center>
314
+ <h1>Synthetic X-ray Generator</h1>
315
+ <p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
316
+ </center>
317
+ """)
318
+
319
+ with gr.Row():
320
+ with gr.Column(scale=1):
321
+ condition = gr.Dropdown(
322
+ ["Pneumonia", "Pneumothorax"],
323
+ label="Select Condition",
324
+ value="Pneumonia",
325
+ interactive=True
326
+ )
327
+ num_images = gr.Slider(
328
+ 1, 10, value=1, step=1,
329
+ label="Number of Images",
330
+ interactive=True
331
+ )
332
+
333
+ with gr.Row():
334
+ submit_btn = gr.Button("Generate", variant="primary")
335
+ cancel_btn = gr.Button("Cancel", variant="stop")
336
+
337
+ gr.Markdown("""
338
+ <div style="text-align: center; margin-top: 10px;">
339
+ <small>Note: Generation may take several seconds per image</small>
340
+ </div>
341
+ """)
342
+
343
+ with gr.Column(scale=2):
344
+ with gr.Tabs():
345
+ with gr.TabItem("Output", id="output_tab"):
346
+ single_image = gr.Image(
347
+ label="Generated X-ray",
348
+ height=400,
349
+ visible=True
350
+ )
351
+ gallery = gr.Gallery(
352
+ label="Generated X-rays",
353
+ columns=3,
354
+ height="auto",
355
+ object_fit="contain",
356
+ visible=False
357
+ )
358
+
359
+ def update_ui_based_on_count(num_images):
360
+ if num_images == 1:
361
+ return {
362
+ single_image: gr.update(visible=True),
363
+ gallery: gr.update(visible=False)
364
+ }
365
+ else:
366
+ return {
367
+ single_image: gr.update(visible=False),
368
+ gallery: gr.update(visible=True)
369
+ }
370
+
371
+ num_images.change(
372
+ fn=update_ui_based_on_count,
373
+ inputs=num_images,
374
+ outputs=[single_image, gallery]
375
+ )
376
+
377
+ submit_btn.click(
378
+ fn=generate_images,
379
+ inputs=[condition, num_images],
380
+ outputs=[single_image, gallery]
381
+ )
382
+
383
+ cancel_btn.click(
384
+ fn=cancel_generation,
385
+ outputs=None
386
+ )
387
+
388
+ demo.css = """
389
+ .gradio-container {
390
+ background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%);
391
+ }
392
+ .gallery-container {
393
+ background-color: white !important;
394
+ }
395
+ """
396
+
397
+ if __name__ == "__main__":
398
+ demo.launch(server_name="0.0.0.0", server_port=7860)