Spaces:
Running
Running
import gradio as gr | |
import requests | |
import pathlib | |
import zipfile | |
import io | |
import urllib.parse # For URL sanitization | |
import xml.etree.ElementTree as ET | |
import base64 | |
from PIL import Image # For handling image types if needed by Gradio gallery | |
import time | |
import shutil | |
import uuid # For unique filenames | |
import os | |
# --- Configuration --- | |
TEMPLATE_API_BASE_URL = os.getenv("mubanapi") | |
SVG2PNG_API_URL = os.getenv("convertapi") | |
TEMP_DIR = pathlib.Path("temp_gradio_app_files") # Renamed for clarity | |
# Ensure temp directory exists | |
TEMP_DIR.mkdir(parents=True, exist_ok=True) # No need to clean it aggressively on each start | |
# --- Helper Functions --- | |
def fix_internal_urls(url_string: str) -> str: | |
""" | |
Replaces internal IP address URLs (like http://10.10.71.201:8002) with the external API base URL. | |
This makes internal URLs accessible externally. | |
""" | |
if not url_string or not isinstance(url_string, str): | |
return url_string | |
try: | |
# Parse the URL | |
parts = urllib.parse.urlsplit(url_string) | |
# Check if the netloc (domain/IP) looks like an internal address | |
# This covers internal IPs (10.x.x.x, 172.16-31.x.x, 192.168.x.x) and localhost | |
internal_patterns = ['10.', '172.16.', '172.17.', '172.18.', '172.19.', '172.2', '172.3', '192.168.', 'localhost', '127.0.0.1'] | |
is_internal = any(parts.netloc.startswith(pattern) for pattern in internal_patterns) | |
if is_internal and parts.scheme in ['http', 'https']: | |
# Keep the path and query components, but replace the domain with TEMPLATE_API_BASE_URL | |
# First, get the external base URL parts | |
base_parts = urllib.parse.urlsplit(TEMPLATE_API_BASE_URL) | |
# Create new URL with base URL's scheme and netloc, but original path and query | |
new_url = urllib.parse.urlunsplit(( | |
base_parts.scheme, | |
base_parts.netloc, | |
parts.path, | |
parts.query, | |
parts.fragment | |
)) | |
print(f"URL transformed: {url_string} -> {new_url}") | |
return new_url | |
return url_string | |
except Exception as e: | |
print(f"Error fixing internal URL '{url_string}': {e}") | |
return url_string | |
# --- Helper Functions --- | |
def sanitize_url(url_string: str) -> str: | |
""" | |
Safely percent-encodes the path and query components of an HTTP/HTTPS URL. | |
Attempts to be idempotent if parts of the URL are already percent-encoded. | |
Parentheses in the path are preserved. | |
""" | |
if not url_string or not isinstance(url_string, str) or url_string.startswith('data:'): | |
return url_string | |
try: | |
scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url_string) | |
if scheme not in ['http', 'https']: | |
# For other schemes, or if splitting fails to identify a scheme, return original. | |
# This could happen for URNs, mailto, etc., which we don't want to modify. | |
return url_string | |
# Sanitize path: unquote to normalize, then quote. | |
# `safe` keeps slashes, existing valid percent-encodings (%), and specified chars like parentheses. | |
sanitized_path = urllib.parse.quote(urllib.parse.unquote(path), safe='/%;:@&+$,=?#()~') | |
sanitized_query = "" | |
if query: | |
# Unquote the original query to handle cases where it might be partially encoded | |
unquoted_query = urllib.parse.unquote(query) | |
# Parse the unquoted query string into a dictionary of parameters | |
parsed_query_params = urllib.parse.parse_qs(unquoted_query, keep_blank_values=True) | |
# urlencode will quote keys and values appropriately for the query string. | |
sanitized_query = urllib.parse.urlencode(parsed_query_params, doseq=True, quote_via=urllib.parse.quote) | |
final_url = urllib.parse.urlunsplit((scheme, netloc, sanitized_path, sanitized_query, fragment)) | |
# print(f"Sanitize URL Input: {url_string}, Output: {final_url}") # Debugging line | |
return final_url | |
except Exception as e: | |
print(f"Warning: Could not sanitize URL '{url_string}': {e}") | |
return url_string | |
def fetch_template_details(m_param, u_param, category_param, name_param): | |
""" | |
Fetches template details from the template_server.py API. | |
Expected to return a dict with 'name', 'thumbnail_url', 'svg_download_url'. | |
""" | |
username = urllib.parse.quote(u_param) | |
category = urllib.parse.quote(category_param) | |
template_name = urllib.parse.quote(name_param) | |
api_url = f"{TEMPLATE_API_BASE_URL}/gradio_template_details/{username}/{category}/{template_name}" | |
print(f"Fetching template details from: {api_url}") | |
try: | |
response = requests.get(api_url, timeout=10) | |
response.raise_for_status() | |
template_data = response.json() | |
# Fix internal URLs in the template details | |
if isinstance(template_data, dict): | |
if "thumbnail_url" in template_data: | |
template_data["thumbnail_url"] = fix_internal_urls(template_data["thumbnail_url"]) | |
if "svg_download_url" in template_data: | |
template_data["svg_download_url"] = fix_internal_urls(template_data["svg_download_url"]) | |
# Handle any other URLs that might be in the response | |
for key, value in template_data.items(): | |
if isinstance(value, str) and (value.startswith("http://") or value.startswith("https://")): | |
template_data[key] = fix_internal_urls(value) | |
return template_data | |
except requests.exceptions.RequestException as e: | |
print(f"Error fetching template details from {api_url}: {e}") | |
return None | |
def download_svg_content(svg_url): | |
"""Downloads the raw SVG content from a given URL.""" | |
try: | |
# Fix any internal URLs before downloading | |
fixed_url = fix_internal_urls(svg_url) | |
response = requests.get(fixed_url, timeout=10) | |
response.raise_for_status() | |
return response.text # SVG is XML text | |
except requests.exceptions.RequestException as e: | |
print(f"Error downloading SVG content from {svg_url}: {e}") | |
return None | |
def replace_background_in_svg(svg_content_str: str, new_image_path: str): | |
""" | |
Replaces a designated image in the SVG string with a base64 data URI. | |
Also converts all other external http/https image URLs in the SVG to base64 data URIs. | |
""" | |
try: | |
namespaces = {'xlink': 'http://www.w3.org/1999/xlink', 'svg': 'http://www.w3.org/2000/svg'} | |
for prefix, uri in namespaces.items(): | |
ET.register_namespace(prefix, uri if prefix else '') | |
root = ET.fromstring(svg_content_str) | |
def _image_bytes_to_data_uri(image_bytes: bytes, resource_identifier: str) -> str: | |
""" | |
Converts image bytes to a data URI. | |
mime_type_source can be a full MIME type string, a file path, or a URL. | |
""" | |
mime_type = "image/png" # Default | |
if "/" in resource_identifier: # Check if it looks like a MIME type string or a URL/path with extension | |
# Try to get from Content-Type header if it's a requests.Response object (not applicable here directly) | |
# For now, infer from path/URL extension or if it's already a MIME type string | |
if resource_identifier.startswith("image/"): # Already a MIME type | |
mime_type = resource_identifier | |
else: # Infer from extension | |
ext = pathlib.Path(resource_identifier).suffix.lower() | |
if ext == ".png": mime_type = "image/png" | |
elif ext in [".jpg", ".jpeg"]: mime_type = "image/jpeg" | |
elif ext == ".gif": mime_type = "image/gif" | |
elif ext == ".webp": mime_type = "image/webp" | |
elif ext == ".svg": mime_type = "image/svg+xml" | |
encoded_image = base64.b64encode(image_bytes).decode('utf-8') | |
return f"data:{mime_type};base64,{encoded_image}" | |
# 1. Handle the user-uploaded image (new_image_path) | |
background_image_element = None | |
ids_to_check = ["background_image", "placeholder_image", "product_image", "main_image", "image_to_replace", "底图"] | |
for img_id in ids_to_check: | |
found_element = root.find(f".//svg:image[@id='{img_id}']", namespaces) | |
if found_element is None: | |
found_element = root.find(f".//{{*}}image[@id='{img_id}']") | |
if found_element is not None: | |
background_image_element = found_element | |
break | |
if background_image_element is None: | |
first_image_svg_ns = root.find('.//svg:image', namespaces) | |
if first_image_svg_ns is not None: | |
background_image_element = first_image_svg_ns | |
else: | |
first_image_any_ns = root.find('.//image') | |
if first_image_any_ns is not None: | |
background_image_element = first_image_any_ns | |
else: | |
print("Warning: No suitable <image> tag found in SVG to replace with new_image_path.") | |
if background_image_element is not None: | |
try: | |
new_image_path_obj = pathlib.Path(new_image_path) | |
if not new_image_path_obj.is_file(): | |
print(f"Error: New image file not found or is not a file: {new_image_path}") | |
# Potentially return None or raise error, for now, it will skip this replacement | |
else: | |
with open(new_image_path_obj, "rb") as f: | |
new_image_bytes = f.read() | |
data_uri = _image_bytes_to_data_uri(new_image_bytes, new_image_path) # Pass path for MIME | |
xlink_href_attr_namespaced = f"{{{namespaces['xlink']}}}href" | |
attr_to_set_on_target = None | |
if xlink_href_attr_namespaced in background_image_element.attrib: | |
attr_to_set_on_target = xlink_href_attr_namespaced | |
elif 'href' in background_image_element.attrib: | |
attr_to_set_on_target = 'href' | |
else: | |
attr_to_set_on_target = xlink_href_attr_namespaced if 'xlink' in namespaces else 'href' | |
background_image_element.set(attr_to_set_on_target, data_uri) | |
except Exception as e: | |
print(f"Error processing and embedding new_image_path '{new_image_path}': {e}") | |
# 2. Process all other http/https images in the SVG | |
if root is not None: | |
# Create a list of all image elements to iterate over. | |
# This avoids issues with modifying the tree while iterating over findall() results directly. | |
all_image_elements_in_tree = list(root.findall('.//svg:image', namespaces)) + \ | |
[el for el in root.findall('.//image') if el.tag != f"{{{namespaces['svg']}}}image"] | |
for image_el in all_image_elements_in_tree: | |
if image_el == background_image_element: # Already processed (or attempted) | |
continue | |
href_val = None | |
attr_name_to_set = None | |
xlink_href_qname = f"{{{namespaces['xlink']}}}href" | |
if image_el.get(xlink_href_qname) is not None: | |
href_val = image_el.get(xlink_href_qname) | |
attr_name_to_set = xlink_href_qname | |
elif image_el.get('href') is not None: | |
href_val = image_el.get('href') | |
attr_name_to_set = 'href' | |
if href_val and href_val.startswith(('http://', 'https://')): | |
try: | |
safe_download_url = sanitize_url(href_val) # Sanitize before download | |
# print(f"DEBUG: Downloading remote image for embedding: {safe_download_url}") | |
response = requests.get(safe_download_url, timeout=20, stream=True) | |
response.raise_for_status() | |
image_bytes = response.content | |
content_type_header = response.headers.get('Content-Type') | |
mime_type_source = content_type_header.split(';')[0].strip() if content_type_header and '/' in content_type_header else href_val | |
data_uri = _image_bytes_to_data_uri(image_bytes, mime_type_source) | |
image_el.set(attr_name_to_set, data_uri) | |
# print(f"DEBUG: Replaced remote URL {href_val} with data URI.") | |
except requests.exceptions.RequestException as e_req: | |
print(f"Failed to download or process remote image {href_val}: {e_req}") | |
except Exception as e_proc: | |
print(f"Error processing remote image {href_val} into data URI: {e_proc}") | |
if root is None: | |
print("Error: SVG root is None after parsing, cannot proceed.") | |
return None | |
modified_svg_bytes = ET.tostring(root, encoding='UTF-8', method='xml', xml_declaration=True) | |
try: | |
debug_svg_filename = f"debug_embedded_svg_{uuid.uuid4().hex[:8]}.svg" | |
debug_svg_path = TEMP_DIR / debug_svg_filename | |
with open(debug_svg_path, "wb") as f_debug: | |
f_debug.write(modified_svg_bytes) | |
print(f"Saved SVG with all images embedded for debugging to: {debug_svg_path}") | |
except Exception as e_save: | |
print(f"Error saving debug embedded SVG: {e_save}") | |
return modified_svg_bytes | |
except ET.ParseError as e: | |
print(f"Error parsing SVG XML: {e}") | |
return None | |
except Exception as e: | |
print(f"An unexpected error occurred during SVG manipulation: {e}") | |
return None | |
def convert_svg_bytes_to_png_api(svg_content_bytes: bytes, original_template_name: str, index: int): | |
""" | |
Converts SVG bytes to PNG bytes using the external API. | |
Saves the PNG to a temporary file and returns the path. | |
""" | |
if not svg_content_bytes: | |
return None | |
try: | |
# The API expects a file upload | |
files = {'svg_file': ('generated_svg.svg', svg_content_bytes, 'image/svg+xml')} | |
response = requests.post(SVG2PNG_API_URL, files=files, timeout=60) | |
response.raise_for_status() | |
# Use UUID for more robust unique filenames | |
png_filename = f"{original_template_name.replace(' ','_')}_output_{index}_{uuid.uuid4().hex[:8]}.png" | |
temp_png_path = TEMP_DIR / png_filename | |
with open(temp_png_path, "wb") as f: | |
f.write(response.content) | |
return str(temp_png_path) | |
except requests.exceptions.RequestException as e: | |
print(f"Error converting SVG to PNG via API: {e}") | |
if hasattr(e, 'response') and e.response is not None: | |
print(f"API Response: {e.response.text}") | |
return None | |
except Exception as e: | |
print(f"Error saving PNG: {e}") | |
return None | |
# --- Gradio App Logic --- | |
def generate_images_from_template(original_svg_download_url: str, template_name_for_file: str, uploaded_image_files: list, request: gr.Request): | |
""" | |
Main processing function for Gradio. | |
Takes original SVG URL, template name, and list of uploaded image file objects. | |
""" | |
if not original_svg_download_url: | |
gr.Warning("模板信息未加载,请确保URL参数正确。") | |
return [], None | |
if not uploaded_image_files: | |
gr.Warning("请上传至少一张图片。") | |
return [], None | |
# 1. Download original SVG content | |
original_svg_content = download_svg_content(original_svg_download_url) | |
if not original_svg_content: | |
gr.Error("无法下载原始SVG模板。") | |
return [], None | |
generated_png_paths = [] | |
processed_count = 0 | |
for i, uploaded_file_obj in enumerate(uploaded_image_files): | |
# uploaded_file_obj.name is the temporary path to the uploaded file | |
new_image_path = uploaded_file_obj.name | |
# 2. Replace background in SVG for each uploaded image | |
modified_svg_bytes = replace_background_in_svg(original_svg_content, new_image_path) | |
if not modified_svg_bytes: | |
gr.Warning(f"处理图片 {i+1} 失败:无法修改SVG。") | |
continue # Skip this image | |
# 3. Convert modified SVG to PNG | |
# Use template_name_for_file for unique output filenames | |
png_path = convert_svg_bytes_to_png_api(modified_svg_bytes, template_name_for_file, i + 1) | |
if png_path: | |
generated_png_paths.append(png_path) | |
processed_count += 1 | |
else: | |
gr.Warning(f"处理图片 {i+1} 失败:无法转换为PNG。") | |
if not generated_png_paths: | |
gr.Info("未能成功生成任何图片。") | |
return [], None | |
zip_buffer = io.BytesIO() | |
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: | |
for png_path_str in generated_png_paths: | |
png_file = pathlib.Path(png_path_str) | |
zf.write(png_file, arcname=png_file.name) | |
zip_buffer.seek(0) | |
zip_filename = f"{template_name_for_file.replace(' ','_')}_batch_{uuid.uuid4().hex[:8]}.zip" # Added UUID | |
temp_zip_path = TEMP_DIR / zip_filename | |
with open(temp_zip_path, "wb") as f: | |
f.write(zip_buffer.getvalue()) | |
gr.Info(f"成功生成 {processed_count} 张图片!") | |
return generated_png_paths, str(temp_zip_path) | |
def initial_load_template_info(request: gr.Request): | |
""" | |
Loads initial template information based on URL query parameters. | |
Downloads the thumbnail and returns its local path. | |
""" | |
query_params = request.query_params | |
m = query_params.get("m") | |
u = query_params.get("u") | |
category = query_params.get("category") | |
name = query_params.get("name") # This is the template name | |
if not all([m, u, category, name]): | |
print("Initial load: URL parameters (m, u, category, name) are incomplete or missing.") | |
return None, "无模板信息 (请检查URL参数)", None, "无模板" | |
details = fetch_template_details(m, u, category, name) | |
local_thumbnail_path = None | |
template_display_name = "错误: 无法加载模板信息" # Default error message | |
svg_download_url = None | |
if details and "name" in details: | |
template_display_name = details["name"] # Use name from details for display | |
svg_download_url = details.get("svg_download_url") | |
if details.get("thumbnail_url"): | |
try: | |
thumb_url = details["thumbnail_url"] | |
# thumb_url is already fixed in fetch_template_details | |
print(f"Fetching thumbnail from: {thumb_url}") | |
thumb_response = requests.get(thumb_url, timeout=10) | |
thumb_response.raise_for_status() | |
# Create a unique filename for the thumbnail in TEMP_DIR | |
# Use name from URL params for file naming consistency if needed, or details['name'] | |
safe_template_name_for_file = name.replace(' ','_').replace('/','_').replace('\\\\','_') # Basic sanitization | |
thumb_filename = f"thumb_{safe_template_name_for_file}_{uuid.uuid4().hex[:8]}.png" | |
local_thumbnail_path = TEMP_DIR / thumb_filename | |
with open(local_thumbnail_path, "wb") as f: | |
f.write(thumb_response.content) | |
print(f"Thumbnail saved to: {local_thumbnail_path}") | |
except requests.exceptions.RequestException as e: | |
print(f"Error downloading thumbnail from {details.get('thumbnail_url')}: {e}") | |
local_thumbnail_path = None | |
except Exception as e: | |
print(f"Error saving thumbnail: {e}") | |
local_thumbnail_path = None | |
else: | |
print(f"Failed to load template details for m={m}, u={u}, category={category}, name={name}") | |
# template_display_name is already set to the default error message | |
return str(local_thumbnail_path) if local_thumbnail_path else None, \ | |
template_display_name, \ | |
svg_download_url, \ | |
name # 'name' from URL param for template_name_for_file_state | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft(), title="SVG模板批量图片生成器") as demo: | |
gr.Markdown("## 使用SVG模板批量生成图片") | |
gr.Markdown("从URL加载模板,上传您自己的图片替换模板中的底图,然后批量生成并下载。") | |
# Hidden state to store original SVG download URL and template name for file ops | |
original_svg_download_url_state = gr.State() | |
template_name_for_file_state = gr.State() # To preserve the name from URL param for consistent file naming | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=200): | |
template_thumbnail_display = gr.Image(label="当前模板缩略图", interactive=False, height=200, type="filepath") # Ensure type is filepath | |
template_name_display = gr.Textbox(label="当前模板名称", interactive=False) | |
with gr.Column(scale=3): | |
uploaded_images_input = gr.Files( | |
label="上传您的图片 (可多选)", | |
file_count="multiple", | |
file_types=["image"] # Accepts .png, .jpg, .jpeg, .gif, .webp etc. | |
) | |
generate_button = gr.Button("🚀 立即生成图片", variant="primary", scale=1) | |
with gr.Accordion("生成结果预览与下载", open=True): | |
output_gallery = gr.Gallery( | |
label="生成图片预览", | |
show_label=True, | |
elem_id="output_gallery", | |
columns=[4], | |
object_fit="contain", | |
height="auto" | |
# type="filepath" is default for Gallery if fed filepaths | |
) | |
output_zip_file = gr.File(label="下载所有生成图片的ZIP包", interactive=False, type="filepath") # Ensure type is filepath | |
# Load initial template info based on URL parameters when the interface loads. | |
# The `initial_load_template_info` function will parse request.query_params. | |
# It's crucial that `gr.Request` is passed to it. | |
# `inputs=None` with `request: gr.Request` in function signature works. | |
demo.load( | |
initial_load_template_info, | |
inputs=None, # gr.Request is implicitly passed if type-hinted in the function | |
outputs=[ | |
template_thumbnail_display, | |
template_name_display, | |
original_svg_download_url_state, | |
template_name_for_file_state | |
] | |
) | |
generate_button.click( | |
generate_images_from_template, | |
inputs=[ | |
original_svg_download_url_state, | |
template_name_for_file_state, | |
uploaded_images_input | |
# gr.Request is also implicitly passed here | |
], | |
outputs=[output_gallery, output_zip_file] | |
) | |
if __name__ == "__main__": | |
# To run this app: | |
# 1. Ensure your template_server.py is running (e.g., on http://localhost:8001) | |
# and has the /template_details/{username}/{category}/{template_name} endpoint. | |
# 2. Run this script: python use_template_app.py | |
# 3. Open your browser to the Gradio link, appending parameters, e.g.: | |
# http://127.0.0.1:7860/?m=pub&u=all&category=%E5%AE%B6%E7%94%B5&name=%E6%A8%A1%E6%9D%BF_%E5%86%B0%E7%82%B9%E4%BB%B7 | |
# For easy testing, you might want to create a dummy template_server.py endpoint | |
# or hardcode some details if the server isn't ready. | |
demo.launch() | |
# Optional: Cleanup TEMP_DIR logic can be added here if needed for long-running servers | |
# For development, manual cleanup or OS temp cleaning is often sufficient. | |
# Example: | |
# importate atexit | |
# def cleanup_temp_dir(): | |
# if TEMP_DIR.exists(): | |
# print(f"Cleaning up temp directory: {TEMP_DIR}") | |
# shutil.rmtree(TEMP_DIR) | |
# atexit.register(cleanup_temp_dir) # This might be too aggressive for dev | |