innoai commited on
Commit
111288a
·
verified ·
1 Parent(s): a3c0a8d

Update use_template_app.py

Browse files
Files changed (1) hide show
  1. use_template_app.py +464 -463
use_template_app.py CHANGED
@@ -1,463 +1,464 @@
1
- import gradio as gr
2
- import requests
3
- import pathlib
4
- import zipfile
5
- import io
6
- import urllib.parse # For URL sanitization
7
- import xml.etree.ElementTree as ET
8
- import base64
9
- from PIL import Image # For handling image types if needed by Gradio gallery
10
- import time
11
- import shutil
12
- import uuid # For unique filenames
13
-
14
- # --- Configuration ---
15
- TEMPLATE_API_BASE_URL = "http://127.0.0.1:8001" # Explicitly use 127.0.0.1
16
- SVG2PNG_API_URL = 'https://innoai-svg2png-api.hf.space/convert' # From svg2png_guiapp.py
17
- TEMP_DIR = pathlib.Path("temp_gradio_app_files") # Renamed for clarity
18
-
19
- # Ensure temp directory exists
20
- TEMP_DIR.mkdir(parents=True, exist_ok=True) # No need to clean it aggressively on each start
21
-
22
- # --- Helper Functions ---
23
-
24
- def sanitize_url(url_string: str) -> str:
25
- """
26
- Safely percent-encodes the path and query components of an HTTP/HTTPS URL.
27
- Attempts to be idempotent if parts of the URL are already percent-encoded.
28
- Parentheses in the path are preserved.
29
- """
30
- if not url_string or not isinstance(url_string, str) or url_string.startswith('data:'):
31
- return url_string
32
- try:
33
- scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url_string)
34
-
35
- if scheme not in ['http', 'https']:
36
- # For other schemes, or if splitting fails to identify a scheme, return original.
37
- # This could happen for URNs, mailto, etc., which we don't want to modify.
38
- return url_string
39
-
40
- # Sanitize path: unquote to normalize, then quote.
41
- # `safe` keeps slashes, existing valid percent-encodings (%), and specified chars like parentheses.
42
- sanitized_path = urllib.parse.quote(urllib.parse.unquote(path), safe='/%;:@&+$,=?#()~')
43
-
44
- sanitized_query = ""
45
- if query:
46
- # Unquote the original query to handle cases where it might be partially encoded
47
- unquoted_query = urllib.parse.unquote(query)
48
- # Parse the unquoted query string into a dictionary of parameters
49
- parsed_query_params = urllib.parse.parse_qs(unquoted_query, keep_blank_values=True)
50
- # urlencode will quote keys and values appropriately for the query string.
51
- sanitized_query = urllib.parse.urlencode(parsed_query_params, doseq=True, quote_via=urllib.parse.quote)
52
-
53
- final_url = urllib.parse.urlunsplit((scheme, netloc, sanitized_path, sanitized_query, fragment))
54
- # print(f"Sanitize URL Input: {url_string}, Output: {final_url}") # Debugging line
55
- return final_url
56
- except Exception as e:
57
- print(f"Warning: Could not sanitize URL '{url_string}': {e}")
58
- return url_string
59
-
60
- def fetch_template_details(m_param, u_param, category_param, name_param):
61
- """
62
- Fetches template details from the template_server.py API.
63
- Expected to return a dict with 'name', 'thumbnail_url', 'svg_download_url'.
64
- """
65
- username = urllib.parse.quote(u_param)
66
- category = urllib.parse.quote(category_param)
67
- template_name = urllib.parse.quote(name_param)
68
-
69
- api_url = f"{TEMPLATE_API_BASE_URL}/gradio_template_details/{username}/{category}/{template_name}"
70
- print(f"Fetching template details from: {api_url}")
71
- try:
72
- response = requests.get(api_url, timeout=10)
73
- response.raise_for_status()
74
- return response.json()
75
- except requests.exceptions.RequestException as e:
76
- print(f"Error fetching template details from {api_url}: {e}")
77
- return None
78
-
79
- def download_svg_content(svg_url):
80
- """Downloads the raw SVG content from a given URL."""
81
- try:
82
- response = requests.get(svg_url, timeout=10)
83
- response.raise_for_status()
84
- return response.text # SVG is XML text
85
- except requests.exceptions.RequestException as e:
86
- print(f"Error downloading SVG content from {svg_url}: {e}")
87
- return None
88
-
89
- def replace_background_in_svg(svg_content_str: str, new_image_path: str):
90
- """
91
- Replaces a designated image in the SVG string with a base64 data URI.
92
- Also converts all other external http/https image URLs in the SVG to base64 data URIs.
93
- """
94
- try:
95
- namespaces = {'xlink': 'http://www.w3.org/1999/xlink', 'svg': 'http://www.w3.org/2000/svg'}
96
- for prefix, uri in namespaces.items():
97
- ET.register_namespace(prefix, uri if prefix else '')
98
- root = ET.fromstring(svg_content_str)
99
-
100
- def _image_bytes_to_data_uri(image_bytes: bytes, resource_identifier: str) -> str:
101
- """
102
- Converts image bytes to a data URI.
103
- mime_type_source can be a full MIME type string, a file path, or a URL.
104
- """
105
- mime_type = "image/png" # Default
106
- if "/" in resource_identifier: # Check if it looks like a MIME type string or a URL/path with extension
107
- # Try to get from Content-Type header if it's a requests.Response object (not applicable here directly)
108
- # For now, infer from path/URL extension or if it's already a MIME type string
109
- if resource_identifier.startswith("image/"): # Already a MIME type
110
- mime_type = resource_identifier
111
- else: # Infer from extension
112
- ext = pathlib.Path(resource_identifier).suffix.lower()
113
- if ext == ".png": mime_type = "image/png"
114
- elif ext in [".jpg", ".jpeg"]: mime_type = "image/jpeg"
115
- elif ext == ".gif": mime_type = "image/gif"
116
- elif ext == ".webp": mime_type = "image/webp"
117
- elif ext == ".svg": mime_type = "image/svg+xml"
118
-
119
- encoded_image = base64.b64encode(image_bytes).decode('utf-8')
120
- return f"data:{mime_type};base64,{encoded_image}"
121
-
122
- # 1. Handle the user-uploaded image (new_image_path)
123
- background_image_element = None
124
- ids_to_check = ["background_image", "placeholder_image", "product_image", "main_image", "image_to_replace", "底图"]
125
- for img_id in ids_to_check:
126
- found_element = root.find(f".//svg:image[@id='{img_id}']", namespaces)
127
- if found_element is None:
128
- found_element = root.find(f".//{{*}}image[@id='{img_id}']")
129
- if found_element is not None:
130
- background_image_element = found_element
131
- break
132
-
133
- if background_image_element is None:
134
- first_image_svg_ns = root.find('.//svg:image', namespaces)
135
- if first_image_svg_ns is not None:
136
- background_image_element = first_image_svg_ns
137
- else:
138
- first_image_any_ns = root.find('.//image')
139
- if first_image_any_ns is not None:
140
- background_image_element = first_image_any_ns
141
- else:
142
- print("Warning: No suitable <image> tag found in SVG to replace with new_image_path.")
143
-
144
- if background_image_element is not None:
145
- try:
146
- new_image_path_obj = pathlib.Path(new_image_path)
147
- if not new_image_path_obj.is_file():
148
- print(f"Error: New image file not found or is not a file: {new_image_path}")
149
- # Potentially return None or raise error, for now, it will skip this replacement
150
- else:
151
- with open(new_image_path_obj, "rb") as f:
152
- new_image_bytes = f.read()
153
- data_uri = _image_bytes_to_data_uri(new_image_bytes, new_image_path) # Pass path for MIME
154
-
155
- xlink_href_attr_namespaced = f"{{{namespaces['xlink']}}}href"
156
- attr_to_set_on_target = None
157
- if xlink_href_attr_namespaced in background_image_element.attrib:
158
- attr_to_set_on_target = xlink_href_attr_namespaced
159
- elif 'href' in background_image_element.attrib:
160
- attr_to_set_on_target = 'href'
161
- else:
162
- attr_to_set_on_target = xlink_href_attr_namespaced if 'xlink' in namespaces else 'href'
163
- background_image_element.set(attr_to_set_on_target, data_uri)
164
- except Exception as e:
165
- print(f"Error processing and embedding new_image_path '{new_image_path}': {e}")
166
-
167
-
168
- # 2. Process all other http/https images in the SVG
169
- if root is not None:
170
- # Create a list of all image elements to iterate over.
171
- # This avoids issues with modifying the tree while iterating over findall() results directly.
172
- all_image_elements_in_tree = list(root.findall('.//svg:image', namespaces)) + \
173
- [el for el in root.findall('.//image') if el.tag != f"{{{namespaces['svg']}}}image"]
174
-
175
- for image_el in all_image_elements_in_tree:
176
- if image_el == background_image_element: # Already processed (or attempted)
177
- continue
178
-
179
- href_val = None
180
- attr_name_to_set = None
181
- xlink_href_qname = f"{{{namespaces['xlink']}}}href"
182
-
183
- if image_el.get(xlink_href_qname) is not None:
184
- href_val = image_el.get(xlink_href_qname)
185
- attr_name_to_set = xlink_href_qname
186
- elif image_el.get('href') is not None:
187
- href_val = image_el.get('href')
188
- attr_name_to_set = 'href'
189
-
190
- if href_val and href_val.startswith(('http://', 'https://')):
191
- try:
192
- safe_download_url = sanitize_url(href_val) # Sanitize before download
193
- # print(f"DEBUG: Downloading remote image for embedding: {safe_download_url}")
194
- response = requests.get(safe_download_url, timeout=20, stream=True)
195
- response.raise_for_status()
196
-
197
- image_bytes = response.content
198
-
199
- content_type_header = response.headers.get('Content-Type')
200
- mime_type_source = content_type_header.split(';')[0].strip() if content_type_header and '/' in content_type_header else href_val
201
-
202
- data_uri = _image_bytes_to_data_uri(image_bytes, mime_type_source)
203
- image_el.set(attr_name_to_set, data_uri)
204
- # print(f"DEBUG: Replaced remote URL {href_val} with data URI.")
205
- except requests.exceptions.RequestException as e_req:
206
- print(f"Failed to download or process remote image {href_val}: {e_req}")
207
- except Exception as e_proc:
208
- print(f"Error processing remote image {href_val} into data URI: {e_proc}")
209
-
210
- if root is None:
211
- print("Error: SVG root is None after parsing, cannot proceed.")
212
- return None
213
-
214
- modified_svg_bytes = ET.tostring(root, encoding='UTF-8', method='xml', xml_declaration=True)
215
-
216
- try:
217
- debug_svg_filename = f"debug_embedded_svg_{uuid.uuid4().hex[:8]}.svg"
218
- debug_svg_path = TEMP_DIR / debug_svg_filename
219
- with open(debug_svg_path, "wb") as f_debug:
220
- f_debug.write(modified_svg_bytes)
221
- print(f"Saved SVG with all images embedded for debugging to: {debug_svg_path}")
222
- except Exception as e_save:
223
- print(f"Error saving debug embedded SVG: {e_save}")
224
-
225
- return modified_svg_bytes
226
-
227
- except ET.ParseError as e:
228
- print(f"Error parsing SVG XML: {e}")
229
- return None
230
- except Exception as e:
231
- print(f"An unexpected error occurred during SVG manipulation: {e}")
232
- return None
233
-
234
- def convert_svg_bytes_to_png_api(svg_content_bytes: bytes, original_template_name: str, index: int):
235
- """
236
- Converts SVG bytes to PNG bytes using the external API.
237
- Saves the PNG to a temporary file and returns the path.
238
- """
239
- if not svg_content_bytes:
240
- return None
241
- try:
242
- # The API expects a file upload
243
- files = {'svg_file': ('generated_svg.svg', svg_content_bytes, 'image/svg+xml')}
244
- response = requests.post(SVG2PNG_API_URL, files=files, timeout=60)
245
- response.raise_for_status()
246
-
247
- # Use UUID for more robust unique filenames
248
- png_filename = f"{original_template_name.replace(' ','_')}_output_{index}_{uuid.uuid4().hex[:8]}.png"
249
- temp_png_path = TEMP_DIR / png_filename
250
- with open(temp_png_path, "wb") as f:
251
- f.write(response.content)
252
- return str(temp_png_path)
253
-
254
- except requests.exceptions.RequestException as e:
255
- print(f"Error converting SVG to PNG via API: {e}")
256
- if hasattr(e, 'response') and e.response is not None:
257
- print(f"API Response: {e.response.text}")
258
- return None
259
- except Exception as e:
260
- print(f"Error saving PNG: {e}")
261
- return None
262
-
263
- # --- Gradio App Logic ---
264
-
265
- def generate_images_from_template(original_svg_download_url: str, template_name_for_file: str, uploaded_image_files: list, request: gr.Request):
266
- """
267
- Main processing function for Gradio.
268
- Takes original SVG URL, template name, and list of uploaded image file objects.
269
- """
270
- if not original_svg_download_url:
271
- gr.Warning("模板信息未加载,请确保URL参数正确。")
272
- return [], None
273
- if not uploaded_image_files:
274
- gr.Warning("请上传至少一张图片。")
275
- return [], None
276
-
277
- # 1. Download original SVG content
278
- original_svg_content = download_svg_content(original_svg_download_url)
279
- if not original_svg_content:
280
- gr.Error("无法下载原始SVG模板。")
281
- return [], None
282
-
283
- generated_png_paths = []
284
- processed_count = 0
285
-
286
- for i, uploaded_file_obj in enumerate(uploaded_image_files):
287
- # uploaded_file_obj.name is the temporary path to the uploaded file
288
- new_image_path = uploaded_file_obj.name
289
-
290
- # 2. Replace background in SVG for each uploaded image
291
- modified_svg_bytes = replace_background_in_svg(original_svg_content, new_image_path)
292
- if not modified_svg_bytes:
293
- gr.Warning(f"处理图片 {i+1} 失败:无法修改SVG。")
294
- continue # Skip this image
295
-
296
- # 3. Convert modified SVG to PNG
297
- # Use template_name_for_file for unique output filenames
298
- png_path = convert_svg_bytes_to_png_api(modified_svg_bytes, template_name_for_file, i + 1)
299
- if png_path:
300
- generated_png_paths.append(png_path)
301
- processed_count += 1
302
- else:
303
- gr.Warning(f"处理图片 {i+1} 失败:无法转换为PNG。")
304
-
305
- if not generated_png_paths:
306
- gr.Info("未能成功生成任何图片。")
307
- return [], None
308
-
309
- zip_buffer = io.BytesIO()
310
- with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
311
- for png_path_str in generated_png_paths:
312
- png_file = pathlib.Path(png_path_str)
313
- zf.write(png_file, arcname=png_file.name)
314
-
315
- zip_buffer.seek(0)
316
-
317
- zip_filename = f"{template_name_for_file.replace(' ','_')}_batch_{uuid.uuid4().hex[:8]}.zip" # Added UUID
318
- temp_zip_path = TEMP_DIR / zip_filename
319
- with open(temp_zip_path, "wb") as f:
320
- f.write(zip_buffer.getvalue())
321
-
322
- gr.Info(f"成功生成 {processed_count} 张图片!")
323
- return generated_png_paths, str(temp_zip_path)
324
-
325
-
326
- def initial_load_template_info(request: gr.Request):
327
- """
328
- Loads initial template information based on URL query parameters.
329
- Downloads the thumbnail and returns its local path.
330
- """
331
- query_params = request.query_params
332
- m = query_params.get("m")
333
- u = query_params.get("u")
334
- category = query_params.get("category")
335
- name = query_params.get("name") # This is the template name
336
-
337
- if not all([m, u, category, name]):
338
- print("Initial load: URL parameters (m, u, category, name) are incomplete or missing.")
339
- return None, "无模板信息 (请检查URL参数)", None, "无模板"
340
-
341
- details = fetch_template_details(m, u, category, name)
342
-
343
- local_thumbnail_path = None
344
- template_display_name = "错误: 无法加载模板信息" # Default error message
345
- svg_download_url = None
346
-
347
- if details and "name" in details:
348
- template_display_name = details["name"] # Use name from details for display
349
- svg_download_url = details.get("svg_download_url")
350
-
351
- if details.get("thumbnail_url"):
352
- try:
353
- thumb_url = details["thumbnail_url"]
354
- print(f"Fetching thumbnail from: {thumb_url}")
355
- thumb_response = requests.get(thumb_url, timeout=10)
356
- thumb_response.raise_for_status()
357
-
358
- # Create a unique filename for the thumbnail in TEMP_DIR
359
- # Use name from URL params for file naming consistency if needed, or details['name']
360
- safe_template_name_for_file = name.replace(' ','_').replace('/','_').replace('\\\\','_') # Basic sanitization
361
- thumb_filename = f"thumb_{safe_template_name_for_file}_{uuid.uuid4().hex[:8]}.png"
362
- local_thumbnail_path = TEMP_DIR / thumb_filename
363
- with open(local_thumbnail_path, "wb") as f:
364
- f.write(thumb_response.content)
365
- print(f"Thumbnail saved to: {local_thumbnail_path}")
366
- except requests.exceptions.RequestException as e:
367
- print(f"Error downloading thumbnail from {details.get('thumbnail_url')}: {e}")
368
- local_thumbnail_path = None
369
- except Exception as e:
370
- print(f"Error saving thumbnail: {e}")
371
- local_thumbnail_path = None
372
- else:
373
- print(f"Failed to load template details for m={m}, u={u}, category={category}, name={name}")
374
- # template_display_name is already set to the default error message
375
-
376
- return str(local_thumbnail_path) if local_thumbnail_path else None, \
377
- template_display_name, \
378
- svg_download_url, \
379
- name # 'name' from URL param for template_name_for_file_state
380
-
381
- # --- Gradio Interface ---
382
- with gr.Blocks(theme=gr.themes.Soft(), title="SVG模板批量图片生成器") as demo:
383
- gr.Markdown("## 使用SVG模板批量生成图片")
384
- gr.Markdown("从URL加载模板,上传您自己的图片替换模板中的底图,然后批量生成并下载。")
385
-
386
- # Hidden state to store original SVG download URL and template name for file ops
387
- original_svg_download_url_state = gr.State()
388
- template_name_for_file_state = gr.State() # To preserve the name from URL param for consistent file naming
389
-
390
- with gr.Row():
391
- with gr.Column(scale=1, min_width=200):
392
- template_thumbnail_display = gr.Image(label="当前模板缩略图", interactive=False, height=200, type="filepath") # Ensure type is filepath
393
- template_name_display = gr.Textbox(label="当前模板名称", interactive=False)
394
- with gr.Column(scale=3):
395
- uploaded_images_input = gr.Files(
396
- label="上传您的图片 (可多选)",
397
- file_count="multiple",
398
- file_types=["image"] # Accepts .png, .jpg, .jpeg, .gif, .webp etc.
399
- )
400
-
401
- generate_button = gr.Button("🚀 立即生成图片", variant="primary", scale=1)
402
-
403
- with gr.Accordion("生成结果预览与下载", open=True):
404
- output_gallery = gr.Gallery(
405
- label="生成图片预览",
406
- show_label=True,
407
- elem_id="output_gallery",
408
- columns=[4],
409
- object_fit="contain",
410
- height="auto"
411
- # type="filepath" is default for Gallery if fed filepaths
412
- )
413
- output_zip_file = gr.File(label="下载所有生成图片的ZIP包", interactive=False, type="filepath") # Ensure type is filepath
414
-
415
- # Load initial template info based on URL parameters when the interface loads.
416
- # The `initial_load_template_info` function will parse request.query_params.
417
- # It's crucial that `gr.Request` is passed to it.
418
- # `inputs=None` with `request: gr.Request` in function signature works.
419
- demo.load(
420
- initial_load_template_info,
421
- inputs=None, # gr.Request is implicitly passed if type-hinted in the function
422
- outputs=[
423
- template_thumbnail_display,
424
- template_name_display,
425
- original_svg_download_url_state,
426
- template_name_for_file_state
427
- ]
428
- )
429
-
430
- generate_button.click(
431
- generate_images_from_template,
432
- inputs=[
433
- original_svg_download_url_state,
434
- template_name_for_file_state,
435
- uploaded_images_input
436
- # gr.Request is also implicitly passed here
437
- ],
438
- outputs=[output_gallery, output_zip_file]
439
- )
440
-
441
- if __name__ == "__main__":
442
- # To run this app:
443
- # 1. Ensure your template_server.py is running (e.g., on http://localhost:8001)
444
- # and has the /template_details/{username}/{category}/{template_name} endpoint.
445
- # 2. Run this script: python use_template_app.py
446
- # 3. Open your browser to the Gradio link, appending parameters, e.g.:
447
- # http://127.0.0.1:7860/?m=pub&u=all&category=%E5%AE%B6%E7%94%B5&name=%E6%A8%A1%E6%9D%BF_%E5%86%B0%E7%82%B9%E4%BB%B7
448
-
449
- # For easy testing, you might want to create a dummy template_server.py endpoint
450
- # or hardcode some details if the server isn't ready.
451
-
452
- demo.launch()
453
-
454
- # Optional: Cleanup TEMP_DIR logic can be added here if needed for long-running servers
455
- # For development, manual cleanup or OS temp cleaning is often sufficient.
456
- # Example:
457
- # importate atexit
458
- # def cleanup_temp_dir():
459
- # if TEMP_DIR.exists():
460
- # print(f"Cleaning up temp directory: {TEMP_DIR}")
461
- # shutil.rmtree(TEMP_DIR)
462
- # atexit.register(cleanup_temp_dir) # This might be too aggressive for dev
463
-
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import pathlib
4
+ import zipfile
5
+ import io
6
+ import urllib.parse # For URL sanitization
7
+ import xml.etree.ElementTree as ET
8
+ import base64
9
+ from PIL import Image # For handling image types if needed by Gradio gallery
10
+ import time
11
+ import shutil
12
+ import uuid # For unique filenames
13
+ import os
14
+
15
+ # --- Configuration ---
16
+ TEMPLATE_API_BASE_URL = os.getenv("mubanapi")
17
+ SVG2PNG_API_URL = os.getenv("convertapi")
18
+ TEMP_DIR = pathlib.Path("temp_gradio_app_files") # Renamed for clarity
19
+
20
+ # Ensure temp directory exists
21
+ TEMP_DIR.mkdir(parents=True, exist_ok=True) # No need to clean it aggressively on each start
22
+
23
+ # --- Helper Functions ---
24
+
25
+ def sanitize_url(url_string: str) -> str:
26
+ """
27
+ Safely percent-encodes the path and query components of an HTTP/HTTPS URL.
28
+ Attempts to be idempotent if parts of the URL are already percent-encoded.
29
+ Parentheses in the path are preserved.
30
+ """
31
+ if not url_string or not isinstance(url_string, str) or url_string.startswith('data:'):
32
+ return url_string
33
+ try:
34
+ scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url_string)
35
+
36
+ if scheme not in ['http', 'https']:
37
+ # For other schemes, or if splitting fails to identify a scheme, return original.
38
+ # This could happen for URNs, mailto, etc., which we don't want to modify.
39
+ return url_string
40
+
41
+ # Sanitize path: unquote to normalize, then quote.
42
+ # `safe` keeps slashes, existing valid percent-encodings (%), and specified chars like parentheses.
43
+ sanitized_path = urllib.parse.quote(urllib.parse.unquote(path), safe='/%;:@&+$,=?#()~')
44
+
45
+ sanitized_query = ""
46
+ if query:
47
+ # Unquote the original query to handle cases where it might be partially encoded
48
+ unquoted_query = urllib.parse.unquote(query)
49
+ # Parse the unquoted query string into a dictionary of parameters
50
+ parsed_query_params = urllib.parse.parse_qs(unquoted_query, keep_blank_values=True)
51
+ # urlencode will quote keys and values appropriately for the query string.
52
+ sanitized_query = urllib.parse.urlencode(parsed_query_params, doseq=True, quote_via=urllib.parse.quote)
53
+
54
+ final_url = urllib.parse.urlunsplit((scheme, netloc, sanitized_path, sanitized_query, fragment))
55
+ # print(f"Sanitize URL Input: {url_string}, Output: {final_url}") # Debugging line
56
+ return final_url
57
+ except Exception as e:
58
+ print(f"Warning: Could not sanitize URL '{url_string}': {e}")
59
+ return url_string
60
+
61
+ def fetch_template_details(m_param, u_param, category_param, name_param):
62
+ """
63
+ Fetches template details from the template_server.py API.
64
+ Expected to return a dict with 'name', 'thumbnail_url', 'svg_download_url'.
65
+ """
66
+ username = urllib.parse.quote(u_param)
67
+ category = urllib.parse.quote(category_param)
68
+ template_name = urllib.parse.quote(name_param)
69
+
70
+ api_url = f"{TEMPLATE_API_BASE_URL}/gradio_template_details/{username}/{category}/{template_name}"
71
+ print(f"Fetching template details from: {api_url}")
72
+ try:
73
+ response = requests.get(api_url, timeout=10)
74
+ response.raise_for_status()
75
+ return response.json()
76
+ except requests.exceptions.RequestException as e:
77
+ print(f"Error fetching template details from {api_url}: {e}")
78
+ return None
79
+
80
+ def download_svg_content(svg_url):
81
+ """Downloads the raw SVG content from a given URL."""
82
+ try:
83
+ response = requests.get(svg_url, timeout=10)
84
+ response.raise_for_status()
85
+ return response.text # SVG is XML text
86
+ except requests.exceptions.RequestException as e:
87
+ print(f"Error downloading SVG content from {svg_url}: {e}")
88
+ return None
89
+
90
+ def replace_background_in_svg(svg_content_str: str, new_image_path: str):
91
+ """
92
+ Replaces a designated image in the SVG string with a base64 data URI.
93
+ Also converts all other external http/https image URLs in the SVG to base64 data URIs.
94
+ """
95
+ try:
96
+ namespaces = {'xlink': 'http://www.w3.org/1999/xlink', 'svg': 'http://www.w3.org/2000/svg'}
97
+ for prefix, uri in namespaces.items():
98
+ ET.register_namespace(prefix, uri if prefix else '')
99
+ root = ET.fromstring(svg_content_str)
100
+
101
+ def _image_bytes_to_data_uri(image_bytes: bytes, resource_identifier: str) -> str:
102
+ """
103
+ Converts image bytes to a data URI.
104
+ mime_type_source can be a full MIME type string, a file path, or a URL.
105
+ """
106
+ mime_type = "image/png" # Default
107
+ if "/" in resource_identifier: # Check if it looks like a MIME type string or a URL/path with extension
108
+ # Try to get from Content-Type header if it's a requests.Response object (not applicable here directly)
109
+ # For now, infer from path/URL extension or if it's already a MIME type string
110
+ if resource_identifier.startswith("image/"): # Already a MIME type
111
+ mime_type = resource_identifier
112
+ else: # Infer from extension
113
+ ext = pathlib.Path(resource_identifier).suffix.lower()
114
+ if ext == ".png": mime_type = "image/png"
115
+ elif ext in [".jpg", ".jpeg"]: mime_type = "image/jpeg"
116
+ elif ext == ".gif": mime_type = "image/gif"
117
+ elif ext == ".webp": mime_type = "image/webp"
118
+ elif ext == ".svg": mime_type = "image/svg+xml"
119
+
120
+ encoded_image = base64.b64encode(image_bytes).decode('utf-8')
121
+ return f"data:{mime_type};base64,{encoded_image}"
122
+
123
+ # 1. Handle the user-uploaded image (new_image_path)
124
+ background_image_element = None
125
+ ids_to_check = ["background_image", "placeholder_image", "product_image", "main_image", "image_to_replace", "底图"]
126
+ for img_id in ids_to_check:
127
+ found_element = root.find(f".//svg:image[@id='{img_id}']", namespaces)
128
+ if found_element is None:
129
+ found_element = root.find(f".//{{*}}image[@id='{img_id}']")
130
+ if found_element is not None:
131
+ background_image_element = found_element
132
+ break
133
+
134
+ if background_image_element is None:
135
+ first_image_svg_ns = root.find('.//svg:image', namespaces)
136
+ if first_image_svg_ns is not None:
137
+ background_image_element = first_image_svg_ns
138
+ else:
139
+ first_image_any_ns = root.find('.//image')
140
+ if first_image_any_ns is not None:
141
+ background_image_element = first_image_any_ns
142
+ else:
143
+ print("Warning: No suitable <image> tag found in SVG to replace with new_image_path.")
144
+
145
+ if background_image_element is not None:
146
+ try:
147
+ new_image_path_obj = pathlib.Path(new_image_path)
148
+ if not new_image_path_obj.is_file():
149
+ print(f"Error: New image file not found or is not a file: {new_image_path}")
150
+ # Potentially return None or raise error, for now, it will skip this replacement
151
+ else:
152
+ with open(new_image_path_obj, "rb") as f:
153
+ new_image_bytes = f.read()
154
+ data_uri = _image_bytes_to_data_uri(new_image_bytes, new_image_path) # Pass path for MIME
155
+
156
+ xlink_href_attr_namespaced = f"{{{namespaces['xlink']}}}href"
157
+ attr_to_set_on_target = None
158
+ if xlink_href_attr_namespaced in background_image_element.attrib:
159
+ attr_to_set_on_target = xlink_href_attr_namespaced
160
+ elif 'href' in background_image_element.attrib:
161
+ attr_to_set_on_target = 'href'
162
+ else:
163
+ attr_to_set_on_target = xlink_href_attr_namespaced if 'xlink' in namespaces else 'href'
164
+ background_image_element.set(attr_to_set_on_target, data_uri)
165
+ except Exception as e:
166
+ print(f"Error processing and embedding new_image_path '{new_image_path}': {e}")
167
+
168
+
169
+ # 2. Process all other http/https images in the SVG
170
+ if root is not None:
171
+ # Create a list of all image elements to iterate over.
172
+ # This avoids issues with modifying the tree while iterating over findall() results directly.
173
+ all_image_elements_in_tree = list(root.findall('.//svg:image', namespaces)) + \
174
+ [el for el in root.findall('.//image') if el.tag != f"{{{namespaces['svg']}}}image"]
175
+
176
+ for image_el in all_image_elements_in_tree:
177
+ if image_el == background_image_element: # Already processed (or attempted)
178
+ continue
179
+
180
+ href_val = None
181
+ attr_name_to_set = None
182
+ xlink_href_qname = f"{{{namespaces['xlink']}}}href"
183
+
184
+ if image_el.get(xlink_href_qname) is not None:
185
+ href_val = image_el.get(xlink_href_qname)
186
+ attr_name_to_set = xlink_href_qname
187
+ elif image_el.get('href') is not None:
188
+ href_val = image_el.get('href')
189
+ attr_name_to_set = 'href'
190
+
191
+ if href_val and href_val.startswith(('http://', 'https://')):
192
+ try:
193
+ safe_download_url = sanitize_url(href_val) # Sanitize before download
194
+ # print(f"DEBUG: Downloading remote image for embedding: {safe_download_url}")
195
+ response = requests.get(safe_download_url, timeout=20, stream=True)
196
+ response.raise_for_status()
197
+
198
+ image_bytes = response.content
199
+
200
+ content_type_header = response.headers.get('Content-Type')
201
+ mime_type_source = content_type_header.split(';')[0].strip() if content_type_header and '/' in content_type_header else href_val
202
+
203
+ data_uri = _image_bytes_to_data_uri(image_bytes, mime_type_source)
204
+ image_el.set(attr_name_to_set, data_uri)
205
+ # print(f"DEBUG: Replaced remote URL {href_val} with data URI.")
206
+ except requests.exceptions.RequestException as e_req:
207
+ print(f"Failed to download or process remote image {href_val}: {e_req}")
208
+ except Exception as e_proc:
209
+ print(f"Error processing remote image {href_val} into data URI: {e_proc}")
210
+
211
+ if root is None:
212
+ print("Error: SVG root is None after parsing, cannot proceed.")
213
+ return None
214
+
215
+ modified_svg_bytes = ET.tostring(root, encoding='UTF-8', method='xml', xml_declaration=True)
216
+
217
+ try:
218
+ debug_svg_filename = f"debug_embedded_svg_{uuid.uuid4().hex[:8]}.svg"
219
+ debug_svg_path = TEMP_DIR / debug_svg_filename
220
+ with open(debug_svg_path, "wb") as f_debug:
221
+ f_debug.write(modified_svg_bytes)
222
+ print(f"Saved SVG with all images embedded for debugging to: {debug_svg_path}")
223
+ except Exception as e_save:
224
+ print(f"Error saving debug embedded SVG: {e_save}")
225
+
226
+ return modified_svg_bytes
227
+
228
+ except ET.ParseError as e:
229
+ print(f"Error parsing SVG XML: {e}")
230
+ return None
231
+ except Exception as e:
232
+ print(f"An unexpected error occurred during SVG manipulation: {e}")
233
+ return None
234
+
235
+ def convert_svg_bytes_to_png_api(svg_content_bytes: bytes, original_template_name: str, index: int):
236
+ """
237
+ Converts SVG bytes to PNG bytes using the external API.
238
+ Saves the PNG to a temporary file and returns the path.
239
+ """
240
+ if not svg_content_bytes:
241
+ return None
242
+ try:
243
+ # The API expects a file upload
244
+ files = {'svg_file': ('generated_svg.svg', svg_content_bytes, 'image/svg+xml')}
245
+ response = requests.post(SVG2PNG_API_URL, files=files, timeout=60)
246
+ response.raise_for_status()
247
+
248
+ # Use UUID for more robust unique filenames
249
+ png_filename = f"{original_template_name.replace(' ','_')}_output_{index}_{uuid.uuid4().hex[:8]}.png"
250
+ temp_png_path = TEMP_DIR / png_filename
251
+ with open(temp_png_path, "wb") as f:
252
+ f.write(response.content)
253
+ return str(temp_png_path)
254
+
255
+ except requests.exceptions.RequestException as e:
256
+ print(f"Error converting SVG to PNG via API: {e}")
257
+ if hasattr(e, 'response') and e.response is not None:
258
+ print(f"API Response: {e.response.text}")
259
+ return None
260
+ except Exception as e:
261
+ print(f"Error saving PNG: {e}")
262
+ return None
263
+
264
+ # --- Gradio App Logic ---
265
+
266
+ def generate_images_from_template(original_svg_download_url: str, template_name_for_file: str, uploaded_image_files: list, request: gr.Request):
267
+ """
268
+ Main processing function for Gradio.
269
+ Takes original SVG URL, template name, and list of uploaded image file objects.
270
+ """
271
+ if not original_svg_download_url:
272
+ gr.Warning("模板信息未加载,请确保URL参数正确。")
273
+ return [], None
274
+ if not uploaded_image_files:
275
+ gr.Warning("请上传至少一张图片。")
276
+ return [], None
277
+
278
+ # 1. Download original SVG content
279
+ original_svg_content = download_svg_content(original_svg_download_url)
280
+ if not original_svg_content:
281
+ gr.Error("无法下载原始SVG模板。")
282
+ return [], None
283
+
284
+ generated_png_paths = []
285
+ processed_count = 0
286
+
287
+ for i, uploaded_file_obj in enumerate(uploaded_image_files):
288
+ # uploaded_file_obj.name is the temporary path to the uploaded file
289
+ new_image_path = uploaded_file_obj.name
290
+
291
+ # 2. Replace background in SVG for each uploaded image
292
+ modified_svg_bytes = replace_background_in_svg(original_svg_content, new_image_path)
293
+ if not modified_svg_bytes:
294
+ gr.Warning(f"处理图片 {i+1} 失败:无法修改SVG。")
295
+ continue # Skip this image
296
+
297
+ # 3. Convert modified SVG to PNG
298
+ # Use template_name_for_file for unique output filenames
299
+ png_path = convert_svg_bytes_to_png_api(modified_svg_bytes, template_name_for_file, i + 1)
300
+ if png_path:
301
+ generated_png_paths.append(png_path)
302
+ processed_count += 1
303
+ else:
304
+ gr.Warning(f"处理图片 {i+1} 失败:无法转换为PNG。")
305
+
306
+ if not generated_png_paths:
307
+ gr.Info("未能成功生成任何图片。")
308
+ return [], None
309
+
310
+ zip_buffer = io.BytesIO()
311
+ with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
312
+ for png_path_str in generated_png_paths:
313
+ png_file = pathlib.Path(png_path_str)
314
+ zf.write(png_file, arcname=png_file.name)
315
+
316
+ zip_buffer.seek(0)
317
+
318
+ zip_filename = f"{template_name_for_file.replace(' ','_')}_batch_{uuid.uuid4().hex[:8]}.zip" # Added UUID
319
+ temp_zip_path = TEMP_DIR / zip_filename
320
+ with open(temp_zip_path, "wb") as f:
321
+ f.write(zip_buffer.getvalue())
322
+
323
+ gr.Info(f"成功生成 {processed_count} 张图片!")
324
+ return generated_png_paths, str(temp_zip_path)
325
+
326
+
327
+ def initial_load_template_info(request: gr.Request):
328
+ """
329
+ Loads initial template information based on URL query parameters.
330
+ Downloads the thumbnail and returns its local path.
331
+ """
332
+ query_params = request.query_params
333
+ m = query_params.get("m")
334
+ u = query_params.get("u")
335
+ category = query_params.get("category")
336
+ name = query_params.get("name") # This is the template name
337
+
338
+ if not all([m, u, category, name]):
339
+ print("Initial load: URL parameters (m, u, category, name) are incomplete or missing.")
340
+ return None, "无模板信息 (请检查URL参数)", None, "无模板"
341
+
342
+ details = fetch_template_details(m, u, category, name)
343
+
344
+ local_thumbnail_path = None
345
+ template_display_name = "错误: 无法加载模板信息" # Default error message
346
+ svg_download_url = None
347
+
348
+ if details and "name" in details:
349
+ template_display_name = details["name"] # Use name from details for display
350
+ svg_download_url = details.get("svg_download_url")
351
+
352
+ if details.get("thumbnail_url"):
353
+ try:
354
+ thumb_url = details["thumbnail_url"]
355
+ print(f"Fetching thumbnail from: {thumb_url}")
356
+ thumb_response = requests.get(thumb_url, timeout=10)
357
+ thumb_response.raise_for_status()
358
+
359
+ # Create a unique filename for the thumbnail in TEMP_DIR
360
+ # Use name from URL params for file naming consistency if needed, or details['name']
361
+ safe_template_name_for_file = name.replace(' ','_').replace('/','_').replace('\\\\','_') # Basic sanitization
362
+ thumb_filename = f"thumb_{safe_template_name_for_file}_{uuid.uuid4().hex[:8]}.png"
363
+ local_thumbnail_path = TEMP_DIR / thumb_filename
364
+ with open(local_thumbnail_path, "wb") as f:
365
+ f.write(thumb_response.content)
366
+ print(f"Thumbnail saved to: {local_thumbnail_path}")
367
+ except requests.exceptions.RequestException as e:
368
+ print(f"Error downloading thumbnail from {details.get('thumbnail_url')}: {e}")
369
+ local_thumbnail_path = None
370
+ except Exception as e:
371
+ print(f"Error saving thumbnail: {e}")
372
+ local_thumbnail_path = None
373
+ else:
374
+ print(f"Failed to load template details for m={m}, u={u}, category={category}, name={name}")
375
+ # template_display_name is already set to the default error message
376
+
377
+ return str(local_thumbnail_path) if local_thumbnail_path else None, \
378
+ template_display_name, \
379
+ svg_download_url, \
380
+ name # 'name' from URL param for template_name_for_file_state
381
+
382
+ # --- Gradio Interface ---
383
+ with gr.Blocks(theme=gr.themes.Soft(), title="SVG模板批量图片生成器") as demo:
384
+ gr.Markdown("## 使用SVG模板批量生成图片")
385
+ gr.Markdown("从URL加载模板,上传您自己的图片替换模板中的底图,然后批量生成并下载。")
386
+
387
+ # Hidden state to store original SVG download URL and template name for file ops
388
+ original_svg_download_url_state = gr.State()
389
+ template_name_for_file_state = gr.State() # To preserve the name from URL param for consistent file naming
390
+
391
+ with gr.Row():
392
+ with gr.Column(scale=1, min_width=200):
393
+ template_thumbnail_display = gr.Image(label="当前模板缩略图", interactive=False, height=200, type="filepath") # Ensure type is filepath
394
+ template_name_display = gr.Textbox(label="当前模板名称", interactive=False)
395
+ with gr.Column(scale=3):
396
+ uploaded_images_input = gr.Files(
397
+ label="上传您的图片 (可多选)",
398
+ file_count="multiple",
399
+ file_types=["image"] # Accepts .png, .jpg, .jpeg, .gif, .webp etc.
400
+ )
401
+
402
+ generate_button = gr.Button("🚀 立即生成图片", variant="primary", scale=1)
403
+
404
+ with gr.Accordion("生成结果预览与下载", open=True):
405
+ output_gallery = gr.Gallery(
406
+ label="生成图片预览",
407
+ show_label=True,
408
+ elem_id="output_gallery",
409
+ columns=[4],
410
+ object_fit="contain",
411
+ height="auto"
412
+ # type="filepath" is default for Gallery if fed filepaths
413
+ )
414
+ output_zip_file = gr.File(label="下载所有生成图片的ZIP包", interactive=False, type="filepath") # Ensure type is filepath
415
+
416
+ # Load initial template info based on URL parameters when the interface loads.
417
+ # The `initial_load_template_info` function will parse request.query_params.
418
+ # It's crucial that `gr.Request` is passed to it.
419
+ # `inputs=None` with `request: gr.Request` in function signature works.
420
+ demo.load(
421
+ initial_load_template_info,
422
+ inputs=None, # gr.Request is implicitly passed if type-hinted in the function
423
+ outputs=[
424
+ template_thumbnail_display,
425
+ template_name_display,
426
+ original_svg_download_url_state,
427
+ template_name_for_file_state
428
+ ]
429
+ )
430
+
431
+ generate_button.click(
432
+ generate_images_from_template,
433
+ inputs=[
434
+ original_svg_download_url_state,
435
+ template_name_for_file_state,
436
+ uploaded_images_input
437
+ # gr.Request is also implicitly passed here
438
+ ],
439
+ outputs=[output_gallery, output_zip_file]
440
+ )
441
+
442
+ if __name__ == "__main__":
443
+ # To run this app:
444
+ # 1. Ensure your template_server.py is running (e.g., on http://localhost:8001)
445
+ # and has the /template_details/{username}/{category}/{template_name} endpoint.
446
+ # 2. Run this script: python use_template_app.py
447
+ # 3. Open your browser to the Gradio link, appending parameters, e.g.:
448
+ # http://127.0.0.1:7860/?m=pub&u=all&category=%E5%AE%B6%E7%94%B5&name=%E6%A8%A1%E6%9D%BF_%E5%86%B0%E7%82%B9%E4%BB%B7
449
+
450
+ # For easy testing, you might want to create a dummy template_server.py endpoint
451
+ # or hardcode some details if the server isn't ready.
452
+
453
+ demo.launch()
454
+
455
+ # Optional: Cleanup TEMP_DIR logic can be added here if needed for long-running servers
456
+ # For development, manual cleanup or OS temp cleaning is often sufficient.
457
+ # Example:
458
+ # importate atexit
459
+ # def cleanup_temp_dir():
460
+ # if TEMP_DIR.exists():
461
+ # print(f"Cleaning up temp directory: {TEMP_DIR}")
462
+ # shutil.rmtree(TEMP_DIR)
463
+ # atexit.register(cleanup_temp_dir) # This might be too aggressive for dev
464
+