# app.py import os import base64 import streamlit as st from gradio_client import Client from dotenv import load_dotenv from pathlib import Path import json import hashlib import time from typing import Dict, Any # Load environment variables load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") # Cache directory setup CACHE_DIR = Path("./cache") CACHE_DIR.mkdir(exist_ok=True) # Cached example diagrams CACHED_EXAMPLES = { "literacy_mental": { "title": "Literacy Mental Map", "prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, clear shapes, brain silhouette, text areas. must include the texts LITERACY/MENTAL ├── PEACE [Dove Icon] ├── HEALTH [Vitruvian Man ~60px] ├── CONNECT [Brain-Mind Connection Icon] ├── INTELLIGENCE │ └── EVERYTHING [Globe Icon ~50px] └── MEMORY ├── READING [Book Icon ~40px] ├── SPEED [Speedometer Icon] └── CREATIVITY └── INTELLIGENCE [Lightbulb + Infinity ~30px]""", "width": 1024, "height": 1024, "seed": 1872187377, "cache_path": "literacy_mental.png" } } # Example diagrams for various use cases DIAGRAM_EXAMPLES = [ { "title": "Project Management Flow", "prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, clear shapes, project management flow. PROJECT MANAGEMENT ├── INITIATION [Rocket Icon] ├── PLANNING [Calendar Icon] ├── EXECUTION [Gear Icon] ├── MONITORING │ └── CONTROL [Dashboard Icon] └── CLOSURE [Checkmark Icon]""", "width": 1024, "height": 1024 }, { "title": "Digital Marketing Strategy", "prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, modern style, marketing concept. DIGITAL MARKETING ├── SEO [Magnifying Glass] ├── SOCIAL MEDIA [Network Icon] ├── CONTENT │ ├── BLOG [Document Icon] │ └── VIDEO [Play Button] └── ANALYTICS [Graph Icon]""", "width": 1024, "height": 1024 } ] # Add 15 more examples ADDITIONAL_EXAMPLES = [ { "title": "Health & Wellness", "prompt": """A handrawn colorful mind map diagram, wellness-focused style, health aspects. WELLNESS ├── PHYSICAL [Dumbbell Icon] ├── MENTAL [Brain Icon] ├── NUTRITION [Apple Icon] └── SLEEP ├── QUALITY [Star Icon] └── DURATION [Clock Icon]""", "width": 1024, "height": 1024 } # ... (나머지 예제들) ] class DiagramCache: def __init__(self, cache_dir: Path): self.cache_dir = cache_dir self.cache_dir.mkdir(exist_ok=True) self._load_cache() def _load_cache(self): """Load existing cache entries""" self.cache_index = {} if (self.cache_dir / "cache_index.json").exists(): with open(self.cache_dir / "cache_index.json", "r") as f: self.cache_index = json.load(f) def _save_cache_index(self): """Save cache index to disk""" with open(self.cache_dir / "cache_index.json", "w") as f: json.dump(self.cache_index, f) def _get_cache_key(self, params: Dict[str, Any]) -> str: """Generate cache key from parameters""" param_str = json.dumps(params, sort_keys=True) return hashlib.md5(param_str.encode()).hexdigest() def get(self, params: Dict[str, Any]) -> Path: """Get cached result if exists""" cache_key = self._get_cache_key(params) cache_info = self.cache_index.get(cache_key) if cache_info: cache_path = self.cache_dir / cache_info["filename"] if cache_path.exists(): return cache_path return None def put(self, params: Dict[str, Any], result_path: Path): """Store result in cache""" cache_key = self._get_cache_key(params) filename = f"{cache_key}{result_path.suffix}" cache_path = self.cache_dir / filename # Copy result to cache with open(result_path, "rb") as src, open(cache_path, "wb") as dst: dst.write(src.read()) # Update index self.cache_index[cache_key] = { "filename": filename, "timestamp": time.time(), "params": params } self._save_cache_index() # Initialize cache diagram_cache = DiagramCache(CACHE_DIR) @st.cache_data def generate_cached_example(example_id: str) -> str: """Generate and cache example diagram""" example = CACHED_EXAMPLES[example_id] client = Client("black-forest-labs/FLUX.1-schnell") # Check cache first cache_path = diagram_cache.get(example) if cache_path: with open(cache_path, "rb") as f: return base64.b64encode(f.read()).decode() # Generate new image result = client.predict( prompt=example["prompt"], seed=example["seed"], randomize_seed=False, width=example["width"], height=example["height"], num_inference_steps=4, api_name="/infer" ) # Cache the result diagram_cache.put(example, Path(result)) with open(result, "rb") as f: return base64.b64encode(f.read()).decode() def generate_diagram(prompt: str, width: int, height: int, seed: int = None) -> str: """Generate a new diagram""" client = Client("black-forest-labs/FLUX.1-schnell") params = { "prompt": prompt, "seed": seed if seed else 1872187377, "width": width, "height": height } # Check cache first cache_path = diagram_cache.get(params) if cache_path: with open(cache_path, "rb") as f: return base64.b64encode(f.read()).decode() # Generate new image try: result = client.predict( prompt=prompt, seed=params["seed"], randomize_seed=False, width=width, height=height, num_inference_steps=4, api_name="/infer" ) # Cache the result diagram_cache.put(params, Path(result)) with open(result, "rb") as f: return base64.b64encode(f.read()).decode() except Exception as e: st.error(f"Error generating diagram: {str(e)}") return None def main(): st.set_page_config(page_title="FLUX Diagram Generator", layout="wide") st.title("🎨 FLUX Diagram Generator") st.markdown("Generate beautiful hand-drawn style diagrams using FLUX AI") # Sidebar for examples st.sidebar.title("📚 Example Templates") selected_example = st.sidebar.selectbox( "Choose a template", options=range(len(DIAGRAM_EXAMPLES)), format_func=lambda x: DIAGRAM_EXAMPLES[x]["title"] ) # Main content area col1, col2 = st.columns([2, 1]) with col1: # Input area prompt = st.text_area( "Diagram Prompt", value=DIAGRAM_EXAMPLES[selected_example]["prompt"], height=200 ) # Configuration with st.expander("Advanced Configuration"): width = st.number_input("Width", min_value=512, max_value=2048, value=1024, step=128) height = st.number_input("Height", min_value=512, max_value=2048, value=1024, step=128) seed = st.number_input("Seed (optional)", value=None, step=1) if st.button("🎨 Generate Diagram"): with st.spinner("Generating your diagram..."): result = generate_diagram(prompt, width, height, seed) if result: st.image(result, caption="Generated Diagram", use_column_width=True) with col2: st.subheader("Tips for Better Results") st.markdown(""" - Use clear hierarchical structures - Include icon descriptions in brackets - Keep text concise and meaningful - Use consistent formatting """) st.subheader("Template Structure") st.code(""" MAIN TOPIC ├── SUBTOPIC 1 [Icon] ├── SUBTOPIC 2 [Icon] └── SUBTOPIC 3 ├── DETAIL 1 [Icon] └── DETAIL 2 [Icon] """) if __name__ == "__main__": main()