Spaces:
Running
Running
File size: 24,595 Bytes
111288a b53bf98 111288a f813322 111288a b53bf98 111288a b53bf98 111288a b53bf98 111288a b53bf98 111288a b53bf98 111288a b53bf98 111288a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 |
import gradio as gr
import requests
import pathlib
import zipfile
import io
import urllib.parse # For URL sanitization
import xml.etree.ElementTree as ET
import base64
from PIL import Image # For handling image types if needed by Gradio gallery
import time
import shutil
import uuid # For unique filenames
import os
# --- Configuration ---
TEMPLATE_API_BASE_URL = os.getenv("mubanapi")
SVG2PNG_API_URL = os.getenv("convertapi")
TEMP_DIR = pathlib.Path("temp_gradio_app_files") # Renamed for clarity
# Ensure temp directory exists
TEMP_DIR.mkdir(parents=True, exist_ok=True) # No need to clean it aggressively on each start
# --- Helper Functions ---
def fix_internal_urls(url_string: str) -> str:
"""
Replaces internal IP address URLs (like http://10.10.71.201:8002) with the external API base URL.
This makes internal URLs accessible externally.
"""
if not url_string or not isinstance(url_string, str):
return url_string
try:
# Parse the URL
parts = urllib.parse.urlsplit(url_string)
# Check if the netloc (domain/IP) looks like an internal address
# This covers internal IPs (10.x.x.x, 172.16-31.x.x, 192.168.x.x) and localhost
internal_patterns = ['10.', '172.16.', '172.17.', '172.18.', '172.19.', '172.2', '172.3', '192.168.', 'localhost', '127.0.0.1']
is_internal = any(parts.netloc.startswith(pattern) for pattern in internal_patterns)
if is_internal and parts.scheme in ['http', 'https']:
# Keep the path and query components, but replace the domain with TEMPLATE_API_BASE_URL
# First, get the external base URL parts
base_parts = urllib.parse.urlsplit(TEMPLATE_API_BASE_URL)
# Create new URL with base URL's scheme and netloc, but original path and query
new_url = urllib.parse.urlunsplit((
base_parts.scheme,
base_parts.netloc,
parts.path,
parts.query,
parts.fragment
))
print(f"URL transformed: {url_string} -> {new_url}")
return new_url
return url_string
except Exception as e:
print(f"Error fixing internal URL '{url_string}': {e}")
return url_string
# --- Helper Functions ---
def sanitize_url(url_string: str) -> str:
"""
Safely percent-encodes the path and query components of an HTTP/HTTPS URL.
Attempts to be idempotent if parts of the URL are already percent-encoded.
Parentheses in the path are preserved.
"""
if not url_string or not isinstance(url_string, str) or url_string.startswith('data:'):
return url_string
try:
scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url_string)
if scheme not in ['http', 'https']:
# For other schemes, or if splitting fails to identify a scheme, return original.
# This could happen for URNs, mailto, etc., which we don't want to modify.
return url_string
# Sanitize path: unquote to normalize, then quote.
# `safe` keeps slashes, existing valid percent-encodings (%), and specified chars like parentheses.
sanitized_path = urllib.parse.quote(urllib.parse.unquote(path), safe='/%;:@&+$,=?#()~')
sanitized_query = ""
if query:
# Unquote the original query to handle cases where it might be partially encoded
unquoted_query = urllib.parse.unquote(query)
# Parse the unquoted query string into a dictionary of parameters
parsed_query_params = urllib.parse.parse_qs(unquoted_query, keep_blank_values=True)
# urlencode will quote keys and values appropriately for the query string.
sanitized_query = urllib.parse.urlencode(parsed_query_params, doseq=True, quote_via=urllib.parse.quote)
final_url = urllib.parse.urlunsplit((scheme, netloc, sanitized_path, sanitized_query, fragment))
# print(f"Sanitize URL Input: {url_string}, Output: {final_url}") # Debugging line
return final_url
except Exception as e:
print(f"Warning: Could not sanitize URL '{url_string}': {e}")
return url_string
def fetch_template_details(m_param, u_param, category_param, name_param):
"""
Fetches template details from the template_server.py API.
Expected to return a dict with 'name', 'thumbnail_url', 'svg_download_url'.
"""
username = urllib.parse.quote(u_param)
category = urllib.parse.quote(category_param)
template_name = urllib.parse.quote(name_param)
api_url = f"{TEMPLATE_API_BASE_URL}/gradio_template_details/{username}/{category}/{template_name}"
print(f"Fetching template details from: {api_url}")
try:
response = requests.get(api_url, timeout=10)
response.raise_for_status()
template_data = response.json()
# Fix internal URLs in the template details
if isinstance(template_data, dict):
if "thumbnail_url" in template_data:
template_data["thumbnail_url"] = fix_internal_urls(template_data["thumbnail_url"])
if "svg_download_url" in template_data:
template_data["svg_download_url"] = fix_internal_urls(template_data["svg_download_url"])
# Handle any other URLs that might be in the response
for key, value in template_data.items():
if isinstance(value, str) and (value.startswith("http://") or value.startswith("https://")):
template_data[key] = fix_internal_urls(value)
return template_data
except requests.exceptions.RequestException as e:
print(f"Error fetching template details from {api_url}: {e}")
return None
def download_svg_content(svg_url):
"""Downloads the raw SVG content from a given URL."""
try:
# Fix any internal URLs before downloading
fixed_url = fix_internal_urls(svg_url)
response = requests.get(fixed_url, timeout=10)
response.raise_for_status()
return response.text # SVG is XML text
except requests.exceptions.RequestException as e:
print(f"Error downloading SVG content from {svg_url}: {e}")
return None
def replace_background_in_svg(svg_content_str: str, new_image_path: str):
"""
Replaces a designated image in the SVG string with a base64 data URI.
Also converts all other external http/https image URLs in the SVG to base64 data URIs.
"""
try:
namespaces = {'xlink': 'http://www.w3.org/1999/xlink', 'svg': 'http://www.w3.org/2000/svg'}
for prefix, uri in namespaces.items():
ET.register_namespace(prefix, uri if prefix else '')
root = ET.fromstring(svg_content_str)
def _image_bytes_to_data_uri(image_bytes: bytes, resource_identifier: str) -> str:
"""
Converts image bytes to a data URI.
mime_type_source can be a full MIME type string, a file path, or a URL.
"""
mime_type = "image/png" # Default
if "/" in resource_identifier: # Check if it looks like a MIME type string or a URL/path with extension
# Try to get from Content-Type header if it's a requests.Response object (not applicable here directly)
# For now, infer from path/URL extension or if it's already a MIME type string
if resource_identifier.startswith("image/"): # Already a MIME type
mime_type = resource_identifier
else: # Infer from extension
ext = pathlib.Path(resource_identifier).suffix.lower()
if ext == ".png": mime_type = "image/png"
elif ext in [".jpg", ".jpeg"]: mime_type = "image/jpeg"
elif ext == ".gif": mime_type = "image/gif"
elif ext == ".webp": mime_type = "image/webp"
elif ext == ".svg": mime_type = "image/svg+xml"
encoded_image = base64.b64encode(image_bytes).decode('utf-8')
return f"data:{mime_type};base64,{encoded_image}"
# 1. Handle the user-uploaded image (new_image_path)
background_image_element = None
ids_to_check = ["background_image", "placeholder_image", "product_image", "main_image", "image_to_replace", "底图"]
for img_id in ids_to_check:
found_element = root.find(f".//svg:image[@id='{img_id}']", namespaces)
if found_element is None:
found_element = root.find(f".//{{*}}image[@id='{img_id}']")
if found_element is not None:
background_image_element = found_element
break
if background_image_element is None:
first_image_svg_ns = root.find('.//svg:image', namespaces)
if first_image_svg_ns is not None:
background_image_element = first_image_svg_ns
else:
first_image_any_ns = root.find('.//image')
if first_image_any_ns is not None:
background_image_element = first_image_any_ns
else:
print("Warning: No suitable <image> tag found in SVG to replace with new_image_path.")
if background_image_element is not None:
try:
new_image_path_obj = pathlib.Path(new_image_path)
if not new_image_path_obj.is_file():
print(f"Error: New image file not found or is not a file: {new_image_path}")
# Potentially return None or raise error, for now, it will skip this replacement
else:
with open(new_image_path_obj, "rb") as f:
new_image_bytes = f.read()
data_uri = _image_bytes_to_data_uri(new_image_bytes, new_image_path) # Pass path for MIME
xlink_href_attr_namespaced = f"{{{namespaces['xlink']}}}href"
attr_to_set_on_target = None
if xlink_href_attr_namespaced in background_image_element.attrib:
attr_to_set_on_target = xlink_href_attr_namespaced
elif 'href' in background_image_element.attrib:
attr_to_set_on_target = 'href'
else:
attr_to_set_on_target = xlink_href_attr_namespaced if 'xlink' in namespaces else 'href'
background_image_element.set(attr_to_set_on_target, data_uri)
except Exception as e:
print(f"Error processing and embedding new_image_path '{new_image_path}': {e}")
# 2. Process all other http/https images in the SVG
if root is not None:
# Create a list of all image elements to iterate over.
# This avoids issues with modifying the tree while iterating over findall() results directly.
all_image_elements_in_tree = list(root.findall('.//svg:image', namespaces)) + \
[el for el in root.findall('.//image') if el.tag != f"{{{namespaces['svg']}}}image"]
for image_el in all_image_elements_in_tree:
if image_el == background_image_element: # Already processed (or attempted)
continue
href_val = None
attr_name_to_set = None
xlink_href_qname = f"{{{namespaces['xlink']}}}href"
if image_el.get(xlink_href_qname) is not None:
href_val = image_el.get(xlink_href_qname)
attr_name_to_set = xlink_href_qname
elif image_el.get('href') is not None:
href_val = image_el.get('href')
attr_name_to_set = 'href'
if href_val and href_val.startswith(('http://', 'https://')):
try:
safe_download_url = sanitize_url(href_val) # Sanitize before download
# print(f"DEBUG: Downloading remote image for embedding: {safe_download_url}")
response = requests.get(safe_download_url, timeout=20, stream=True)
response.raise_for_status()
image_bytes = response.content
content_type_header = response.headers.get('Content-Type')
mime_type_source = content_type_header.split(';')[0].strip() if content_type_header and '/' in content_type_header else href_val
data_uri = _image_bytes_to_data_uri(image_bytes, mime_type_source)
image_el.set(attr_name_to_set, data_uri)
# print(f"DEBUG: Replaced remote URL {href_val} with data URI.")
except requests.exceptions.RequestException as e_req:
print(f"Failed to download or process remote image {href_val}: {e_req}")
except Exception as e_proc:
print(f"Error processing remote image {href_val} into data URI: {e_proc}")
if root is None:
print("Error: SVG root is None after parsing, cannot proceed.")
return None
modified_svg_bytes = ET.tostring(root, encoding='UTF-8', method='xml', xml_declaration=True)
try:
debug_svg_filename = f"debug_embedded_svg_{uuid.uuid4().hex[:8]}.svg"
debug_svg_path = TEMP_DIR / debug_svg_filename
with open(debug_svg_path, "wb") as f_debug:
f_debug.write(modified_svg_bytes)
print(f"Saved SVG with all images embedded for debugging to: {debug_svg_path}")
except Exception as e_save:
print(f"Error saving debug embedded SVG: {e_save}")
return modified_svg_bytes
except ET.ParseError as e:
print(f"Error parsing SVG XML: {e}")
return None
except Exception as e:
print(f"An unexpected error occurred during SVG manipulation: {e}")
return None
def convert_svg_bytes_to_png_api(svg_content_bytes: bytes, original_template_name: str, index: int):
"""
Converts SVG bytes to PNG bytes using the external API.
Saves the PNG to a temporary file and returns the path.
"""
if not svg_content_bytes:
return None
try:
# The API expects a file upload
files = {'svg_file': ('generated_svg.svg', svg_content_bytes, 'image/svg+xml')}
response = requests.post(SVG2PNG_API_URL, files=files, timeout=60)
response.raise_for_status()
# Use UUID for more robust unique filenames
png_filename = f"{original_template_name.replace(' ','_')}_output_{index}_{uuid.uuid4().hex[:8]}.png"
temp_png_path = TEMP_DIR / png_filename
with open(temp_png_path, "wb") as f:
f.write(response.content)
return str(temp_png_path)
except requests.exceptions.RequestException as e:
print(f"Error converting SVG to PNG via API: {e}")
if hasattr(e, 'response') and e.response is not None:
print(f"API Response: {e.response.text}")
return None
except Exception as e:
print(f"Error saving PNG: {e}")
return None
# --- Gradio App Logic ---
def generate_images_from_template(original_svg_download_url: str, template_name_for_file: str, uploaded_image_files: list, request: gr.Request):
"""
Main processing function for Gradio.
Takes original SVG URL, template name, and list of uploaded image file objects.
"""
if not original_svg_download_url:
gr.Warning("模板信息未加载,请确保URL参数正确。")
return [], None
if not uploaded_image_files:
gr.Warning("请上传至少一张图片。")
return [], None
# 1. Download original SVG content
original_svg_content = download_svg_content(original_svg_download_url)
if not original_svg_content:
gr.Error("无法下载原始SVG模板。")
return [], None
generated_png_paths = []
processed_count = 0
for i, uploaded_file_obj in enumerate(uploaded_image_files):
# uploaded_file_obj.name is the temporary path to the uploaded file
new_image_path = uploaded_file_obj.name
# 2. Replace background in SVG for each uploaded image
modified_svg_bytes = replace_background_in_svg(original_svg_content, new_image_path)
if not modified_svg_bytes:
gr.Warning(f"处理图片 {i+1} 失败:无法修改SVG。")
continue # Skip this image
# 3. Convert modified SVG to PNG
# Use template_name_for_file for unique output filenames
png_path = convert_svg_bytes_to_png_api(modified_svg_bytes, template_name_for_file, i + 1)
if png_path:
generated_png_paths.append(png_path)
processed_count += 1
else:
gr.Warning(f"处理图片 {i+1} 失败:无法转换为PNG。")
if not generated_png_paths:
gr.Info("未能成功生成任何图片。")
return [], None
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
for png_path_str in generated_png_paths:
png_file = pathlib.Path(png_path_str)
zf.write(png_file, arcname=png_file.name)
zip_buffer.seek(0)
zip_filename = f"{template_name_for_file.replace(' ','_')}_batch_{uuid.uuid4().hex[:8]}.zip" # Added UUID
temp_zip_path = TEMP_DIR / zip_filename
with open(temp_zip_path, "wb") as f:
f.write(zip_buffer.getvalue())
gr.Info(f"成功生成 {processed_count} 张图片!")
return generated_png_paths, str(temp_zip_path)
def initial_load_template_info(request: gr.Request):
"""
Loads initial template information based on URL query parameters.
Downloads the thumbnail and returns its local path.
"""
query_params = request.query_params
m = query_params.get("m")
u = query_params.get("u")
category = query_params.get("category")
name = query_params.get("name") # This is the template name
if not all([m, u, category, name]):
print("Initial load: URL parameters (m, u, category, name) are incomplete or missing.")
return None, "无模板信息 (请检查URL参数)", None, "无模板"
details = fetch_template_details(m, u, category, name)
local_thumbnail_path = None
template_display_name = "错误: 无法加载模板信息" # Default error message
svg_download_url = None
if details and "name" in details:
template_display_name = details["name"] # Use name from details for display
svg_download_url = details.get("svg_download_url")
if details.get("thumbnail_url"):
try:
thumb_url = details["thumbnail_url"]
# thumb_url is already fixed in fetch_template_details
print(f"Fetching thumbnail from: {thumb_url}")
thumb_response = requests.get(thumb_url, timeout=10)
thumb_response.raise_for_status()
# Create a unique filename for the thumbnail in TEMP_DIR
# Use name from URL params for file naming consistency if needed, or details['name']
safe_template_name_for_file = name.replace(' ','_').replace('/','_').replace('\\\\','_') # Basic sanitization
thumb_filename = f"thumb_{safe_template_name_for_file}_{uuid.uuid4().hex[:8]}.png"
local_thumbnail_path = TEMP_DIR / thumb_filename
with open(local_thumbnail_path, "wb") as f:
f.write(thumb_response.content)
print(f"Thumbnail saved to: {local_thumbnail_path}")
except requests.exceptions.RequestException as e:
print(f"Error downloading thumbnail from {details.get('thumbnail_url')}: {e}")
local_thumbnail_path = None
except Exception as e:
print(f"Error saving thumbnail: {e}")
local_thumbnail_path = None
else:
print(f"Failed to load template details for m={m}, u={u}, category={category}, name={name}")
# template_display_name is already set to the default error message
return str(local_thumbnail_path) if local_thumbnail_path else None, \
template_display_name, \
svg_download_url, \
name # 'name' from URL param for template_name_for_file_state
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="SVG模板批量图片生成器") as demo:
gr.Markdown("## 使用SVG模板批量生成图片")
gr.Markdown("从URL加载模板,上传您自己的图片替换模板中的底图,然后批量生成并下载。")
# Hidden state to store original SVG download URL and template name for file ops
original_svg_download_url_state = gr.State()
template_name_for_file_state = gr.State() # To preserve the name from URL param for consistent file naming
with gr.Row():
with gr.Column(scale=1, min_width=200):
template_thumbnail_display = gr.Image(label="当前模板缩略图", interactive=False, height=200, type="filepath") # Ensure type is filepath
template_name_display = gr.Textbox(label="当前模板名称", interactive=False)
with gr.Column(scale=3):
uploaded_images_input = gr.Files(
label="上传您的图片 (可多选)",
file_count="multiple",
file_types=["image"] # Accepts .png, .jpg, .jpeg, .gif, .webp etc.
)
generate_button = gr.Button("🚀 立即生成图片", variant="primary", scale=1)
with gr.Accordion("生成结果预览与下载", open=True):
output_gallery = gr.Gallery(
label="生成图片预览",
show_label=True,
elem_id="output_gallery",
columns=[4],
object_fit="contain",
height="auto"
# type="filepath" is default for Gallery if fed filepaths
)
output_zip_file = gr.File(label="下载所有生成图片的ZIP包", interactive=False, type="filepath") # Ensure type is filepath
# Load initial template info based on URL parameters when the interface loads.
# The `initial_load_template_info` function will parse request.query_params.
# It's crucial that `gr.Request` is passed to it.
# `inputs=None` with `request: gr.Request` in function signature works.
demo.load(
initial_load_template_info,
inputs=None, # gr.Request is implicitly passed if type-hinted in the function
outputs=[
template_thumbnail_display,
template_name_display,
original_svg_download_url_state,
template_name_for_file_state
]
)
generate_button.click(
generate_images_from_template,
inputs=[
original_svg_download_url_state,
template_name_for_file_state,
uploaded_images_input
# gr.Request is also implicitly passed here
],
outputs=[output_gallery, output_zip_file]
)
if __name__ == "__main__":
# To run this app:
# 1. Ensure your template_server.py is running (e.g., on http://localhost:8001)
# and has the /template_details/{username}/{category}/{template_name} endpoint.
# 2. Run this script: python use_template_app.py
# 3. Open your browser to the Gradio link, appending parameters, e.g.:
# http://127.0.0.1:7860/?m=pub&u=all&category=%E5%AE%B6%E7%94%B5&name=%E6%A8%A1%E6%9D%BF_%E5%86%B0%E7%82%B9%E4%BB%B7
# For easy testing, you might want to create a dummy template_server.py endpoint
# or hardcode some details if the server isn't ready.
demo.launch()
# Optional: Cleanup TEMP_DIR logic can be added here if needed for long-running servers
# For development, manual cleanup or OS temp cleaning is often sufficient.
# Example:
# importate atexit
# def cleanup_temp_dir():
# if TEMP_DIR.exists():
# print(f"Cleaning up temp directory: {TEMP_DIR}")
# shutil.rmtree(TEMP_DIR)
# atexit.register(cleanup_temp_dir) # This might be too aggressive for dev
|