svg2png-templates-public / use_template_app.py
innoai's picture
Update use_template_app.py
b53bf98 verified
import gradio as gr
import requests
import pathlib
import zipfile
import io
import urllib.parse # For URL sanitization
import xml.etree.ElementTree as ET
import base64
from PIL import Image # For handling image types if needed by Gradio gallery
import time
import shutil
import uuid # For unique filenames
import os
# --- Configuration ---
TEMPLATE_API_BASE_URL = os.getenv("mubanapi")
SVG2PNG_API_URL = os.getenv("convertapi")
TEMP_DIR = pathlib.Path("temp_gradio_app_files") # Renamed for clarity
# Ensure temp directory exists
TEMP_DIR.mkdir(parents=True, exist_ok=True) # No need to clean it aggressively on each start
# --- Helper Functions ---
def fix_internal_urls(url_string: str) -> str:
"""
Replaces internal IP address URLs (like http://10.10.71.201:8002) with the external API base URL.
This makes internal URLs accessible externally.
"""
if not url_string or not isinstance(url_string, str):
return url_string
try:
# Parse the URL
parts = urllib.parse.urlsplit(url_string)
# Check if the netloc (domain/IP) looks like an internal address
# This covers internal IPs (10.x.x.x, 172.16-31.x.x, 192.168.x.x) and localhost
internal_patterns = ['10.', '172.16.', '172.17.', '172.18.', '172.19.', '172.2', '172.3', '192.168.', 'localhost', '127.0.0.1']
is_internal = any(parts.netloc.startswith(pattern) for pattern in internal_patterns)
if is_internal and parts.scheme in ['http', 'https']:
# Keep the path and query components, but replace the domain with TEMPLATE_API_BASE_URL
# First, get the external base URL parts
base_parts = urllib.parse.urlsplit(TEMPLATE_API_BASE_URL)
# Create new URL with base URL's scheme and netloc, but original path and query
new_url = urllib.parse.urlunsplit((
base_parts.scheme,
base_parts.netloc,
parts.path,
parts.query,
parts.fragment
))
print(f"URL transformed: {url_string} -> {new_url}")
return new_url
return url_string
except Exception as e:
print(f"Error fixing internal URL '{url_string}': {e}")
return url_string
# --- Helper Functions ---
def sanitize_url(url_string: str) -> str:
"""
Safely percent-encodes the path and query components of an HTTP/HTTPS URL.
Attempts to be idempotent if parts of the URL are already percent-encoded.
Parentheses in the path are preserved.
"""
if not url_string or not isinstance(url_string, str) or url_string.startswith('data:'):
return url_string
try:
scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url_string)
if scheme not in ['http', 'https']:
# For other schemes, or if splitting fails to identify a scheme, return original.
# This could happen for URNs, mailto, etc., which we don't want to modify.
return url_string
# Sanitize path: unquote to normalize, then quote.
# `safe` keeps slashes, existing valid percent-encodings (%), and specified chars like parentheses.
sanitized_path = urllib.parse.quote(urllib.parse.unquote(path), safe='/%;:@&+$,=?#()~')
sanitized_query = ""
if query:
# Unquote the original query to handle cases where it might be partially encoded
unquoted_query = urllib.parse.unquote(query)
# Parse the unquoted query string into a dictionary of parameters
parsed_query_params = urllib.parse.parse_qs(unquoted_query, keep_blank_values=True)
# urlencode will quote keys and values appropriately for the query string.
sanitized_query = urllib.parse.urlencode(parsed_query_params, doseq=True, quote_via=urllib.parse.quote)
final_url = urllib.parse.urlunsplit((scheme, netloc, sanitized_path, sanitized_query, fragment))
# print(f"Sanitize URL Input: {url_string}, Output: {final_url}") # Debugging line
return final_url
except Exception as e:
print(f"Warning: Could not sanitize URL '{url_string}': {e}")
return url_string
def fetch_template_details(m_param, u_param, category_param, name_param):
"""
Fetches template details from the template_server.py API.
Expected to return a dict with 'name', 'thumbnail_url', 'svg_download_url'.
"""
username = urllib.parse.quote(u_param)
category = urllib.parse.quote(category_param)
template_name = urllib.parse.quote(name_param)
api_url = f"{TEMPLATE_API_BASE_URL}/gradio_template_details/{username}/{category}/{template_name}"
print(f"Fetching template details from: {api_url}")
try:
response = requests.get(api_url, timeout=10)
response.raise_for_status()
template_data = response.json()
# Fix internal URLs in the template details
if isinstance(template_data, dict):
if "thumbnail_url" in template_data:
template_data["thumbnail_url"] = fix_internal_urls(template_data["thumbnail_url"])
if "svg_download_url" in template_data:
template_data["svg_download_url"] = fix_internal_urls(template_data["svg_download_url"])
# Handle any other URLs that might be in the response
for key, value in template_data.items():
if isinstance(value, str) and (value.startswith("http://") or value.startswith("https://")):
template_data[key] = fix_internal_urls(value)
return template_data
except requests.exceptions.RequestException as e:
print(f"Error fetching template details from {api_url}: {e}")
return None
def download_svg_content(svg_url):
"""Downloads the raw SVG content from a given URL."""
try:
# Fix any internal URLs before downloading
fixed_url = fix_internal_urls(svg_url)
response = requests.get(fixed_url, timeout=10)
response.raise_for_status()
return response.text # SVG is XML text
except requests.exceptions.RequestException as e:
print(f"Error downloading SVG content from {svg_url}: {e}")
return None
def replace_background_in_svg(svg_content_str: str, new_image_path: str):
"""
Replaces a designated image in the SVG string with a base64 data URI.
Also converts all other external http/https image URLs in the SVG to base64 data URIs.
"""
try:
namespaces = {'xlink': 'http://www.w3.org/1999/xlink', 'svg': 'http://www.w3.org/2000/svg'}
for prefix, uri in namespaces.items():
ET.register_namespace(prefix, uri if prefix else '')
root = ET.fromstring(svg_content_str)
def _image_bytes_to_data_uri(image_bytes: bytes, resource_identifier: str) -> str:
"""
Converts image bytes to a data URI.
mime_type_source can be a full MIME type string, a file path, or a URL.
"""
mime_type = "image/png" # Default
if "/" in resource_identifier: # Check if it looks like a MIME type string or a URL/path with extension
# Try to get from Content-Type header if it's a requests.Response object (not applicable here directly)
# For now, infer from path/URL extension or if it's already a MIME type string
if resource_identifier.startswith("image/"): # Already a MIME type
mime_type = resource_identifier
else: # Infer from extension
ext = pathlib.Path(resource_identifier).suffix.lower()
if ext == ".png": mime_type = "image/png"
elif ext in [".jpg", ".jpeg"]: mime_type = "image/jpeg"
elif ext == ".gif": mime_type = "image/gif"
elif ext == ".webp": mime_type = "image/webp"
elif ext == ".svg": mime_type = "image/svg+xml"
encoded_image = base64.b64encode(image_bytes).decode('utf-8')
return f"data:{mime_type};base64,{encoded_image}"
# 1. Handle the user-uploaded image (new_image_path)
background_image_element = None
ids_to_check = ["background_image", "placeholder_image", "product_image", "main_image", "image_to_replace", "底图"]
for img_id in ids_to_check:
found_element = root.find(f".//svg:image[@id='{img_id}']", namespaces)
if found_element is None:
found_element = root.find(f".//{{*}}image[@id='{img_id}']")
if found_element is not None:
background_image_element = found_element
break
if background_image_element is None:
first_image_svg_ns = root.find('.//svg:image', namespaces)
if first_image_svg_ns is not None:
background_image_element = first_image_svg_ns
else:
first_image_any_ns = root.find('.//image')
if first_image_any_ns is not None:
background_image_element = first_image_any_ns
else:
print("Warning: No suitable <image> tag found in SVG to replace with new_image_path.")
if background_image_element is not None:
try:
new_image_path_obj = pathlib.Path(new_image_path)
if not new_image_path_obj.is_file():
print(f"Error: New image file not found or is not a file: {new_image_path}")
# Potentially return None or raise error, for now, it will skip this replacement
else:
with open(new_image_path_obj, "rb") as f:
new_image_bytes = f.read()
data_uri = _image_bytes_to_data_uri(new_image_bytes, new_image_path) # Pass path for MIME
xlink_href_attr_namespaced = f"{{{namespaces['xlink']}}}href"
attr_to_set_on_target = None
if xlink_href_attr_namespaced in background_image_element.attrib:
attr_to_set_on_target = xlink_href_attr_namespaced
elif 'href' in background_image_element.attrib:
attr_to_set_on_target = 'href'
else:
attr_to_set_on_target = xlink_href_attr_namespaced if 'xlink' in namespaces else 'href'
background_image_element.set(attr_to_set_on_target, data_uri)
except Exception as e:
print(f"Error processing and embedding new_image_path '{new_image_path}': {e}")
# 2. Process all other http/https images in the SVG
if root is not None:
# Create a list of all image elements to iterate over.
# This avoids issues with modifying the tree while iterating over findall() results directly.
all_image_elements_in_tree = list(root.findall('.//svg:image', namespaces)) + \
[el for el in root.findall('.//image') if el.tag != f"{{{namespaces['svg']}}}image"]
for image_el in all_image_elements_in_tree:
if image_el == background_image_element: # Already processed (or attempted)
continue
href_val = None
attr_name_to_set = None
xlink_href_qname = f"{{{namespaces['xlink']}}}href"
if image_el.get(xlink_href_qname) is not None:
href_val = image_el.get(xlink_href_qname)
attr_name_to_set = xlink_href_qname
elif image_el.get('href') is not None:
href_val = image_el.get('href')
attr_name_to_set = 'href'
if href_val and href_val.startswith(('http://', 'https://')):
try:
safe_download_url = sanitize_url(href_val) # Sanitize before download
# print(f"DEBUG: Downloading remote image for embedding: {safe_download_url}")
response = requests.get(safe_download_url, timeout=20, stream=True)
response.raise_for_status()
image_bytes = response.content
content_type_header = response.headers.get('Content-Type')
mime_type_source = content_type_header.split(';')[0].strip() if content_type_header and '/' in content_type_header else href_val
data_uri = _image_bytes_to_data_uri(image_bytes, mime_type_source)
image_el.set(attr_name_to_set, data_uri)
# print(f"DEBUG: Replaced remote URL {href_val} with data URI.")
except requests.exceptions.RequestException as e_req:
print(f"Failed to download or process remote image {href_val}: {e_req}")
except Exception as e_proc:
print(f"Error processing remote image {href_val} into data URI: {e_proc}")
if root is None:
print("Error: SVG root is None after parsing, cannot proceed.")
return None
modified_svg_bytes = ET.tostring(root, encoding='UTF-8', method='xml', xml_declaration=True)
try:
debug_svg_filename = f"debug_embedded_svg_{uuid.uuid4().hex[:8]}.svg"
debug_svg_path = TEMP_DIR / debug_svg_filename
with open(debug_svg_path, "wb") as f_debug:
f_debug.write(modified_svg_bytes)
print(f"Saved SVG with all images embedded for debugging to: {debug_svg_path}")
except Exception as e_save:
print(f"Error saving debug embedded SVG: {e_save}")
return modified_svg_bytes
except ET.ParseError as e:
print(f"Error parsing SVG XML: {e}")
return None
except Exception as e:
print(f"An unexpected error occurred during SVG manipulation: {e}")
return None
def convert_svg_bytes_to_png_api(svg_content_bytes: bytes, original_template_name: str, index: int):
"""
Converts SVG bytes to PNG bytes using the external API.
Saves the PNG to a temporary file and returns the path.
"""
if not svg_content_bytes:
return None
try:
# The API expects a file upload
files = {'svg_file': ('generated_svg.svg', svg_content_bytes, 'image/svg+xml')}
response = requests.post(SVG2PNG_API_URL, files=files, timeout=60)
response.raise_for_status()
# Use UUID for more robust unique filenames
png_filename = f"{original_template_name.replace(' ','_')}_output_{index}_{uuid.uuid4().hex[:8]}.png"
temp_png_path = TEMP_DIR / png_filename
with open(temp_png_path, "wb") as f:
f.write(response.content)
return str(temp_png_path)
except requests.exceptions.RequestException as e:
print(f"Error converting SVG to PNG via API: {e}")
if hasattr(e, 'response') and e.response is not None:
print(f"API Response: {e.response.text}")
return None
except Exception as e:
print(f"Error saving PNG: {e}")
return None
# --- Gradio App Logic ---
def generate_images_from_template(original_svg_download_url: str, template_name_for_file: str, uploaded_image_files: list, request: gr.Request):
"""
Main processing function for Gradio.
Takes original SVG URL, template name, and list of uploaded image file objects.
"""
if not original_svg_download_url:
gr.Warning("模板信息未加载,请确保URL参数正确。")
return [], None
if not uploaded_image_files:
gr.Warning("请上传至少一张图片。")
return [], None
# 1. Download original SVG content
original_svg_content = download_svg_content(original_svg_download_url)
if not original_svg_content:
gr.Error("无法下载原始SVG模板。")
return [], None
generated_png_paths = []
processed_count = 0
for i, uploaded_file_obj in enumerate(uploaded_image_files):
# uploaded_file_obj.name is the temporary path to the uploaded file
new_image_path = uploaded_file_obj.name
# 2. Replace background in SVG for each uploaded image
modified_svg_bytes = replace_background_in_svg(original_svg_content, new_image_path)
if not modified_svg_bytes:
gr.Warning(f"处理图片 {i+1} 失败:无法修改SVG。")
continue # Skip this image
# 3. Convert modified SVG to PNG
# Use template_name_for_file for unique output filenames
png_path = convert_svg_bytes_to_png_api(modified_svg_bytes, template_name_for_file, i + 1)
if png_path:
generated_png_paths.append(png_path)
processed_count += 1
else:
gr.Warning(f"处理图片 {i+1} 失败:无法转换为PNG。")
if not generated_png_paths:
gr.Info("未能成功生成任何图片。")
return [], None
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
for png_path_str in generated_png_paths:
png_file = pathlib.Path(png_path_str)
zf.write(png_file, arcname=png_file.name)
zip_buffer.seek(0)
zip_filename = f"{template_name_for_file.replace(' ','_')}_batch_{uuid.uuid4().hex[:8]}.zip" # Added UUID
temp_zip_path = TEMP_DIR / zip_filename
with open(temp_zip_path, "wb") as f:
f.write(zip_buffer.getvalue())
gr.Info(f"成功生成 {processed_count} 张图片!")
return generated_png_paths, str(temp_zip_path)
def initial_load_template_info(request: gr.Request):
"""
Loads initial template information based on URL query parameters.
Downloads the thumbnail and returns its local path.
"""
query_params = request.query_params
m = query_params.get("m")
u = query_params.get("u")
category = query_params.get("category")
name = query_params.get("name") # This is the template name
if not all([m, u, category, name]):
print("Initial load: URL parameters (m, u, category, name) are incomplete or missing.")
return None, "无模板信息 (请检查URL参数)", None, "无模板"
details = fetch_template_details(m, u, category, name)
local_thumbnail_path = None
template_display_name = "错误: 无法加载模板信息" # Default error message
svg_download_url = None
if details and "name" in details:
template_display_name = details["name"] # Use name from details for display
svg_download_url = details.get("svg_download_url")
if details.get("thumbnail_url"):
try:
thumb_url = details["thumbnail_url"]
# thumb_url is already fixed in fetch_template_details
print(f"Fetching thumbnail from: {thumb_url}")
thumb_response = requests.get(thumb_url, timeout=10)
thumb_response.raise_for_status()
# Create a unique filename for the thumbnail in TEMP_DIR
# Use name from URL params for file naming consistency if needed, or details['name']
safe_template_name_for_file = name.replace(' ','_').replace('/','_').replace('\\\\','_') # Basic sanitization
thumb_filename = f"thumb_{safe_template_name_for_file}_{uuid.uuid4().hex[:8]}.png"
local_thumbnail_path = TEMP_DIR / thumb_filename
with open(local_thumbnail_path, "wb") as f:
f.write(thumb_response.content)
print(f"Thumbnail saved to: {local_thumbnail_path}")
except requests.exceptions.RequestException as e:
print(f"Error downloading thumbnail from {details.get('thumbnail_url')}: {e}")
local_thumbnail_path = None
except Exception as e:
print(f"Error saving thumbnail: {e}")
local_thumbnail_path = None
else:
print(f"Failed to load template details for m={m}, u={u}, category={category}, name={name}")
# template_display_name is already set to the default error message
return str(local_thumbnail_path) if local_thumbnail_path else None, \
template_display_name, \
svg_download_url, \
name # 'name' from URL param for template_name_for_file_state
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="SVG模板批量图片生成器") as demo:
gr.Markdown("## 使用SVG模板批量生成图片")
gr.Markdown("从URL加载模板,上传您自己的图片替换模板中的底图,然后批量生成并下载。")
# Hidden state to store original SVG download URL and template name for file ops
original_svg_download_url_state = gr.State()
template_name_for_file_state = gr.State() # To preserve the name from URL param for consistent file naming
with gr.Row():
with gr.Column(scale=1, min_width=200):
template_thumbnail_display = gr.Image(label="当前模板缩略图", interactive=False, height=200, type="filepath") # Ensure type is filepath
template_name_display = gr.Textbox(label="当前模板名称", interactive=False)
with gr.Column(scale=3):
uploaded_images_input = gr.Files(
label="上传您的图片 (可多选)",
file_count="multiple",
file_types=["image"] # Accepts .png, .jpg, .jpeg, .gif, .webp etc.
)
generate_button = gr.Button("🚀 立即生成图片", variant="primary", scale=1)
with gr.Accordion("生成结果预览与下载", open=True):
output_gallery = gr.Gallery(
label="生成图片预览",
show_label=True,
elem_id="output_gallery",
columns=[4],
object_fit="contain",
height="auto"
# type="filepath" is default for Gallery if fed filepaths
)
output_zip_file = gr.File(label="下载所有生成图片的ZIP包", interactive=False, type="filepath") # Ensure type is filepath
# Load initial template info based on URL parameters when the interface loads.
# The `initial_load_template_info` function will parse request.query_params.
# It's crucial that `gr.Request` is passed to it.
# `inputs=None` with `request: gr.Request` in function signature works.
demo.load(
initial_load_template_info,
inputs=None, # gr.Request is implicitly passed if type-hinted in the function
outputs=[
template_thumbnail_display,
template_name_display,
original_svg_download_url_state,
template_name_for_file_state
]
)
generate_button.click(
generate_images_from_template,
inputs=[
original_svg_download_url_state,
template_name_for_file_state,
uploaded_images_input
# gr.Request is also implicitly passed here
],
outputs=[output_gallery, output_zip_file]
)
if __name__ == "__main__":
# To run this app:
# 1. Ensure your template_server.py is running (e.g., on http://localhost:8001)
# and has the /template_details/{username}/{category}/{template_name} endpoint.
# 2. Run this script: python use_template_app.py
# 3. Open your browser to the Gradio link, appending parameters, e.g.:
# http://127.0.0.1:7860/?m=pub&u=all&category=%E5%AE%B6%E7%94%B5&name=%E6%A8%A1%E6%9D%BF_%E5%86%B0%E7%82%B9%E4%BB%B7
# For easy testing, you might want to create a dummy template_server.py endpoint
# or hardcode some details if the server isn't ready.
demo.launch()
# Optional: Cleanup TEMP_DIR logic can be added here if needed for long-running servers
# For development, manual cleanup or OS temp cleaning is often sufficient.
# Example:
# importate atexit
# def cleanup_temp_dir():
# if TEMP_DIR.exists():
# print(f"Cleaning up temp directory: {TEMP_DIR}")
# shutil.rmtree(TEMP_DIR)
# atexit.register(cleanup_temp_dir) # This might be too aggressive for dev