fantos commited on
Commit
cdc202d
·
verified ·
1 Parent(s): 9c857dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -26
app.py CHANGED
@@ -65,7 +65,110 @@ class HistoryManager:
65
  # Initialize history manager
66
  history_manager = HistoryManager()
67
 
68
- [Previous model and generator code remains the same...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def apply_preset(preset_name):
71
  """Apply a style preset and return the settings"""
@@ -78,29 +181,12 @@ def apply_preset(preset_name):
78
  )
79
  return (1.0, 1.0, 1.0, False)
80
 
81
- def save_image_with_metadata(image, output_path, settings):
82
- """Save image with processing metadata"""
83
- try:
84
- # Save image
85
- image.save(output_path)
86
-
87
- # Save metadata
88
- metadata_path = output_path + ".json"
89
- with open(metadata_path, 'w') as f:
90
- json.dump({
91
- "processing_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
92
- "settings": settings
93
- }, f)
94
- except Exception as e:
95
- print(f"Error saving image metadata: {e}")
96
-
97
- def get_image_download_link(image):
98
- """Create a download link for the processed image"""
99
- buffered = io.BytesIO()
100
- image.save(buffered, format="PNG")
101
- img_str = base64.b64encode(buffered.getvalue()).decode()
102
- href = f'data:image/png;base64,{img_str}'
103
- return href
104
 
105
  def predict(input_img, version, preset_name, line_thickness=1.0, contrast=1.0,
106
  brightness=1.0, enable_enhancement=False, output_size="Original"):
@@ -166,7 +252,7 @@ def predict(input_img, version, preset_name, line_thickness=1.0, contrast=1.0,
166
  except Exception as e:
167
  raise gr.Error(f"Error processing image: {str(e)}")
168
 
169
- # Extended custom CSS
170
  custom_css = """
171
  .gradio-container {
172
  font-family: 'Helvetica Neue', Arial, sans-serif;
@@ -208,7 +294,7 @@ custom_css = """
208
  }
209
  """
210
 
211
- # Create Gradio interface with enhanced UI
212
  with gr.Blocks(css=custom_css) as iface:
213
  with gr.Row(elem_classes="gr-header"):
214
  gr.Markdown("# 🎨 Advanced Line Drawing Generator")
 
65
  # Initialize history manager
66
  history_manager = HistoryManager()
67
 
68
+ norm_layer = nn.InstanceNorm2d
69
+
70
+ class ResidualBlock(nn.Module):
71
+ def __init__(self, in_features):
72
+ super(ResidualBlock, self).__init__()
73
+
74
+ conv_block = [ nn.ReflectionPad2d(1),
75
+ nn.Conv2d(in_features, in_features, 3),
76
+ norm_layer(in_features),
77
+ nn.ReLU(inplace=True),
78
+ nn.ReflectionPad2d(1),
79
+ nn.Conv2d(in_features, in_features, 3),
80
+ norm_layer(in_features) ]
81
+
82
+ self.conv_block = nn.Sequential(*conv_block)
83
+
84
+ def forward(self, x):
85
+ return x + self.conv_block(x)
86
+
87
+ class Generator(nn.Module):
88
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
89
+ super(Generator, self).__init__()
90
+
91
+ # Initial convolution block
92
+ model0 = [ nn.ReflectionPad2d(3),
93
+ nn.Conv2d(input_nc, 64, 7),
94
+ norm_layer(64),
95
+ nn.ReLU(inplace=True) ]
96
+ self.model0 = nn.Sequential(*model0)
97
+
98
+ # Downsampling
99
+ model1 = []
100
+ in_features = 64
101
+ out_features = in_features*2
102
+ for _ in range(2):
103
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
104
+ norm_layer(out_features),
105
+ nn.ReLU(inplace=True) ]
106
+ in_features = out_features
107
+ out_features = in_features*2
108
+ self.model1 = nn.Sequential(*model1)
109
+
110
+ # Residual blocks
111
+ model2 = []
112
+ for _ in range(n_residual_blocks):
113
+ model2 += [ResidualBlock(in_features)]
114
+ self.model2 = nn.Sequential(*model2)
115
+
116
+ # Upsampling
117
+ model3 = []
118
+ out_features = in_features//2
119
+ for _ in range(2):
120
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
121
+ norm_layer(out_features),
122
+ nn.ReLU(inplace=True) ]
123
+ in_features = out_features
124
+ out_features = in_features//2
125
+ self.model3 = nn.Sequential(*model3)
126
+
127
+ # Output layer
128
+ model4 = [ nn.ReflectionPad2d(3),
129
+ nn.Conv2d(64, output_nc, 7)]
130
+ if sigmoid:
131
+ model4 += [nn.Sigmoid()]
132
+
133
+ self.model4 = nn.Sequential(*model4)
134
+
135
+ def forward(self, x):
136
+ out = self.model0(x)
137
+ out = self.model1(out)
138
+ out = self.model2(out)
139
+ out = self.model3(out)
140
+ out = self.model4(out)
141
+ return out
142
+
143
+ # Initialize models
144
+ def load_models():
145
+ try:
146
+ print("Initializing models in CPU mode...")
147
+ model1 = Generator(3, 1, 3)
148
+ model2 = Generator(3, 1, 3)
149
+
150
+ model1.load_state_dict(torch.load('model.pth', map_location='cpu'))
151
+ model2.load_state_dict(torch.load('model2.pth', map_location='cpu'))
152
+
153
+ model1.eval()
154
+ model2.eval()
155
+ torch.set_grad_enabled(False)
156
+
157
+ print("Models loaded successfully in CPU mode")
158
+ return model1, model2
159
+ except Exception as e:
160
+ error_msg = f"Error loading models: {str(e)}"
161
+ print(error_msg)
162
+ raise gr.Error("Failed to initialize models. Please check the model files and system configuration.")
163
+
164
+ # Load models
165
+ try:
166
+ print("Starting model initialization...")
167
+ model1, model2 = load_models()
168
+ print("Model initialization completed")
169
+ except Exception as e:
170
+ print(f"Critical error during model initialization: {str(e)}")
171
+ raise gr.Error("Failed to start the application due to model initialization error.")
172
 
173
  def apply_preset(preset_name):
174
  """Apply a style preset and return the settings"""
 
181
  )
182
  return (1.0, 1.0, 1.0, False)
183
 
184
+ def enhance_lines(img, contrast=1.0, brightness=1.0):
185
+ """Enhance line drawing with contrast and brightness adjustments"""
186
+ enhanced = np.array(img)
187
+ enhanced = enhanced * contrast
188
+ enhanced = np.clip(enhanced + brightness, 0, 1)
189
+ return Image.fromarray((enhanced * 255).astype(np.uint8))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  def predict(input_img, version, preset_name, line_thickness=1.0, contrast=1.0,
192
  brightness=1.0, enable_enhancement=False, output_size="Original"):
 
252
  except Exception as e:
253
  raise gr.Error(f"Error processing image: {str(e)}")
254
 
255
+ # Custom CSS
256
  custom_css = """
257
  .gradio-container {
258
  font-family: 'Helvetica Neue', Arial, sans-serif;
 
294
  }
295
  """
296
 
297
+ # Create Gradio interface
298
  with gr.Blocks(css=custom_css) as iface:
299
  with gr.Row(elem_classes="gr-header"):
300
  gr.Markdown("# 🎨 Advanced Line Drawing Generator")