|
|
|
import os |
|
import base64 |
|
import streamlit as st |
|
from gradio_client import Client |
|
from dotenv import load_dotenv |
|
from pathlib import Path |
|
import json |
|
import hashlib |
|
import time |
|
from typing import Dict, Any |
|
|
|
|
|
load_dotenv() |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
CACHE_DIR = Path("./cache") |
|
CACHE_DIR.mkdir(exist_ok=True) |
|
|
|
|
|
CACHED_EXAMPLES = { |
|
"literacy_mental": { |
|
"title": "Literacy Mental Map", |
|
"prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, clear shapes, brain silhouette, text areas. must include the texts |
|
LITERACY/MENTAL |
|
βββ PEACE [Dove Icon] |
|
βββ HEALTH [Vitruvian Man ~60px] |
|
βββ CONNECT [Brain-Mind Connection Icon] |
|
βββ INTELLIGENCE |
|
β βββ EVERYTHING [Globe Icon ~50px] |
|
βββ MEMORY |
|
βββ READING [Book Icon ~40px] |
|
βββ SPEED [Speedometer Icon] |
|
βββ CREATIVITY |
|
βββ INTELLIGENCE [Lightbulb + Infinity ~30px]""", |
|
"width": 1024, |
|
"height": 1024, |
|
"seed": 1872187377, |
|
"cache_path": "literacy_mental.png" |
|
} |
|
} |
|
|
|
|
|
DIAGRAM_EXAMPLES = [ |
|
{ |
|
"title": "Project Management Flow", |
|
"prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, clear shapes, project management flow. |
|
PROJECT MANAGEMENT |
|
βββ INITIATION [Rocket Icon] |
|
βββ PLANNING [Calendar Icon] |
|
βββ EXECUTION [Gear Icon] |
|
βββ MONITORING |
|
β βββ CONTROL [Dashboard Icon] |
|
βββ CLOSURE [Checkmark Icon]""", |
|
"width": 1024, |
|
"height": 1024 |
|
}, |
|
{ |
|
"title": "Digital Marketing Strategy", |
|
"prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, modern style, marketing concept. |
|
DIGITAL MARKETING |
|
βββ SEO [Magnifying Glass] |
|
βββ SOCIAL MEDIA [Network Icon] |
|
βββ CONTENT |
|
β βββ BLOG [Document Icon] |
|
β βββ VIDEO [Play Button] |
|
βββ ANALYTICS [Graph Icon]""", |
|
"width": 1024, |
|
"height": 1024 |
|
} |
|
] |
|
|
|
|
|
ADDITIONAL_EXAMPLES = [ |
|
{ |
|
"title": "Health & Wellness", |
|
"prompt": """A handrawn colorful mind map diagram, wellness-focused style, health aspects. |
|
WELLNESS |
|
βββ PHYSICAL [Dumbbell Icon] |
|
βββ MENTAL [Brain Icon] |
|
βββ NUTRITION [Apple Icon] |
|
βββ SLEEP |
|
βββ QUALITY [Star Icon] |
|
βββ DURATION [Clock Icon]""", |
|
"width": 1024, |
|
"height": 1024 |
|
} |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DiagramCache: |
|
def __init__(self, cache_dir: Path): |
|
self.cache_dir = cache_dir |
|
self.cache_dir.mkdir(exist_ok=True) |
|
self._load_cache() |
|
|
|
def _load_cache(self): |
|
"""Load existing cache entries""" |
|
self.cache_index = {} |
|
if (self.cache_dir / "cache_index.json").exists(): |
|
with open(self.cache_dir / "cache_index.json", "r") as f: |
|
self.cache_index = json.load(f) |
|
|
|
def _save_cache_index(self): |
|
"""Save cache index to disk""" |
|
with open(self.cache_dir / "cache_index.json", "w") as f: |
|
json.dump(self.cache_index, f) |
|
|
|
def _get_cache_key(self, params: Dict[str, Any]) -> str: |
|
"""Generate cache key from parameters""" |
|
param_str = json.dumps(params, sort_keys=True) |
|
return hashlib.md5(param_str.encode()).hexdigest() |
|
|
|
def get(self, params: Dict[str, Any]) -> Path: |
|
"""Get cached result if exists""" |
|
cache_key = self._get_cache_key(params) |
|
cache_info = self.cache_index.get(cache_key) |
|
if cache_info: |
|
cache_path = self.cache_dir / cache_info["filename"] |
|
if cache_path.exists(): |
|
return cache_path |
|
return None |
|
|
|
def put(self, params: Dict[str, Any], result_path: Path): |
|
"""Store result in cache""" |
|
cache_key = self._get_cache_key(params) |
|
filename = f"{cache_key}{result_path.suffix}" |
|
cache_path = self.cache_dir / filename |
|
|
|
|
|
with open(result_path, "rb") as src, open(cache_path, "wb") as dst: |
|
dst.write(src.read()) |
|
|
|
|
|
self.cache_index[cache_key] = { |
|
"filename": filename, |
|
"timestamp": time.time(), |
|
"params": params |
|
} |
|
self._save_cache_index() |
|
|
|
|
|
|
|
diagram_cache = DiagramCache(CACHE_DIR) |
|
|
|
@st.cache_data |
|
def generate_cached_example(example_id: str) -> str: |
|
"""Generate and cache example diagram""" |
|
example = CACHED_EXAMPLES[example_id] |
|
client = Client("black-forest-labs/FLUX.1-schnell") |
|
|
|
|
|
cache_path = diagram_cache.get(example) |
|
if cache_path: |
|
with open(cache_path, "rb") as f: |
|
return base64.b64encode(f.read()).decode() |
|
|
|
|
|
result = client.predict( |
|
prompt=example["prompt"], |
|
seed=example["seed"], |
|
randomize_seed=False, |
|
width=example["width"], |
|
height=example["height"], |
|
num_inference_steps=4, |
|
api_name="/infer" |
|
) |
|
|
|
|
|
diagram_cache.put(example, Path(result)) |
|
|
|
with open(result, "rb") as f: |
|
return base64.b64encode(f.read()).decode() |
|
|
|
def generate_diagram(prompt: str, width: int, height: int, seed: int = None) -> str: |
|
"""Generate a new diagram""" |
|
client = Client("black-forest-labs/FLUX.1-schnell") |
|
params = { |
|
"prompt": prompt, |
|
"seed": seed if seed else 1872187377, |
|
"width": width, |
|
"height": height |
|
} |
|
|
|
|
|
cache_path = diagram_cache.get(params) |
|
if cache_path: |
|
with open(cache_path, "rb") as f: |
|
return base64.b64encode(f.read()).decode() |
|
|
|
|
|
try: |
|
result = client.predict( |
|
prompt=prompt, |
|
seed=params["seed"], |
|
randomize_seed=False, |
|
width=width, |
|
height=height, |
|
num_inference_steps=4, |
|
api_name="/infer" |
|
) |
|
|
|
|
|
diagram_cache.put(params, Path(result)) |
|
|
|
with open(result, "rb") as f: |
|
return base64.b64encode(f.read()).decode() |
|
except Exception as e: |
|
st.error(f"Error generating diagram: {str(e)}") |
|
return None |
|
|
|
|
|
def main(): |
|
st.set_page_config(page_title="FLUX Diagram Generator", layout="wide") |
|
|
|
st.title("π¨ FLUX Diagram Generator") |
|
st.markdown("Generate beautiful hand-drawn style diagrams using FLUX AI") |
|
|
|
|
|
st.sidebar.title("π Example Templates") |
|
selected_example = st.sidebar.selectbox( |
|
"Choose a template", |
|
options=range(len(DIAGRAM_EXAMPLES)), |
|
format_func=lambda x: DIAGRAM_EXAMPLES[x]["title"] |
|
) |
|
|
|
|
|
col1, col2 = st.columns([2, 1]) |
|
|
|
with col1: |
|
|
|
prompt = st.text_area( |
|
"Diagram Prompt", |
|
value=DIAGRAM_EXAMPLES[selected_example]["prompt"], |
|
height=200 |
|
) |
|
|
|
|
|
with st.expander("Advanced Configuration"): |
|
width = st.number_input("Width", min_value=512, max_value=2048, value=1024, step=128) |
|
height = st.number_input("Height", min_value=512, max_value=2048, value=1024, step=128) |
|
seed = st.number_input("Seed (optional)", value=None, step=1) |
|
|
|
if st.button("π¨ Generate Diagram"): |
|
with st.spinner("Generating your diagram..."): |
|
result = generate_diagram(prompt, width, height, seed) |
|
if result: |
|
st.image(result, caption="Generated Diagram", use_column_width=True) |
|
|
|
with col2: |
|
st.subheader("Tips for Better Results") |
|
st.markdown(""" |
|
- Use clear hierarchical structures |
|
- Include icon descriptions in brackets |
|
- Keep text concise and meaningful |
|
- Use consistent formatting |
|
""") |
|
|
|
st.subheader("Template Structure") |
|
st.code(""" |
|
MAIN TOPIC |
|
βββ SUBTOPIC 1 [Icon] |
|
βββ SUBTOPIC 2 [Icon] |
|
βββ SUBTOPIC 3 |
|
βββ DETAIL 1 [Icon] |
|
βββ DETAIL 2 [Icon] |
|
""") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|