diagram / app.py
aiqcamp's picture
Update app.py
dd5d6cc verified
raw
history blame
8.65 kB
# app.py
import os
import base64
import streamlit as st
from gradio_client import Client
from dotenv import load_dotenv
from pathlib import Path
import json
import hashlib
import time
from typing import Dict, Any
# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
# Cache directory setup
CACHE_DIR = Path("./cache")
CACHE_DIR.mkdir(exist_ok=True)
# Cached example diagrams
CACHED_EXAMPLES = {
"literacy_mental": {
"title": "Literacy Mental Map",
"prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, clear shapes, brain silhouette, text areas. must include the texts
LITERACY/MENTAL
β”œβ”€β”€ PEACE [Dove Icon]
β”œβ”€β”€ HEALTH [Vitruvian Man ~60px]
β”œβ”€β”€ CONNECT [Brain-Mind Connection Icon]
β”œβ”€β”€ INTELLIGENCE
β”‚ └── EVERYTHING [Globe Icon ~50px]
└── MEMORY
β”œβ”€β”€ READING [Book Icon ~40px]
β”œβ”€β”€ SPEED [Speedometer Icon]
└── CREATIVITY
└── INTELLIGENCE [Lightbulb + Infinity ~30px]""",
"width": 1024,
"height": 1024,
"seed": 1872187377,
"cache_path": "literacy_mental.png"
}
}
# Example diagrams for various use cases
DIAGRAM_EXAMPLES = [
{
"title": "Project Management Flow",
"prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, clear shapes, project management flow.
PROJECT MANAGEMENT
β”œβ”€β”€ INITIATION [Rocket Icon]
β”œβ”€β”€ PLANNING [Calendar Icon]
β”œβ”€β”€ EXECUTION [Gear Icon]
β”œβ”€β”€ MONITORING
β”‚ └── CONTROL [Dashboard Icon]
└── CLOSURE [Checkmark Icon]""",
"width": 1024,
"height": 1024
},
{
"title": "Digital Marketing Strategy",
"prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, modern style, marketing concept.
DIGITAL MARKETING
β”œβ”€β”€ SEO [Magnifying Glass]
β”œβ”€β”€ SOCIAL MEDIA [Network Icon]
β”œβ”€β”€ CONTENT
β”‚ β”œβ”€β”€ BLOG [Document Icon]
β”‚ └── VIDEO [Play Button]
└── ANALYTICS [Graph Icon]""",
"width": 1024,
"height": 1024
}
]
# Add 15 more examples
ADDITIONAL_EXAMPLES = [
{
"title": "Health & Wellness",
"prompt": """A handrawn colorful mind map diagram, wellness-focused style, health aspects.
WELLNESS
β”œβ”€β”€ PHYSICAL [Dumbbell Icon]
β”œβ”€β”€ MENTAL [Brain Icon]
β”œβ”€β”€ NUTRITION [Apple Icon]
└── SLEEP
β”œβ”€β”€ QUALITY [Star Icon]
└── DURATION [Clock Icon]""",
"width": 1024,
"height": 1024
}
# ... (λ‚˜λ¨Έμ§€ μ˜ˆμ œλ“€)
]
class DiagramCache:
def __init__(self, cache_dir: Path):
self.cache_dir = cache_dir
self.cache_dir.mkdir(exist_ok=True)
self._load_cache()
def _load_cache(self):
"""Load existing cache entries"""
self.cache_index = {}
if (self.cache_dir / "cache_index.json").exists():
with open(self.cache_dir / "cache_index.json", "r") as f:
self.cache_index = json.load(f)
def _save_cache_index(self):
"""Save cache index to disk"""
with open(self.cache_dir / "cache_index.json", "w") as f:
json.dump(self.cache_index, f)
def _get_cache_key(self, params: Dict[str, Any]) -> str:
"""Generate cache key from parameters"""
param_str = json.dumps(params, sort_keys=True)
return hashlib.md5(param_str.encode()).hexdigest()
def get(self, params: Dict[str, Any]) -> Path:
"""Get cached result if exists"""
cache_key = self._get_cache_key(params)
cache_info = self.cache_index.get(cache_key)
if cache_info:
cache_path = self.cache_dir / cache_info["filename"]
if cache_path.exists():
return cache_path
return None
def put(self, params: Dict[str, Any], result_path: Path):
"""Store result in cache"""
cache_key = self._get_cache_key(params)
filename = f"{cache_key}{result_path.suffix}"
cache_path = self.cache_dir / filename
# Copy result to cache
with open(result_path, "rb") as src, open(cache_path, "wb") as dst:
dst.write(src.read())
# Update index
self.cache_index[cache_key] = {
"filename": filename,
"timestamp": time.time(),
"params": params
}
self._save_cache_index()
# Initialize cache
diagram_cache = DiagramCache(CACHE_DIR)
@st.cache_data
def generate_cached_example(example_id: str) -> str:
"""Generate and cache example diagram"""
example = CACHED_EXAMPLES[example_id]
client = Client("black-forest-labs/FLUX.1-schnell")
# Check cache first
cache_path = diagram_cache.get(example)
if cache_path:
with open(cache_path, "rb") as f:
return base64.b64encode(f.read()).decode()
# Generate new image
result = client.predict(
prompt=example["prompt"],
seed=example["seed"],
randomize_seed=False,
width=example["width"],
height=example["height"],
num_inference_steps=4,
api_name="/infer"
)
# Cache the result
diagram_cache.put(example, Path(result))
with open(result, "rb") as f:
return base64.b64encode(f.read()).decode()
def generate_diagram(prompt: str, width: int, height: int, seed: int = None) -> str:
"""Generate a new diagram"""
client = Client("black-forest-labs/FLUX.1-schnell")
params = {
"prompt": prompt,
"seed": seed if seed else 1872187377,
"width": width,
"height": height
}
# Check cache first
cache_path = diagram_cache.get(params)
if cache_path:
with open(cache_path, "rb") as f:
return base64.b64encode(f.read()).decode()
# Generate new image
try:
result = client.predict(
prompt=prompt,
seed=params["seed"],
randomize_seed=False,
width=width,
height=height,
num_inference_steps=4,
api_name="/infer"
)
# Cache the result
diagram_cache.put(params, Path(result))
with open(result, "rb") as f:
return base64.b64encode(f.read()).decode()
except Exception as e:
st.error(f"Error generating diagram: {str(e)}")
return None
def main():
st.set_page_config(page_title="FLUX Diagram Generator", layout="wide")
st.title("🎨 FLUX Diagram Generator")
st.markdown("Generate beautiful hand-drawn style diagrams using FLUX AI")
# Sidebar for examples
st.sidebar.title("πŸ“š Example Templates")
selected_example = st.sidebar.selectbox(
"Choose a template",
options=range(len(DIAGRAM_EXAMPLES)),
format_func=lambda x: DIAGRAM_EXAMPLES[x]["title"]
)
# Main content area
col1, col2 = st.columns([2, 1])
with col1:
# Input area
prompt = st.text_area(
"Diagram Prompt",
value=DIAGRAM_EXAMPLES[selected_example]["prompt"],
height=200
)
# Configuration
with st.expander("Advanced Configuration"):
width = st.number_input("Width", min_value=512, max_value=2048, value=1024, step=128)
height = st.number_input("Height", min_value=512, max_value=2048, value=1024, step=128)
seed = st.number_input("Seed (optional)", value=None, step=1)
if st.button("🎨 Generate Diagram"):
with st.spinner("Generating your diagram..."):
result = generate_diagram(prompt, width, height, seed)
if result:
st.image(result, caption="Generated Diagram", use_column_width=True)
with col2:
st.subheader("Tips for Better Results")
st.markdown("""
- Use clear hierarchical structures
- Include icon descriptions in brackets
- Keep text concise and meaningful
- Use consistent formatting
""")
st.subheader("Template Structure")
st.code("""
MAIN TOPIC
β”œβ”€β”€ SUBTOPIC 1 [Icon]
β”œβ”€β”€ SUBTOPIC 2 [Icon]
└── SUBTOPIC 3
β”œβ”€β”€ DETAIL 1 [Icon]
└── DETAIL 2 [Icon]
""")
if __name__ == "__main__":
main()