Vedansh-7 commited on
Commit
f2a0342
·
verified ·
1 Parent(s): f69a0b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -7,11 +7,11 @@ import math
7
  import os
8
  from threading import Event
9
  import traceback
10
- import cv2 # Added for bilateral filtering
11
 
12
  # Constants
13
  IMG_SIZE = 128
14
- TIMESTEPS = 300 # From second code
15
  NUM_CLASSES = 2
16
 
17
  # Global Cancellation Flag
@@ -27,13 +27,13 @@ class SinusoidalPositionEmbeddings(nn.Module):
27
  self.dim = dim
28
  half_dim = dim // 2
29
  emb = math.log(10000) / (half_dim - 1)
30
- emb = torch.exp(torch.arange(half_dim) * -emb) # From second code (no dtype specified)
31
  self.register_buffer('embeddings', emb)
32
 
33
  def forward(self, time):
34
- device = time.device # From second code
35
  embeddings = self.embeddings.to(device)
36
- embeddings = time[:, None] * embeddings[None, :] # From second code
37
  return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
38
 
39
  class UNet(nn.Module):
@@ -130,7 +130,7 @@ class DiffusionModel(nn.Module):
130
  self.timesteps = timesteps
131
  self.time_dim = time_dim
132
 
133
- # Linear beta schedule with scaling from second code
134
  scale = 1000 / timesteps
135
  beta_start = scale * 0.0001
136
  beta_end = scale * 0.02
@@ -165,7 +165,7 @@ class DiffusionModel(nn.Module):
165
  else:
166
  labels = labels.to(device)
167
 
168
- # REVERTED SAMPLING LOOP WITH NOISE REDUCTION
169
  for t in reversed(range(self.timesteps)):
170
  if cancel_event.is_set():
171
  return None
@@ -199,7 +199,7 @@ class DiffusionModel(nn.Module):
199
  x_0 = std * x_0 + mean
200
  x_0 = torch.clamp(x_0, 0., 1.)
201
 
202
- # ENHANCED SHARPENING
203
  # First apply mild bilateral filtering to reduce noise while preserving edges
204
  x_np = x_0.cpu().permute(0, 2, 3, 1).numpy()
205
  filtered = []
@@ -208,7 +208,7 @@ class DiffusionModel(nn.Module):
208
  filtered_img = cv2.bilateralFilter(img, d=5, sigmaColor=15, sigmaSpace=15)
209
  filtered.append(filtered_img / 255.0)
210
  x_0 = torch.tensor(np.array(filtered), device=device).permute(0, 3, 1, 2)
211
-
212
  # Then apply stronger unsharp masking
213
  kernel = torch.ones(3, 1, 5, 5, device=device) / 75
214
  kernel = kernel.to(x_0.dtype)
@@ -328,7 +328,7 @@ except Exception as e:
328
  print("Creating dummy model for demonstration")
329
  loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES), timesteps=TIMESTEPS).to(device)
330
 
331
- # Gradio UI (from first code)
332
  with gr.Blocks(theme=gr.themes.Soft(
333
  primary_hue="violet",
334
  neutral_hue="slate",
@@ -421,4 +421,4 @@ with gr.Blocks(theme=gr.themes.Soft(
421
  """
422
 
423
  if __name__ == "__main__":
424
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
7
  import os
8
  from threading import Event
9
  import traceback
10
+ import cv2
11
 
12
  # Constants
13
  IMG_SIZE = 128
14
+ TIMESTEPS = 300
15
  NUM_CLASSES = 2
16
 
17
  # Global Cancellation Flag
 
27
  self.dim = dim
28
  half_dim = dim // 2
29
  emb = math.log(10000) / (half_dim - 1)
30
+ emb = torch.exp(torch.arange(half_dim) * -emb)
31
  self.register_buffer('embeddings', emb)
32
 
33
  def forward(self, time):
34
+ device = time.device
35
  embeddings = self.embeddings.to(device)
36
+ embeddings = time[:, None] * embeddings[None, :]
37
  return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
38
 
39
  class UNet(nn.Module):
 
130
  self.timesteps = timesteps
131
  self.time_dim = time_dim
132
 
133
+ # Linear beta schedule with scaling
134
  scale = 1000 / timesteps
135
  beta_start = scale * 0.0001
136
  beta_end = scale * 0.02
 
165
  else:
166
  labels = labels.to(device)
167
 
168
+ # ---- REVERTED SAMPLING LOOP WITH NOISE REDUCTION ----
169
  for t in reversed(range(self.timesteps)):
170
  if cancel_event.is_set():
171
  return None
 
199
  x_0 = std * x_0 + mean
200
  x_0 = torch.clamp(x_0, 0., 1.)
201
 
202
+ # ---- ENHANCED SHARPENING ----
203
  # First apply mild bilateral filtering to reduce noise while preserving edges
204
  x_np = x_0.cpu().permute(0, 2, 3, 1).numpy()
205
  filtered = []
 
208
  filtered_img = cv2.bilateralFilter(img, d=5, sigmaColor=15, sigmaSpace=15)
209
  filtered.append(filtered_img / 255.0)
210
  x_0 = torch.tensor(np.array(filtered), device=device).permute(0, 3, 1, 2)
211
+
212
  # Then apply stronger unsharp masking
213
  kernel = torch.ones(3, 1, 5, 5, device=device) / 75
214
  kernel = kernel.to(x_0.dtype)
 
328
  print("Creating dummy model for demonstration")
329
  loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES), timesteps=TIMESTEPS).to(device)
330
 
331
+ # Gradio UI
332
  with gr.Blocks(theme=gr.themes.Soft(
333
  primary_hue="violet",
334
  neutral_hue="slate",
 
421
  """
422
 
423
  if __name__ == "__main__":
424
+ demo.launch(server_name="0.0.0.0", server_port=7860,share= True)