fantos commited on
Commit
2cdbfb7
·
verified ·
1 Parent(s): 85946be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -81
app.py CHANGED
@@ -4,43 +4,44 @@ import torch.nn as nn
4
  import gradio as gr
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
- import os # 📁 For file operations
 
 
 
 
 
8
 
9
- # 🧠 Neural network layers
10
  norm_layer = nn.InstanceNorm2d
11
 
12
- # 🧱 Building block for the generator
13
  class ResidualBlock(nn.Module):
14
  def __init__(self, in_features):
15
  super(ResidualBlock, self).__init__()
16
-
17
  conv_block = [ nn.ReflectionPad2d(1),
18
  nn.Conv2d(in_features, in_features, 3),
19
  norm_layer(in_features),
20
  nn.ReLU(inplace=True),
21
  nn.ReflectionPad2d(1),
22
  nn.Conv2d(in_features, in_features, 3),
23
- norm_layer(in_features)
24
- ]
25
-
26
  self.conv_block = nn.Sequential(*conv_block)
27
 
28
  def forward(self, x):
29
  return x + self.conv_block(x)
30
 
31
- # 🎨 Generator model for creating line drawings
32
  class Generator(nn.Module):
33
  def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
34
  super(Generator, self).__init__()
35
-
36
- # 🏁 Initial convolution block
37
  model0 = [ nn.ReflectionPad2d(3),
38
  nn.Conv2d(input_nc, 64, 7),
39
  norm_layer(64),
40
  nn.ReLU(inplace=True) ]
41
  self.model0 = nn.Sequential(*model0)
42
 
43
- # 🔽 Downsampling
44
  model1 = []
45
  in_features = 64
46
  out_features = in_features*2
@@ -52,13 +53,13 @@ class Generator(nn.Module):
52
  out_features = in_features*2
53
  self.model1 = nn.Sequential(*model1)
54
 
55
- # 🔁 Residual blocks
56
  model2 = []
57
  for _ in range(n_residual_blocks):
58
  model2 += [ResidualBlock(in_features)]
59
  self.model2 = nn.Sequential(*model2)
60
 
61
- # 🔼 Upsampling
62
  model3 = []
63
  out_features = in_features//2
64
  for _ in range(2):
@@ -69,85 +70,216 @@ class Generator(nn.Module):
69
  out_features = in_features//2
70
  self.model3 = nn.Sequential(*model3)
71
 
72
- # 🎭 Output layer
73
  model4 = [ nn.ReflectionPad2d(3),
74
- nn.Conv2d(64, output_nc, 7)]
75
  if sigmoid:
76
  model4 += [nn.Sigmoid()]
77
-
78
  self.model4 = nn.Sequential(*model4)
79
 
80
- def forward(self, x, cond=None):
81
  out = self.model0(x)
82
  out = self.model1(out)
83
  out = self.model2(out)
84
  out = self.model3(out)
85
  out = self.model4(out)
86
-
87
  return out
88
 
89
- # 🔧 Load the models
90
- model1 = Generator(3, 1, 3)
91
- model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'), weights_only=True))
92
- model1.eval()
93
-
94
- model2 = Generator(3, 1, 3)
95
- model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu'), weights_only=True))
96
- model2.eval()
97
-
98
- # 🖼️ Function to process the image and create line drawing
99
- def predict(input_img, ver):
100
- # Open the image and get its original size
101
- original_img = Image.open(input_img)
102
- original_size = original_img.size
103
-
104
- # Define the transformation pipeline
105
- transform = transforms.Compose([
106
- transforms.Resize(256, Image.BICUBIC),
107
- transforms.ToTensor(),
108
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
109
- ])
110
-
111
- # Apply the transformation
112
- input_tensor = transform(original_img)
113
- input_tensor = input_tensor.unsqueeze(0)
114
-
115
- # Process the image through the model
116
- with torch.no_grad():
117
- if ver == 'Simple Lines':
118
- output = model2(input_tensor)
119
- else:
120
- output = model1(input_tensor)
121
-
122
- # Convert the output tensor to an image
123
- output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1))
124
-
125
- # Resize the output image back to the original size
126
- output_img = output_img.resize(original_size, Image.BICUBIC)
127
-
128
- return output_img
129
-
130
- # 📝 Title for the Gradio interface
131
- title="🖌️ Image to Line Drawings - Complex and Simple Portraits and Landscapes"
132
-
133
- # 🖼️ Dynamically generate examples from images in the directory
134
- examples = []
135
- image_dir = '.' # Assuming images are in the current directory
136
- for file in os.listdir(image_dir):
137
- if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
138
- examples.append([file, 'Simple Lines'])
139
- examples.append([file, 'Complex Lines'])
140
-
141
- # 🚀 Create and launch the Gradio interface
142
- iface = gr.Interface(
143
- fn=predict,
144
- inputs=[
145
- gr.Image(type='filepath'),
146
- gr.Radio(['Complex Lines', 'Simple Lines'], label='version', value='Simple Lines')
147
- ],
148
- outputs=gr.Image(type="pil"),
149
- title=title,
150
- examples=examples
151
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
 
153
  iface.launch()
 
4
  import gradio as gr
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
+ import os
8
+ from huggingface_hub import hf_hub_download
9
+ import torch.nn.functional as F
10
+
11
+ # Check for CUDA availability but fallback to CPU
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
 
 
14
  norm_layer = nn.InstanceNorm2d
15
 
 
16
  class ResidualBlock(nn.Module):
17
  def __init__(self, in_features):
18
  super(ResidualBlock, self).__init__()
19
+
20
  conv_block = [ nn.ReflectionPad2d(1),
21
  nn.Conv2d(in_features, in_features, 3),
22
  norm_layer(in_features),
23
  nn.ReLU(inplace=True),
24
  nn.ReflectionPad2d(1),
25
  nn.Conv2d(in_features, in_features, 3),
26
+ norm_layer(in_features) ]
27
+
 
28
  self.conv_block = nn.Sequential(*conv_block)
29
 
30
  def forward(self, x):
31
  return x + self.conv_block(x)
32
 
 
33
  class Generator(nn.Module):
34
  def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
35
  super(Generator, self).__init__()
36
+
37
+ # Initial convolution block
38
  model0 = [ nn.ReflectionPad2d(3),
39
  nn.Conv2d(input_nc, 64, 7),
40
  norm_layer(64),
41
  nn.ReLU(inplace=True) ]
42
  self.model0 = nn.Sequential(*model0)
43
 
44
+ # Downsampling
45
  model1 = []
46
  in_features = 64
47
  out_features = in_features*2
 
53
  out_features = in_features*2
54
  self.model1 = nn.Sequential(*model1)
55
 
56
+ # Residual blocks
57
  model2 = []
58
  for _ in range(n_residual_blocks):
59
  model2 += [ResidualBlock(in_features)]
60
  self.model2 = nn.Sequential(*model2)
61
 
62
+ # Upsampling
63
  model3 = []
64
  out_features = in_features//2
65
  for _ in range(2):
 
70
  out_features = in_features//2
71
  self.model3 = nn.Sequential(*model3)
72
 
73
+ # Output layer
74
  model4 = [ nn.ReflectionPad2d(3),
75
+ nn.Conv2d(64, output_nc, 7)]
76
  if sigmoid:
77
  model4 += [nn.Sigmoid()]
78
+
79
  self.model4 = nn.Sequential(*model4)
80
 
81
+ def forward(self, x):
82
  out = self.model0(x)
83
  out = self.model1(out)
84
  out = self.model2(out)
85
  out = self.model3(out)
86
  out = self.model4(out)
 
87
  return out
88
 
89
+ # Initialize models
90
+ def load_models():
91
+ model1 = Generator(3, 1, 3).to(device)
92
+ model2 = Generator(3, 1, 3).to(device)
93
+
94
+ # Download models from HuggingFace Hub
95
+ model1_path = hf_hub_download(repo_id="your-hf-repo/line-drawing", filename="model.pth")
96
+ model2_path = hf_hub_download(repo_id="your-hf-repo/line-drawing", filename="model2.pth")
97
+
98
+ model1.load_state_dict(torch.load(model1_path, map_location=device))
99
+ model2.load_state_dict(torch.load(model2_path, map_location=device))
100
+
101
+ model1.eval()
102
+ model2.eval()
103
+ return model1, model2
104
+
105
+ model1, model2 = load_models()
106
+
107
+ def apply_style_transfer(img, strength=1.0):
108
+ """Apply artistic style transfer effect"""
109
+ img_array = np.array(img)
110
+ processed = F.interpolate(
111
+ torch.from_numpy(img_array).float().unsqueeze(0),
112
+ size=(256, 256),
113
+ mode='bilinear',
114
+ align_corners=False
115
+ )
116
+ return processed * strength
117
+
118
+ def enhance_lines(img, contrast=1.0, brightness=1.0):
119
+ """Enhance line drawing with contrast and brightness adjustments"""
120
+ enhanced = np.array(img)
121
+ enhanced = enhanced * contrast
122
+ enhanced = np.clip(enhanced + brightness, 0, 1)
123
+ return Image.fromarray((enhanced * 255).astype(np.uint8))
124
+
125
+ def predict(input_img, version, line_thickness=1.0, contrast=1.0, brightness=1.0, enable_enhancement=False):
126
+ try:
127
+ # Open and process input image
128
+ original_img = Image.open(input_img)
129
+ original_size = original_img.size
130
+
131
+ # Transform pipeline
132
+ transform = transforms.Compose([
133
+ transforms.Resize(256, Image.BICUBIC),
134
+ transforms.ToTensor(),
135
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
136
+ ])
137
+
138
+ input_tensor = transform(original_img).unsqueeze(0).to(device)
139
+
140
+ # Process through selected model
141
+ with torch.no_grad():
142
+ if version == 'Simple Lines':
143
+ output = model2(input_tensor)
144
+ else:
145
+ output = model1(input_tensor)
146
+
147
+ # Apply line thickness adjustment
148
+ output = output * line_thickness
149
+
150
+ # Convert to image
151
+ output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1))
152
+
153
+ # Apply enhancements if enabled
154
+ if enable_enhancement:
155
+ output_img = enhance_lines(output_img, contrast, brightness)
156
+
157
+ # Resize to original
158
+ output_img = output_img.resize(original_size, Image.BICUBIC)
159
+
160
+ return output_img
161
+
162
+ except Exception as e:
163
+ raise gr.Error(f"Error processing image: {str(e)}")
164
+
165
+ # Custom CSS for better UI
166
+ custom_css = """
167
+ .gradio-container {
168
+ font-family: 'Helvetica Neue', Arial, sans-serif;
169
+ }
170
+ .gr-button {
171
+ border-radius: 8px;
172
+ background: linear-gradient(45deg, #3498db, #2980b9);
173
+ border: none;
174
+ color: white;
175
+ }
176
+ .gr-button:hover {
177
+ background: linear-gradient(45deg, #2980b9, #3498db);
178
+ transform: translateY(-2px);
179
+ transition: all 0.3s ease;
180
+ }
181
+ .gr-input {
182
+ border-radius: 8px;
183
+ border: 2px solid #3498db;
184
+ }
185
+ """
186
+
187
+ # Create Gradio interface with enhanced UI
188
+ with gr.Blocks(css=custom_css) as iface:
189
+ gr.Markdown("# 🎨 Advanced Line Drawing Generator")
190
+ gr.Markdown("Transform your images into beautiful line drawings with advanced controls")
191
+
192
+ with gr.Row():
193
+ with gr.Column():
194
+ input_image = gr.Image(type="filepath", label="Upload Image")
195
+ version = gr.Radio(
196
+ choices=['Complex Lines', 'Simple Lines'],
197
+ value='Simple Lines',
198
+ label="Drawing Style"
199
+ )
200
+
201
+ with gr.Accordion("Advanced Settings", open=False):
202
+ line_thickness = gr.Slider(
203
+ minimum=0.1,
204
+ maximum=2.0,
205
+ value=1.0,
206
+ step=0.1,
207
+ label="Line Thickness"
208
+ )
209
+ enable_enhancement = gr.Checkbox(
210
+ label="Enable Enhancement",
211
+ value=False
212
+ )
213
+ with gr.Group(visible=False) as enhancement_controls:
214
+ contrast = gr.Slider(
215
+ minimum=0.5,
216
+ maximum=2.0,
217
+ value=1.0,
218
+ step=0.1,
219
+ label="Contrast"
220
+ )
221
+ brightness = gr.Slider(
222
+ minimum=0.5,
223
+ maximum=1.5,
224
+ value=1.0,
225
+ step=0.1,
226
+ label="Brightness"
227
+ )
228
+
229
+ enable_enhancement.change(
230
+ fn=lambda x: gr.Group(visible=x),
231
+ inputs=[enable_enhancement],
232
+ outputs=[enhancement_controls]
233
+ )
234
+
235
+ with gr.Column():
236
+ output_image = gr.Image(type="pil", label="Generated Line Drawing")
237
+
238
+ with gr.Row():
239
+ generate_btn = gr.Button("Generate Drawing", variant="primary")
240
+ clear_btn = gr.Button("Clear", variant="secondary")
241
+
242
+ # Load example images
243
+ example_images = []
244
+ for file in os.listdir('.'):
245
+ if file.lower().endswith(('.png', '.jpg', '.jpeg')):
246
+ example_images.append(file)
247
+
248
+ if example_images:
249
+ gr.Examples(
250
+ examples=[[img, "Simple Lines"] for img in example_images],
251
+ inputs=[input_image, version],
252
+ outputs=output_image,
253
+ fn=predict,
254
+ cache_examples=True
255
+ )
256
+
257
+ # Set up event handlers
258
+ generate_btn.click(
259
+ fn=predict,
260
+ inputs=[
261
+ input_image,
262
+ version,
263
+ line_thickness,
264
+ contrast,
265
+ brightness,
266
+ enable_enhancement
267
+ ],
268
+ outputs=output_image
269
+ )
270
+
271
+ clear_btn.click(
272
+ fn=lambda: (None, "Simple Lines", 1.0, 1.0, 1.0, False),
273
+ inputs=[],
274
+ outputs=[
275
+ input_image,
276
+ version,
277
+ line_thickness,
278
+ contrast,
279
+ brightness,
280
+ enable_enhancement
281
+ ]
282
+ )
283
 
284
+ # Launch the interface
285
  iface.launch()