Spaces:
Running
Running
#!/usr/bin/env python3 | |
import os | |
import sys | |
import tempfile | |
import zipfile | |
import json | |
from pathlib import Path | |
from typing import Optional, Tuple | |
import uuid | |
from datetime import datetime | |
import traceback | |
import gradio as gr | |
project_root = Path(__file__).parent | |
sys.path.insert(0, str(project_root)) | |
try: | |
from src.state.poster_state import create_state | |
from src.workflow.pipeline import create_workflow_graph | |
except ImportError as e: | |
print(f"Error importing modules: {e}") | |
sys.exit(1) | |
def set_temp_api_keys(anthropic_key, openai_key, anthropic_base_url=None, openai_base_url=None): | |
"""Temporarily set API keys and base URLs in environment, returns cleanup function""" | |
original_values = {} | |
# Save original values and set new ones | |
if anthropic_key and anthropic_key.strip(): | |
original_values["ANTHROPIC_API_KEY"] = os.environ.get("ANTHROPIC_API_KEY") | |
os.environ["ANTHROPIC_API_KEY"] = anthropic_key.strip() | |
if openai_key and openai_key.strip(): | |
original_values["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY") | |
os.environ["OPENAI_API_KEY"] = openai_key.strip() | |
if anthropic_base_url and anthropic_base_url.strip(): | |
original_values["ANTHROPIC_BASE_URL"] = os.environ.get("ANTHROPIC_BASE_URL") | |
os.environ["ANTHROPIC_BASE_URL"] = anthropic_base_url.strip() | |
if openai_base_url and openai_base_url.strip(): | |
original_values["OPENAI_BASE_URL"] = os.environ.get("OPENAI_BASE_URL") | |
os.environ["OPENAI_BASE_URL"] = openai_base_url.strip() | |
def cleanup(): | |
"""Restore original environment values""" | |
for key, original_value in original_values.items(): | |
if original_value is None: | |
os.environ.pop(key, None) | |
else: | |
os.environ[key] = original_value | |
return cleanup | |
AVAILABLE_MODELS = [ | |
"claude-sonnet-4-20250514", | |
"gpt-4o-2024-08-06", | |
"gpt-4.1-2025-04-14", | |
"gpt-4.1-mini-2025-04-14" | |
] | |
def create_job_directory() -> Path: | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
job_id = str(uuid.uuid4())[:8] | |
dir_name = f"job_{timestamp}_{job_id}" | |
job_dir = Path(tempfile.mkdtemp(prefix=f"{dir_name}_")) | |
return job_dir | |
def validate_inputs(pdf_file, logo_file, aff_logo_file, text_model, vision_model, poster_width, poster_height, anthropic_key, openai_key, anthropic_base_url, openai_base_url): | |
if not pdf_file: | |
return "Please upload PDF paper" | |
if not logo_file: | |
return "Please upload conference logo" | |
if not aff_logo_file: | |
return "Please upload affiliation logo" | |
if text_model not in AVAILABLE_MODELS: | |
return f"Invalid text model: {text_model}" | |
if vision_model not in AVAILABLE_MODELS: | |
return f"Invalid vision model: {vision_model}" | |
# Check API keys | |
has_anthropic = bool(anthropic_key and anthropic_key.strip()) | |
has_openai = bool(openai_key and openai_key.strip()) | |
if not has_anthropic and not has_openai: | |
return "Please provide at least one API key (Anthropic or OpenAI)" | |
# Check if selected models have corresponding API keys | |
if text_model.startswith("claude") and not has_anthropic: | |
return "Anthropic API key required for Claude models" | |
if text_model.startswith("gpt") and not has_openai: | |
return "OpenAI API key required for GPT models" | |
if vision_model.startswith("claude") and not has_anthropic: | |
return "Anthropic API key required for Claude models" | |
if vision_model.startswith("gpt") and not has_openai: | |
return "OpenAI API key required for GPT models" | |
ratio = poster_width / poster_height | |
if ratio < 1.4 or ratio > 2.0: | |
return f"Poster ratio {ratio:.2f} out of range (1.4-2.0)" | |
# Check file type - Gradio returns file object with name attribute | |
if hasattr(pdf_file, 'name') and not pdf_file.name.lower().endswith('.pdf'): | |
return "Paper must be PDF format" | |
return None | |
def generate_poster(pdf_file, logo_file, aff_logo_file, text_model, vision_model, poster_width, poster_height, anthropic_key, openai_key, anthropic_base_url, openai_base_url, progress=gr.Progress()): | |
try: | |
# Set API keys temporarily | |
cleanup_api_keys = set_temp_api_keys(anthropic_key, openai_key, anthropic_base_url, openai_base_url) | |
error_msg = validate_inputs(pdf_file, logo_file, aff_logo_file, text_model, vision_model, poster_width, poster_height, anthropic_key, openai_key, anthropic_base_url, openai_base_url) | |
if error_msg: | |
cleanup_api_keys() | |
return None, f"β {error_msg}" | |
progress(0.1, desc="Initializing...") | |
job_dir = create_job_directory() | |
pdf_path = job_dir / "paper.pdf" | |
logo_path = job_dir / "logo.png" | |
aff_logo_path = job_dir / "aff_logo.png" | |
# Handle file writing - check if it's file object or bytes | |
if hasattr(pdf_file, 'read'): | |
pdf_content = pdf_file.read() | |
else: | |
pdf_content = pdf_file | |
if hasattr(logo_file, 'read'): | |
logo_content = logo_file.read() | |
else: | |
logo_content = logo_file | |
if hasattr(aff_logo_file, 'read'): | |
aff_logo_content = aff_logo_file.read() | |
else: | |
aff_logo_content = aff_logo_file | |
with open(pdf_path, "wb") as f: | |
f.write(pdf_content) | |
with open(logo_path, "wb") as f: | |
f.write(logo_content) | |
with open(aff_logo_path, "wb") as f: | |
f.write(aff_logo_content) | |
progress(0.2, desc="Setting up workflow...") | |
state = create_state( | |
pdf_path=str(pdf_path), | |
text_model=text_model, | |
vision_model=vision_model, | |
width=int(poster_width), | |
height=int(poster_height), | |
url="", | |
logo_path=str(logo_path), | |
aff_logo_path=str(aff_logo_path) | |
) | |
progress(0.3, desc="Compiling workflow...") | |
graph = create_workflow_graph() | |
workflow = graph.compile() | |
progress(0.5, desc="Processing paper...") | |
final_state = workflow.invoke(state) | |
progress(0.8, desc="Generating outputs...") | |
if final_state.get("errors"): | |
error_details = "; ".join(final_state["errors"]) | |
cleanup_api_keys() | |
return None, f"β Generation errors: {error_details}" | |
output_dir = Path(final_state["output_dir"]) | |
poster_name = final_state.get("poster_name", "poster") | |
progress(0.9, desc="Packaging results...") | |
zip_path = job_dir / f"{poster_name}_output.zip" | |
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
if output_dir.exists(): | |
for file_path in output_dir.rglob("*"): | |
if file_path.is_file(): | |
arcname = file_path.relative_to(output_dir) | |
zipf.write(file_path, arcname) | |
progress(1.0, desc="β Completed!") | |
success_msg = f"""β Poster generation successful! | |
Poster: {poster_name} | |
Output: {output_dir.name} | |
Package: {zip_path.name} | |
Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}""" | |
cleanup_api_keys() | |
return str(zip_path), success_msg | |
except Exception as e: | |
cleanup_api_keys() | |
error_msg = f"β Error: {str(e)}\n\n{traceback.format_exc()}" | |
return None, error_msg | |
with gr.Blocks(title="PosterGen", css=""" | |
.gradio-column { | |
margin-left: 10px !important; | |
margin-right: 10px !important; | |
} | |
.gradio-column:first-child { | |
margin-left: 0 !important; | |
} | |
.gradio-column:last-child { | |
margin-right: 0 !important; | |
} | |
""") as demo: | |
gr.HTML(""" | |
<div style="text-align: center; margin-bottom: 30px;"> | |
<h1 style="margin-bottom: 10px;"> | |
PosterGen | |
</h1> | |
<p style="font-size: 18px; color: #666;">π¨ Generate design-aware academic posters from PDF papers</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1, variant="panel"): | |
gr.Markdown("### π Upload Files") | |
pdf_file = gr.File(label="PDF Paper", file_types=[".pdf"], type="binary") | |
with gr.Row(): | |
logo_file = gr.File(label="Conference Logo", file_types=["image"], type="binary") | |
aff_logo_file = gr.File(label="Affiliation Logo", file_types=["image"], type="binary") | |
with gr.Column(scale=1, variant="panel"): | |
with gr.Group(): | |
gr.Markdown("### π API Keys") | |
gr.Markdown("β οΈ Keys are processed securely and not stored") | |
with gr.Row(): | |
anthropic_key = gr.Textbox( | |
label="Anthropic API Key", | |
type="password", | |
placeholder="sk-ant-...", | |
info="Required for Claude models" | |
) | |
openai_key = gr.Textbox( | |
label="OpenAI API Key", | |
type="password", | |
placeholder="sk-...", | |
info="Required for GPT models" | |
) | |
with gr.Row(): | |
anthropic_base_url = gr.Textbox( | |
label="Anthropic Base URL (Optional)", | |
placeholder="https://api.anthropic.com", | |
info="Set the base url for compatible API services" | |
) | |
openai_base_url = gr.Textbox( | |
label="OpenAI Base URL (Optional)", | |
placeholder="https://api.openai.com/v1", | |
info="Set the base url for compatible API services" | |
) | |
gr.Markdown("### π€ Model Settings") | |
with gr.Row(): | |
text_model = gr.Dropdown(choices=AVAILABLE_MODELS, value=AVAILABLE_MODELS[0], label="Text Model") | |
vision_model = gr.Dropdown(choices=AVAILABLE_MODELS, value=AVAILABLE_MODELS[0], label="Vision Model") | |
gr.Markdown("### π Dimensions") | |
with gr.Row(): | |
poster_width = gr.Number(value=54, minimum=20, maximum=100, step=0.1, label="Width (inches)") | |
poster_height = gr.Number(value=36, minimum=10, maximum=60, step=0.1, label="Height (inches)") | |
with gr.Column(scale=1, variant="panel"): | |
gr.Markdown("### π Status") | |
status_output = gr.Textbox(label="Generation Status", placeholder="Click 'Generate Poster' to start...", lines=6) | |
gr.Markdown("### π₯ Download") | |
download_file = gr.File(label="Download Package") | |
# Generate button spanning full width | |
with gr.Row(): | |
generate_btn = gr.Button("π Generate Poster", variant="primary", size="lg") | |
def generate_and_display(*args): | |
download_file_result, status_result = generate_poster(*args) | |
return download_file_result, status_result | |
generate_btn.click( | |
fn=generate_and_display, | |
inputs=[pdf_file, logo_file, aff_logo_file, text_model, vision_model, poster_width, poster_height, anthropic_key, openai_key, anthropic_base_url, openai_base_url], | |
outputs=[download_file, status_output] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |