Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -65,7 +65,110 @@ class HistoryManager:
|
|
65 |
# Initialize history manager
|
66 |
history_manager = HistoryManager()
|
67 |
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
def apply_preset(preset_name):
|
71 |
"""Apply a style preset and return the settings"""
|
@@ -78,29 +181,12 @@ def apply_preset(preset_name):
|
|
78 |
)
|
79 |
return (1.0, 1.0, 1.0, False)
|
80 |
|
81 |
-
def
|
82 |
-
"""
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
# Save metadata
|
88 |
-
metadata_path = output_path + ".json"
|
89 |
-
with open(metadata_path, 'w') as f:
|
90 |
-
json.dump({
|
91 |
-
"processing_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
92 |
-
"settings": settings
|
93 |
-
}, f)
|
94 |
-
except Exception as e:
|
95 |
-
print(f"Error saving image metadata: {e}")
|
96 |
-
|
97 |
-
def get_image_download_link(image):
|
98 |
-
"""Create a download link for the processed image"""
|
99 |
-
buffered = io.BytesIO()
|
100 |
-
image.save(buffered, format="PNG")
|
101 |
-
img_str = base64.b64encode(buffered.getvalue()).decode()
|
102 |
-
href = f'data:image/png;base64,{img_str}'
|
103 |
-
return href
|
104 |
|
105 |
def predict(input_img, version, preset_name, line_thickness=1.0, contrast=1.0,
|
106 |
brightness=1.0, enable_enhancement=False, output_size="Original"):
|
@@ -166,7 +252,7 @@ def predict(input_img, version, preset_name, line_thickness=1.0, contrast=1.0,
|
|
166 |
except Exception as e:
|
167 |
raise gr.Error(f"Error processing image: {str(e)}")
|
168 |
|
169 |
-
#
|
170 |
custom_css = """
|
171 |
.gradio-container {
|
172 |
font-family: 'Helvetica Neue', Arial, sans-serif;
|
@@ -208,7 +294,7 @@ custom_css = """
|
|
208 |
}
|
209 |
"""
|
210 |
|
211 |
-
# Create Gradio interface
|
212 |
with gr.Blocks(css=custom_css) as iface:
|
213 |
with gr.Row(elem_classes="gr-header"):
|
214 |
gr.Markdown("# 🎨 Advanced Line Drawing Generator")
|
|
|
65 |
# Initialize history manager
|
66 |
history_manager = HistoryManager()
|
67 |
|
68 |
+
norm_layer = nn.InstanceNorm2d
|
69 |
+
|
70 |
+
class ResidualBlock(nn.Module):
|
71 |
+
def __init__(self, in_features):
|
72 |
+
super(ResidualBlock, self).__init__()
|
73 |
+
|
74 |
+
conv_block = [ nn.ReflectionPad2d(1),
|
75 |
+
nn.Conv2d(in_features, in_features, 3),
|
76 |
+
norm_layer(in_features),
|
77 |
+
nn.ReLU(inplace=True),
|
78 |
+
nn.ReflectionPad2d(1),
|
79 |
+
nn.Conv2d(in_features, in_features, 3),
|
80 |
+
norm_layer(in_features) ]
|
81 |
+
|
82 |
+
self.conv_block = nn.Sequential(*conv_block)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
return x + self.conv_block(x)
|
86 |
+
|
87 |
+
class Generator(nn.Module):
|
88 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
|
89 |
+
super(Generator, self).__init__()
|
90 |
+
|
91 |
+
# Initial convolution block
|
92 |
+
model0 = [ nn.ReflectionPad2d(3),
|
93 |
+
nn.Conv2d(input_nc, 64, 7),
|
94 |
+
norm_layer(64),
|
95 |
+
nn.ReLU(inplace=True) ]
|
96 |
+
self.model0 = nn.Sequential(*model0)
|
97 |
+
|
98 |
+
# Downsampling
|
99 |
+
model1 = []
|
100 |
+
in_features = 64
|
101 |
+
out_features = in_features*2
|
102 |
+
for _ in range(2):
|
103 |
+
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
104 |
+
norm_layer(out_features),
|
105 |
+
nn.ReLU(inplace=True) ]
|
106 |
+
in_features = out_features
|
107 |
+
out_features = in_features*2
|
108 |
+
self.model1 = nn.Sequential(*model1)
|
109 |
+
|
110 |
+
# Residual blocks
|
111 |
+
model2 = []
|
112 |
+
for _ in range(n_residual_blocks):
|
113 |
+
model2 += [ResidualBlock(in_features)]
|
114 |
+
self.model2 = nn.Sequential(*model2)
|
115 |
+
|
116 |
+
# Upsampling
|
117 |
+
model3 = []
|
118 |
+
out_features = in_features//2
|
119 |
+
for _ in range(2):
|
120 |
+
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
121 |
+
norm_layer(out_features),
|
122 |
+
nn.ReLU(inplace=True) ]
|
123 |
+
in_features = out_features
|
124 |
+
out_features = in_features//2
|
125 |
+
self.model3 = nn.Sequential(*model3)
|
126 |
+
|
127 |
+
# Output layer
|
128 |
+
model4 = [ nn.ReflectionPad2d(3),
|
129 |
+
nn.Conv2d(64, output_nc, 7)]
|
130 |
+
if sigmoid:
|
131 |
+
model4 += [nn.Sigmoid()]
|
132 |
+
|
133 |
+
self.model4 = nn.Sequential(*model4)
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
out = self.model0(x)
|
137 |
+
out = self.model1(out)
|
138 |
+
out = self.model2(out)
|
139 |
+
out = self.model3(out)
|
140 |
+
out = self.model4(out)
|
141 |
+
return out
|
142 |
+
|
143 |
+
# Initialize models
|
144 |
+
def load_models():
|
145 |
+
try:
|
146 |
+
print("Initializing models in CPU mode...")
|
147 |
+
model1 = Generator(3, 1, 3)
|
148 |
+
model2 = Generator(3, 1, 3)
|
149 |
+
|
150 |
+
model1.load_state_dict(torch.load('model.pth', map_location='cpu'))
|
151 |
+
model2.load_state_dict(torch.load('model2.pth', map_location='cpu'))
|
152 |
+
|
153 |
+
model1.eval()
|
154 |
+
model2.eval()
|
155 |
+
torch.set_grad_enabled(False)
|
156 |
+
|
157 |
+
print("Models loaded successfully in CPU mode")
|
158 |
+
return model1, model2
|
159 |
+
except Exception as e:
|
160 |
+
error_msg = f"Error loading models: {str(e)}"
|
161 |
+
print(error_msg)
|
162 |
+
raise gr.Error("Failed to initialize models. Please check the model files and system configuration.")
|
163 |
+
|
164 |
+
# Load models
|
165 |
+
try:
|
166 |
+
print("Starting model initialization...")
|
167 |
+
model1, model2 = load_models()
|
168 |
+
print("Model initialization completed")
|
169 |
+
except Exception as e:
|
170 |
+
print(f"Critical error during model initialization: {str(e)}")
|
171 |
+
raise gr.Error("Failed to start the application due to model initialization error.")
|
172 |
|
173 |
def apply_preset(preset_name):
|
174 |
"""Apply a style preset and return the settings"""
|
|
|
181 |
)
|
182 |
return (1.0, 1.0, 1.0, False)
|
183 |
|
184 |
+
def enhance_lines(img, contrast=1.0, brightness=1.0):
|
185 |
+
"""Enhance line drawing with contrast and brightness adjustments"""
|
186 |
+
enhanced = np.array(img)
|
187 |
+
enhanced = enhanced * contrast
|
188 |
+
enhanced = np.clip(enhanced + brightness, 0, 1)
|
189 |
+
return Image.fromarray((enhanced * 255).astype(np.uint8))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
def predict(input_img, version, preset_name, line_thickness=1.0, contrast=1.0,
|
192 |
brightness=1.0, enable_enhancement=False, output_size="Original"):
|
|
|
252 |
except Exception as e:
|
253 |
raise gr.Error(f"Error processing image: {str(e)}")
|
254 |
|
255 |
+
# Custom CSS
|
256 |
custom_css = """
|
257 |
.gradio-container {
|
258 |
font-family: 'Helvetica Neue', Arial, sans-serif;
|
|
|
294 |
}
|
295 |
"""
|
296 |
|
297 |
+
# Create Gradio interface
|
298 |
with gr.Blocks(css=custom_css) as iface:
|
299 |
with gr.Row(elem_classes="gr-header"):
|
300 |
gr.Markdown("# 🎨 Advanced Line Drawing Generator")
|