Vedansh-7 commited on
Commit
f25462b
·
1 Parent(s): 071deee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -6
app.py CHANGED
@@ -7,6 +7,9 @@ import math
7
  import os
8
  from threading import Event
9
  import traceback
 
 
 
10
 
11
  # Constants
12
  IMG_SIZE = 128
@@ -153,8 +156,10 @@ class DiffusionModel(nn.Module):
153
 
154
  @torch.no_grad()
155
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
 
156
  x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
157
 
 
158
  if labels.ndim == 1:
159
  labels_one_hot = torch.zeros(num_images, num_classes).to(device)
160
  labels_one_hot[torch.arange(num_images), labels] = 1
@@ -162,6 +167,7 @@ class DiffusionModel(nn.Module):
162
  else:
163
  labels = labels.to(device)
164
 
 
165
  for t in reversed(range(self.timesteps)):
166
  if cancel_event.is_set():
167
  return None
@@ -169,6 +175,7 @@ class DiffusionModel(nn.Module):
169
  t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
170
  predicted_noise = self.model(x_t, labels, t_tensor)
171
 
 
172
  beta_t = self.betas[t].to(device)
173
  alpha_t = self.alphas[t].to(device)
174
  alpha_bar_t = self.alpha_bars[t].to(device)
@@ -176,8 +183,9 @@ class DiffusionModel(nn.Module):
176
  mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
177
  variance = beta_t
178
 
 
179
  if t > 0:
180
- noise = torch.randn_like(x_t)
181
  else:
182
  noise = torch.zeros_like(x_t)
183
 
@@ -186,14 +194,34 @@ class DiffusionModel(nn.Module):
186
  if progress_callback:
187
  progress_callback((self.timesteps - t) / self.timesteps)
188
 
 
189
  x_0 = torch.clamp(x_t, -1., 1.)
190
-
191
- # Normalization
192
  mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
193
  std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
194
  x_0 = std * x_0 + mean
195
  x_0 = torch.clamp(x_0, 0., 1.)
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  return x_0
198
 
199
  def load_model(model_path, device):
@@ -274,9 +302,34 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
274
 
275
  processed_images = []
276
  for img in images:
277
- img_np = img.cpu().permute(1, 2, 0).numpy()
278
- img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
279
- pil_img = Image.fromarray(img_np)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  processed_images.append(pil_img)
281
 
282
  if num_images == 1:
 
7
  import os
8
  from threading import Event
9
  import traceback
10
+ import cv2 # Added for bilateral filtering
11
+ import matplotlib.pyplot as plt
12
+ from io import BytesIO
13
 
14
  # Constants
15
  IMG_SIZE = 128
 
156
 
157
  @torch.no_grad()
158
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
159
+ # Start with random noise
160
  x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
161
 
162
+ # Label handling (one-hot if needed)
163
  if labels.ndim == 1:
164
  labels_one_hot = torch.zeros(num_images, num_classes).to(device)
165
  labels_one_hot[torch.arange(num_images), labels] = 1
 
167
  else:
168
  labels = labels.to(device)
169
 
170
+ # REVERTED SAMPLING LOOP WITH NOISE REDUCTION
171
  for t in reversed(range(self.timesteps)):
172
  if cancel_event.is_set():
173
  return None
 
175
  t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
176
  predicted_noise = self.model(x_t, labels, t_tensor)
177
 
178
+ # Calculate coefficients
179
  beta_t = self.betas[t].to(device)
180
  alpha_t = self.alphas[t].to(device)
181
  alpha_bar_t = self.alpha_bars[t].to(device)
 
183
  mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
184
  variance = beta_t
185
 
186
+ # Reduced noise injection with lower multiplier
187
  if t > 0:
188
+ noise = torch.randn_like(x_t) * 0.8 # Reduced noise by 20%
189
  else:
190
  noise = torch.zeros_like(x_t)
191
 
 
194
  if progress_callback:
195
  progress_callback((self.timesteps - t) / self.timesteps)
196
 
197
+ # Clamp and denormalize
198
  x_0 = torch.clamp(x_t, -1., 1.)
 
 
199
  mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
200
  std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
201
  x_0 = std * x_0 + mean
202
  x_0 = torch.clamp(x_0, 0., 1.)
203
 
204
+ # ENHANCED SHARPENING
205
+ # First apply mild bilateral filtering to reduce noise while preserving edges
206
+ x_np = x_0.cpu().permute(0, 2, 3, 1).numpy()
207
+ filtered = []
208
+ for img in x_np:
209
+ img = (img * 255).astype(np.uint8)
210
+ filtered_img = cv2.bilateralFilter(img, d=5, sigmaColor=15, sigmaSpace=15)
211
+ filtered.append(filtered_img / 255.0)
212
+ x_0 = torch.tensor(np.array(filtered), device=device).permute(0, 3, 1, 2)
213
+
214
+ # Then apply stronger unsharp masking
215
+ kernel = torch.ones(3, 1, 5, 5, device=device) / 75
216
+ kernel = kernel.to(x_0.dtype)
217
+ blurred = torch.nn.functional.conv2d(
218
+ x_0,
219
+ kernel,
220
+ padding=2,
221
+ groups=3
222
+ )
223
+ x_0 = torch.clamp(1.5 * x_0 - 0.5 * blurred, 0., 1.) # Increased sharpening factor
224
+
225
  return x_0
226
 
227
  def load_model(model_path, device):
 
302
 
303
  processed_images = []
304
  for img in images:
305
+ # Convert to grayscale and apply bone colormap
306
+ img_np = img.cpu().permute(1, 2, 0).mean(dim=-1).numpy()
307
+
308
+ # Normalize to 0-1
309
+ img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
310
+
311
+ # Apply additional sharpening with OpenCV
312
+ img_np_uint8 = (img_np * 255).astype(np.uint8)
313
+
314
+ # Apply unsharp mask for additional sharpness
315
+ blurred = cv2.GaussianBlur(img_np_uint8, (0, 0), 2.0)
316
+ sharpened = cv2.addWeighted(img_np_uint8, 1.5, blurred, -0.5, 0)
317
+
318
+ # Apply bone colormap using matplotlib - FIXED APPROACH
319
+ # Create a simple bone-like colormap manually to avoid matplotlib issues
320
+ sharpened_normalized = sharpened / 255.0
321
+ # Simulate bone colormap: black to white with blueish tones
322
+ r = np.clip(sharpened_normalized * 1.2 - 0.1, 0, 1)
323
+ g = np.clip(sharpened_normalized * 1.1 - 0.05, 0, 1)
324
+ b = np.clip(sharpened_normalized * 1.0 + 0.1, 0, 1)
325
+
326
+ # Combine channels and convert to uint8
327
+ bone_colored = np.stack([r, g, b], axis=-1)
328
+ bone_colored_uint8 = (bone_colored * 255).astype(np.uint8)
329
+
330
+ # Create PIL image
331
+ pil_img = Image.fromarray(bone_colored_uint8)
332
+
333
  processed_images.append(pil_img)
334
 
335
  if num_images == 1: