Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,11 +7,11 @@ import math
|
|
7 |
import os
|
8 |
from threading import Event
|
9 |
import traceback
|
10 |
-
import cv2
|
11 |
|
12 |
# Constants
|
13 |
IMG_SIZE = 128
|
14 |
-
TIMESTEPS = 300
|
15 |
NUM_CLASSES = 2
|
16 |
|
17 |
# Global Cancellation Flag
|
@@ -27,13 +27,13 @@ class SinusoidalPositionEmbeddings(nn.Module):
|
|
27 |
self.dim = dim
|
28 |
half_dim = dim // 2
|
29 |
emb = math.log(10000) / (half_dim - 1)
|
30 |
-
emb = torch.exp(torch.arange(half_dim) * -emb)
|
31 |
self.register_buffer('embeddings', emb)
|
32 |
|
33 |
def forward(self, time):
|
34 |
-
device = time.device
|
35 |
embeddings = self.embeddings.to(device)
|
36 |
-
embeddings = time[:, None] * embeddings[None, :]
|
37 |
return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
|
38 |
|
39 |
class UNet(nn.Module):
|
@@ -130,7 +130,7 @@ class DiffusionModel(nn.Module):
|
|
130 |
self.timesteps = timesteps
|
131 |
self.time_dim = time_dim
|
132 |
|
133 |
-
# Linear beta schedule with scaling
|
134 |
scale = 1000 / timesteps
|
135 |
beta_start = scale * 0.0001
|
136 |
beta_end = scale * 0.02
|
@@ -165,7 +165,7 @@ class DiffusionModel(nn.Module):
|
|
165 |
else:
|
166 |
labels = labels.to(device)
|
167 |
|
168 |
-
# REVERTED SAMPLING LOOP WITH NOISE REDUCTION
|
169 |
for t in reversed(range(self.timesteps)):
|
170 |
if cancel_event.is_set():
|
171 |
return None
|
@@ -199,7 +199,7 @@ class DiffusionModel(nn.Module):
|
|
199 |
x_0 = std * x_0 + mean
|
200 |
x_0 = torch.clamp(x_0, 0., 1.)
|
201 |
|
202 |
-
# ENHANCED SHARPENING
|
203 |
# First apply mild bilateral filtering to reduce noise while preserving edges
|
204 |
x_np = x_0.cpu().permute(0, 2, 3, 1).numpy()
|
205 |
filtered = []
|
@@ -208,7 +208,7 @@ class DiffusionModel(nn.Module):
|
|
208 |
filtered_img = cv2.bilateralFilter(img, d=5, sigmaColor=15, sigmaSpace=15)
|
209 |
filtered.append(filtered_img / 255.0)
|
210 |
x_0 = torch.tensor(np.array(filtered), device=device).permute(0, 3, 1, 2)
|
211 |
-
|
212 |
# Then apply stronger unsharp masking
|
213 |
kernel = torch.ones(3, 1, 5, 5, device=device) / 75
|
214 |
kernel = kernel.to(x_0.dtype)
|
@@ -328,7 +328,7 @@ except Exception as e:
|
|
328 |
print("Creating dummy model for demonstration")
|
329 |
loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES), timesteps=TIMESTEPS).to(device)
|
330 |
|
331 |
-
# Gradio UI
|
332 |
with gr.Blocks(theme=gr.themes.Soft(
|
333 |
primary_hue="violet",
|
334 |
neutral_hue="slate",
|
@@ -421,4 +421,4 @@ with gr.Blocks(theme=gr.themes.Soft(
|
|
421 |
"""
|
422 |
|
423 |
if __name__ == "__main__":
|
424 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
7 |
import os
|
8 |
from threading import Event
|
9 |
import traceback
|
10 |
+
import cv2
|
11 |
|
12 |
# Constants
|
13 |
IMG_SIZE = 128
|
14 |
+
TIMESTEPS = 300
|
15 |
NUM_CLASSES = 2
|
16 |
|
17 |
# Global Cancellation Flag
|
|
|
27 |
self.dim = dim
|
28 |
half_dim = dim // 2
|
29 |
emb = math.log(10000) / (half_dim - 1)
|
30 |
+
emb = torch.exp(torch.arange(half_dim) * -emb)
|
31 |
self.register_buffer('embeddings', emb)
|
32 |
|
33 |
def forward(self, time):
|
34 |
+
device = time.device
|
35 |
embeddings = self.embeddings.to(device)
|
36 |
+
embeddings = time[:, None] * embeddings[None, :]
|
37 |
return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
|
38 |
|
39 |
class UNet(nn.Module):
|
|
|
130 |
self.timesteps = timesteps
|
131 |
self.time_dim = time_dim
|
132 |
|
133 |
+
# Linear beta schedule with scaling
|
134 |
scale = 1000 / timesteps
|
135 |
beta_start = scale * 0.0001
|
136 |
beta_end = scale * 0.02
|
|
|
165 |
else:
|
166 |
labels = labels.to(device)
|
167 |
|
168 |
+
# ---- REVERTED SAMPLING LOOP WITH NOISE REDUCTION ----
|
169 |
for t in reversed(range(self.timesteps)):
|
170 |
if cancel_event.is_set():
|
171 |
return None
|
|
|
199 |
x_0 = std * x_0 + mean
|
200 |
x_0 = torch.clamp(x_0, 0., 1.)
|
201 |
|
202 |
+
# ---- ENHANCED SHARPENING ----
|
203 |
# First apply mild bilateral filtering to reduce noise while preserving edges
|
204 |
x_np = x_0.cpu().permute(0, 2, 3, 1).numpy()
|
205 |
filtered = []
|
|
|
208 |
filtered_img = cv2.bilateralFilter(img, d=5, sigmaColor=15, sigmaSpace=15)
|
209 |
filtered.append(filtered_img / 255.0)
|
210 |
x_0 = torch.tensor(np.array(filtered), device=device).permute(0, 3, 1, 2)
|
211 |
+
|
212 |
# Then apply stronger unsharp masking
|
213 |
kernel = torch.ones(3, 1, 5, 5, device=device) / 75
|
214 |
kernel = kernel.to(x_0.dtype)
|
|
|
328 |
print("Creating dummy model for demonstration")
|
329 |
loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES), timesteps=TIMESTEPS).to(device)
|
330 |
|
331 |
+
# Gradio UI
|
332 |
with gr.Blocks(theme=gr.themes.Soft(
|
333 |
primary_hue="violet",
|
334 |
neutral_hue="slate",
|
|
|
421 |
"""
|
422 |
|
423 |
if __name__ == "__main__":
|
424 |
+
demo.launch(server_name="0.0.0.0", server_port=7860,share= True)
|