import gradio as gr import requests import pathlib import zipfile import io import urllib.parse # For URL sanitization import xml.etree.ElementTree as ET import base64 from PIL import Image # For handling image types if needed by Gradio gallery import time import shutil import uuid # For unique filenames import os # --- Configuration --- TEMPLATE_API_BASE_URL = os.getenv("mubanapi") SVG2PNG_API_URL = os.getenv("convertapi") TEMP_DIR = pathlib.Path("temp_gradio_app_files") # Renamed for clarity # Ensure temp directory exists TEMP_DIR.mkdir(parents=True, exist_ok=True) # No need to clean it aggressively on each start # --- Helper Functions --- def fix_internal_urls(url_string: str) -> str: """ Replaces internal IP address URLs (like http://10.10.71.201:8002) with the external API base URL. This makes internal URLs accessible externally. """ if not url_string or not isinstance(url_string, str): return url_string try: # Parse the URL parts = urllib.parse.urlsplit(url_string) # Check if the netloc (domain/IP) looks like an internal address # This covers internal IPs (10.x.x.x, 172.16-31.x.x, 192.168.x.x) and localhost internal_patterns = ['10.', '172.16.', '172.17.', '172.18.', '172.19.', '172.2', '172.3', '192.168.', 'localhost', '127.0.0.1'] is_internal = any(parts.netloc.startswith(pattern) for pattern in internal_patterns) if is_internal and parts.scheme in ['http', 'https']: # Keep the path and query components, but replace the domain with TEMPLATE_API_BASE_URL # First, get the external base URL parts base_parts = urllib.parse.urlsplit(TEMPLATE_API_BASE_URL) # Create new URL with base URL's scheme and netloc, but original path and query new_url = urllib.parse.urlunsplit(( base_parts.scheme, base_parts.netloc, parts.path, parts.query, parts.fragment )) print(f"URL transformed: {url_string} -> {new_url}") return new_url return url_string except Exception as e: print(f"Error fixing internal URL '{url_string}': {e}") return url_string # --- Helper Functions --- def sanitize_url(url_string: str) -> str: """ Safely percent-encodes the path and query components of an HTTP/HTTPS URL. Attempts to be idempotent if parts of the URL are already percent-encoded. Parentheses in the path are preserved. """ if not url_string or not isinstance(url_string, str) or url_string.startswith('data:'): return url_string try: scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url_string) if scheme not in ['http', 'https']: # For other schemes, or if splitting fails to identify a scheme, return original. # This could happen for URNs, mailto, etc., which we don't want to modify. return url_string # Sanitize path: unquote to normalize, then quote. # `safe` keeps slashes, existing valid percent-encodings (%), and specified chars like parentheses. sanitized_path = urllib.parse.quote(urllib.parse.unquote(path), safe='/%;:@&+$,=?#()~') sanitized_query = "" if query: # Unquote the original query to handle cases where it might be partially encoded unquoted_query = urllib.parse.unquote(query) # Parse the unquoted query string into a dictionary of parameters parsed_query_params = urllib.parse.parse_qs(unquoted_query, keep_blank_values=True) # urlencode will quote keys and values appropriately for the query string. sanitized_query = urllib.parse.urlencode(parsed_query_params, doseq=True, quote_via=urllib.parse.quote) final_url = urllib.parse.urlunsplit((scheme, netloc, sanitized_path, sanitized_query, fragment)) # print(f"Sanitize URL Input: {url_string}, Output: {final_url}") # Debugging line return final_url except Exception as e: print(f"Warning: Could not sanitize URL '{url_string}': {e}") return url_string def fetch_template_details(m_param, u_param, category_param, name_param): """ Fetches template details from the template_server.py API. Expected to return a dict with 'name', 'thumbnail_url', 'svg_download_url'. """ username = urllib.parse.quote(u_param) category = urllib.parse.quote(category_param) template_name = urllib.parse.quote(name_param) api_url = f"{TEMPLATE_API_BASE_URL}/gradio_template_details/{username}/{category}/{template_name}" print(f"Fetching template details from: {api_url}") try: response = requests.get(api_url, timeout=10) response.raise_for_status() template_data = response.json() # Fix internal URLs in the template details if isinstance(template_data, dict): if "thumbnail_url" in template_data: template_data["thumbnail_url"] = fix_internal_urls(template_data["thumbnail_url"]) if "svg_download_url" in template_data: template_data["svg_download_url"] = fix_internal_urls(template_data["svg_download_url"]) # Handle any other URLs that might be in the response for key, value in template_data.items(): if isinstance(value, str) and (value.startswith("http://") or value.startswith("https://")): template_data[key] = fix_internal_urls(value) return template_data except requests.exceptions.RequestException as e: print(f"Error fetching template details from {api_url}: {e}") return None def download_svg_content(svg_url): """Downloads the raw SVG content from a given URL.""" try: # Fix any internal URLs before downloading fixed_url = fix_internal_urls(svg_url) response = requests.get(fixed_url, timeout=10) response.raise_for_status() return response.text # SVG is XML text except requests.exceptions.RequestException as e: print(f"Error downloading SVG content from {svg_url}: {e}") return None def replace_background_in_svg(svg_content_str: str, new_image_path: str): """ Replaces a designated image in the SVG string with a base64 data URI. Also converts all other external http/https image URLs in the SVG to base64 data URIs. """ try: namespaces = {'xlink': 'http://www.w3.org/1999/xlink', 'svg': 'http://www.w3.org/2000/svg'} for prefix, uri in namespaces.items(): ET.register_namespace(prefix, uri if prefix else '') root = ET.fromstring(svg_content_str) def _image_bytes_to_data_uri(image_bytes: bytes, resource_identifier: str) -> str: """ Converts image bytes to a data URI. mime_type_source can be a full MIME type string, a file path, or a URL. """ mime_type = "image/png" # Default if "/" in resource_identifier: # Check if it looks like a MIME type string or a URL/path with extension # Try to get from Content-Type header if it's a requests.Response object (not applicable here directly) # For now, infer from path/URL extension or if it's already a MIME type string if resource_identifier.startswith("image/"): # Already a MIME type mime_type = resource_identifier else: # Infer from extension ext = pathlib.Path(resource_identifier).suffix.lower() if ext == ".png": mime_type = "image/png" elif ext in [".jpg", ".jpeg"]: mime_type = "image/jpeg" elif ext == ".gif": mime_type = "image/gif" elif ext == ".webp": mime_type = "image/webp" elif ext == ".svg": mime_type = "image/svg+xml" encoded_image = base64.b64encode(image_bytes).decode('utf-8') return f"data:{mime_type};base64,{encoded_image}" # 1. Handle the user-uploaded image (new_image_path) background_image_element = None ids_to_check = ["background_image", "placeholder_image", "product_image", "main_image", "image_to_replace", "底图"] for img_id in ids_to_check: found_element = root.find(f".//svg:image[@id='{img_id}']", namespaces) if found_element is None: found_element = root.find(f".//{{*}}image[@id='{img_id}']") if found_element is not None: background_image_element = found_element break if background_image_element is None: first_image_svg_ns = root.find('.//svg:image', namespaces) if first_image_svg_ns is not None: background_image_element = first_image_svg_ns else: first_image_any_ns = root.find('.//image') if first_image_any_ns is not None: background_image_element = first_image_any_ns else: print("Warning: No suitable tag found in SVG to replace with new_image_path.") if background_image_element is not None: try: new_image_path_obj = pathlib.Path(new_image_path) if not new_image_path_obj.is_file(): print(f"Error: New image file not found or is not a file: {new_image_path}") # Potentially return None or raise error, for now, it will skip this replacement else: with open(new_image_path_obj, "rb") as f: new_image_bytes = f.read() data_uri = _image_bytes_to_data_uri(new_image_bytes, new_image_path) # Pass path for MIME xlink_href_attr_namespaced = f"{{{namespaces['xlink']}}}href" attr_to_set_on_target = None if xlink_href_attr_namespaced in background_image_element.attrib: attr_to_set_on_target = xlink_href_attr_namespaced elif 'href' in background_image_element.attrib: attr_to_set_on_target = 'href' else: attr_to_set_on_target = xlink_href_attr_namespaced if 'xlink' in namespaces else 'href' background_image_element.set(attr_to_set_on_target, data_uri) except Exception as e: print(f"Error processing and embedding new_image_path '{new_image_path}': {e}") # 2. Process all other http/https images in the SVG if root is not None: # Create a list of all image elements to iterate over. # This avoids issues with modifying the tree while iterating over findall() results directly. all_image_elements_in_tree = list(root.findall('.//svg:image', namespaces)) + \ [el for el in root.findall('.//image') if el.tag != f"{{{namespaces['svg']}}}image"] for image_el in all_image_elements_in_tree: if image_el == background_image_element: # Already processed (or attempted) continue href_val = None attr_name_to_set = None xlink_href_qname = f"{{{namespaces['xlink']}}}href" if image_el.get(xlink_href_qname) is not None: href_val = image_el.get(xlink_href_qname) attr_name_to_set = xlink_href_qname elif image_el.get('href') is not None: href_val = image_el.get('href') attr_name_to_set = 'href' if href_val and href_val.startswith(('http://', 'https://')): try: safe_download_url = sanitize_url(href_val) # Sanitize before download # print(f"DEBUG: Downloading remote image for embedding: {safe_download_url}") response = requests.get(safe_download_url, timeout=20, stream=True) response.raise_for_status() image_bytes = response.content content_type_header = response.headers.get('Content-Type') mime_type_source = content_type_header.split(';')[0].strip() if content_type_header and '/' in content_type_header else href_val data_uri = _image_bytes_to_data_uri(image_bytes, mime_type_source) image_el.set(attr_name_to_set, data_uri) # print(f"DEBUG: Replaced remote URL {href_val} with data URI.") except requests.exceptions.RequestException as e_req: print(f"Failed to download or process remote image {href_val}: {e_req}") except Exception as e_proc: print(f"Error processing remote image {href_val} into data URI: {e_proc}") if root is None: print("Error: SVG root is None after parsing, cannot proceed.") return None modified_svg_bytes = ET.tostring(root, encoding='UTF-8', method='xml', xml_declaration=True) try: debug_svg_filename = f"debug_embedded_svg_{uuid.uuid4().hex[:8]}.svg" debug_svg_path = TEMP_DIR / debug_svg_filename with open(debug_svg_path, "wb") as f_debug: f_debug.write(modified_svg_bytes) print(f"Saved SVG with all images embedded for debugging to: {debug_svg_path}") except Exception as e_save: print(f"Error saving debug embedded SVG: {e_save}") return modified_svg_bytes except ET.ParseError as e: print(f"Error parsing SVG XML: {e}") return None except Exception as e: print(f"An unexpected error occurred during SVG manipulation: {e}") return None def convert_svg_bytes_to_png_api(svg_content_bytes: bytes, original_template_name: str, index: int): """ Converts SVG bytes to PNG bytes using the external API. Saves the PNG to a temporary file and returns the path. """ if not svg_content_bytes: return None try: # The API expects a file upload files = {'svg_file': ('generated_svg.svg', svg_content_bytes, 'image/svg+xml')} response = requests.post(SVG2PNG_API_URL, files=files, timeout=60) response.raise_for_status() # Use UUID for more robust unique filenames png_filename = f"{original_template_name.replace(' ','_')}_output_{index}_{uuid.uuid4().hex[:8]}.png" temp_png_path = TEMP_DIR / png_filename with open(temp_png_path, "wb") as f: f.write(response.content) return str(temp_png_path) except requests.exceptions.RequestException as e: print(f"Error converting SVG to PNG via API: {e}") if hasattr(e, 'response') and e.response is not None: print(f"API Response: {e.response.text}") return None except Exception as e: print(f"Error saving PNG: {e}") return None # --- Gradio App Logic --- def generate_images_from_template(original_svg_download_url: str, template_name_for_file: str, uploaded_image_files: list, request: gr.Request): """ Main processing function for Gradio. Takes original SVG URL, template name, and list of uploaded image file objects. """ if not original_svg_download_url: gr.Warning("模板信息未加载,请确保URL参数正确。") return [], None if not uploaded_image_files: gr.Warning("请上传至少一张图片。") return [], None # 1. Download original SVG content original_svg_content = download_svg_content(original_svg_download_url) if not original_svg_content: gr.Error("无法下载原始SVG模板。") return [], None generated_png_paths = [] processed_count = 0 for i, uploaded_file_obj in enumerate(uploaded_image_files): # uploaded_file_obj.name is the temporary path to the uploaded file new_image_path = uploaded_file_obj.name # 2. Replace background in SVG for each uploaded image modified_svg_bytes = replace_background_in_svg(original_svg_content, new_image_path) if not modified_svg_bytes: gr.Warning(f"处理图片 {i+1} 失败:无法修改SVG。") continue # Skip this image # 3. Convert modified SVG to PNG # Use template_name_for_file for unique output filenames png_path = convert_svg_bytes_to_png_api(modified_svg_bytes, template_name_for_file, i + 1) if png_path: generated_png_paths.append(png_path) processed_count += 1 else: gr.Warning(f"处理图片 {i+1} 失败:无法转换为PNG。") if not generated_png_paths: gr.Info("未能成功生成任何图片。") return [], None zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: for png_path_str in generated_png_paths: png_file = pathlib.Path(png_path_str) zf.write(png_file, arcname=png_file.name) zip_buffer.seek(0) zip_filename = f"{template_name_for_file.replace(' ','_')}_batch_{uuid.uuid4().hex[:8]}.zip" # Added UUID temp_zip_path = TEMP_DIR / zip_filename with open(temp_zip_path, "wb") as f: f.write(zip_buffer.getvalue()) gr.Info(f"成功生成 {processed_count} 张图片!") return generated_png_paths, str(temp_zip_path) def initial_load_template_info(request: gr.Request): """ Loads initial template information based on URL query parameters. Downloads the thumbnail and returns its local path. """ query_params = request.query_params m = query_params.get("m") u = query_params.get("u") category = query_params.get("category") name = query_params.get("name") # This is the template name if not all([m, u, category, name]): print("Initial load: URL parameters (m, u, category, name) are incomplete or missing.") return None, "无模板信息 (请检查URL参数)", None, "无模板" details = fetch_template_details(m, u, category, name) local_thumbnail_path = None template_display_name = "错误: 无法加载模板信息" # Default error message svg_download_url = None if details and "name" in details: template_display_name = details["name"] # Use name from details for display svg_download_url = details.get("svg_download_url") if details.get("thumbnail_url"): try: thumb_url = details["thumbnail_url"] # thumb_url is already fixed in fetch_template_details print(f"Fetching thumbnail from: {thumb_url}") thumb_response = requests.get(thumb_url, timeout=10) thumb_response.raise_for_status() # Create a unique filename for the thumbnail in TEMP_DIR # Use name from URL params for file naming consistency if needed, or details['name'] safe_template_name_for_file = name.replace(' ','_').replace('/','_').replace('\\\\','_') # Basic sanitization thumb_filename = f"thumb_{safe_template_name_for_file}_{uuid.uuid4().hex[:8]}.png" local_thumbnail_path = TEMP_DIR / thumb_filename with open(local_thumbnail_path, "wb") as f: f.write(thumb_response.content) print(f"Thumbnail saved to: {local_thumbnail_path}") except requests.exceptions.RequestException as e: print(f"Error downloading thumbnail from {details.get('thumbnail_url')}: {e}") local_thumbnail_path = None except Exception as e: print(f"Error saving thumbnail: {e}") local_thumbnail_path = None else: print(f"Failed to load template details for m={m}, u={u}, category={category}, name={name}") # template_display_name is already set to the default error message return str(local_thumbnail_path) if local_thumbnail_path else None, \ template_display_name, \ svg_download_url, \ name # 'name' from URL param for template_name_for_file_state # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft(), title="SVG模板批量图片生成器") as demo: gr.Markdown("## 使用SVG模板批量生成图片") gr.Markdown("从URL加载模板,上传您自己的图片替换模板中的底图,然后批量生成并下载。") # Hidden state to store original SVG download URL and template name for file ops original_svg_download_url_state = gr.State() template_name_for_file_state = gr.State() # To preserve the name from URL param for consistent file naming with gr.Row(): with gr.Column(scale=1, min_width=200): template_thumbnail_display = gr.Image(label="当前模板缩略图", interactive=False, height=200, type="filepath") # Ensure type is filepath template_name_display = gr.Textbox(label="当前模板名称", interactive=False) with gr.Column(scale=3): uploaded_images_input = gr.Files( label="上传您的图片 (可多选)", file_count="multiple", file_types=["image"] # Accepts .png, .jpg, .jpeg, .gif, .webp etc. ) generate_button = gr.Button("🚀 立即生成图片", variant="primary", scale=1) with gr.Accordion("生成结果预览与下载", open=True): output_gallery = gr.Gallery( label="生成图片预览", show_label=True, elem_id="output_gallery", columns=[4], object_fit="contain", height="auto" # type="filepath" is default for Gallery if fed filepaths ) output_zip_file = gr.File(label="下载所有生成图片的ZIP包", interactive=False, type="filepath") # Ensure type is filepath # Load initial template info based on URL parameters when the interface loads. # The `initial_load_template_info` function will parse request.query_params. # It's crucial that `gr.Request` is passed to it. # `inputs=None` with `request: gr.Request` in function signature works. demo.load( initial_load_template_info, inputs=None, # gr.Request is implicitly passed if type-hinted in the function outputs=[ template_thumbnail_display, template_name_display, original_svg_download_url_state, template_name_for_file_state ] ) generate_button.click( generate_images_from_template, inputs=[ original_svg_download_url_state, template_name_for_file_state, uploaded_images_input # gr.Request is also implicitly passed here ], outputs=[output_gallery, output_zip_file] ) if __name__ == "__main__": # To run this app: # 1. Ensure your template_server.py is running (e.g., on http://localhost:8001) # and has the /template_details/{username}/{category}/{template_name} endpoint. # 2. Run this script: python use_template_app.py # 3. Open your browser to the Gradio link, appending parameters, e.g.: # http://127.0.0.1:7860/?m=pub&u=all&category=%E5%AE%B6%E7%94%B5&name=%E6%A8%A1%E6%9D%BF_%E5%86%B0%E7%82%B9%E4%BB%B7 # For easy testing, you might want to create a dummy template_server.py endpoint # or hardcode some details if the server isn't ready. demo.launch() # Optional: Cleanup TEMP_DIR logic can be added here if needed for long-running servers # For development, manual cleanup or OS temp cleaning is often sufficient. # Example: # importate atexit # def cleanup_temp_dir(): # if TEMP_DIR.exists(): # print(f"Cleaning up temp directory: {TEMP_DIR}") # shutil.rmtree(TEMP_DIR) # atexit.register(cleanup_temp_dir) # This might be too aggressive for dev