|
|
|
""" |
|
CompI Phase 3 Final Dashboard - Complete Integration (3.A → 3.E) |
|
|
|
This is the ultimate CompI interface that integrates ALL Phase 3 components: |
|
- Phase 3.A/3.B: True multimodal fusion with real processing |
|
- Phase 3.C: Advanced references with role assignment and live ControlNet previews |
|
- Phase 3.D: Professional workflow management (gallery, presets, export) |
|
- Phase 3.E: Performance management and model switching |
|
|
|
Features: |
|
- All multimodal inputs (Text, Audio, Data, Emotion, Real-time, Multi-Reference) |
|
- Advanced References: multi-image upload/URLs, style vs structure roles, ControlNet with live previews |
|
- Model & Performance: SD 1.5/SDXL switching, LoRA integration, VRAM monitoring, OOM auto-retry |
|
- Workflow & Export: gallery, filters, rating/tags/notes, presets save/load, portable export ZIP |
|
- True fusion engine: real processing for all inputs, intelligent generation mode selection |
|
""" |
|
|
|
import os |
|
import io |
|
import csv |
|
import json |
|
import zipfile |
|
import shutil |
|
import platform |
|
import requests |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import Optional, Dict, List |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import streamlit as st |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
from diffusers import ( |
|
StableDiffusionPipeline, |
|
StableDiffusionImg2ImgPipeline, |
|
) |
|
|
|
|
|
HAS_CONTROLNET = True |
|
CN_IMG2IMG_AVAILABLE = True |
|
try: |
|
from diffusers import ( |
|
StableDiffusionControlNetPipeline, |
|
StableDiffusionControlNetImg2ImgPipeline, |
|
ControlNetModel, |
|
) |
|
except Exception: |
|
HAS_CONTROLNET = False |
|
CN_IMG2IMG_AVAILABLE = False |
|
|
|
|
|
HAS_SDXL = True |
|
HAS_UPSCALER = True |
|
try: |
|
from diffusers import StableDiffusionXLPipeline |
|
except Exception: |
|
HAS_SDXL = False |
|
|
|
try: |
|
from diffusers import StableDiffusionLatentUpscalePipeline |
|
except Exception: |
|
HAS_UPSCALER = False |
|
|
|
|
|
def _lazy_install(pkgs: str): |
|
"""Install packages on demand""" |
|
os.system(f"pip install -q {pkgs}") |
|
|
|
try: |
|
import librosa |
|
import soundfile as sf |
|
except Exception: |
|
_lazy_install("librosa soundfile") |
|
import librosa |
|
import soundfile as sf |
|
|
|
try: |
|
import whisper |
|
except Exception: |
|
_lazy_install("git+https://github.com/openai/whisper.git") |
|
import whisper |
|
|
|
try: |
|
from textblob import TextBlob |
|
except Exception: |
|
_lazy_install("textblob") |
|
from textblob import TextBlob |
|
|
|
try: |
|
import feedparser |
|
except Exception: |
|
_lazy_install("feedparser") |
|
import feedparser |
|
|
|
try: |
|
import matplotlib.pyplot as plt |
|
except Exception: |
|
_lazy_install("matplotlib") |
|
import matplotlib.pyplot as plt |
|
|
|
try: |
|
import cv2 |
|
except Exception: |
|
_lazy_install("opencv-python-headless") |
|
import cv2 |
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
OUTPUT_DIR = Path("outputs") |
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
EXPORTS_DIR = Path("exports") |
|
EXPORTS_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
PRESETS_DIR = Path("presets") |
|
PRESETS_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
RUNLOG = OUTPUT_DIR / "phase3_run_log.csv" |
|
RUNLOG_3C = OUTPUT_DIR / "phase3c_runs.csv" |
|
RUNLOG_3E = OUTPUT_DIR / "phase3e_runlog.csv" |
|
ANNOT_CSV = OUTPUT_DIR / "phase3d_annotations.csv" |
|
|
|
|
|
|
|
def slugify(s: str, n=30): |
|
"""Create safe filename from string""" |
|
if not s: |
|
return "none" |
|
return "_".join(s.lower().split())[:n] |
|
|
|
def save_image(img: Image.Image, name: str) -> str: |
|
"""Save image to outputs directory""" |
|
p = OUTPUT_DIR / name |
|
img.save(p) |
|
return str(p) |
|
|
|
def vram_gb() -> Optional[float]: |
|
"""Get total VRAM in GB""" |
|
if DEVICE == "cuda": |
|
try: |
|
return torch.cuda.get_device_properties(0).total_memory / (1024**3) |
|
except Exception: |
|
return None |
|
return None |
|
|
|
def vram_used_gb() -> Optional[float]: |
|
"""Get used VRAM in GB""" |
|
if DEVICE == "cuda": |
|
try: |
|
torch.cuda.synchronize() |
|
return torch.cuda.memory_allocated() / (1024**3) |
|
except Exception: |
|
return None |
|
return None |
|
|
|
def attempt_enable_xformers(pipe): |
|
"""Try to enable xFormers memory efficient attention""" |
|
try: |
|
pipe.enable_xformers_memory_efficient_attention() |
|
return True |
|
except Exception: |
|
return False |
|
|
|
def apply_perf(pipe, attn_slice=True, vae_slice=True, vae_tile=False): |
|
"""Apply performance optimizations to pipeline""" |
|
if attn_slice: |
|
pipe.enable_attention_slicing() |
|
if vae_slice: |
|
try: |
|
pipe.enable_vae_slicing() |
|
except Exception: |
|
pass |
|
if vae_tile: |
|
try: |
|
pipe.enable_vae_tiling() |
|
except Exception: |
|
pass |
|
|
|
def safe_retry_sizes(h, w, steps): |
|
"""Generate progressive fallback sizes for OOM recovery""" |
|
sizes = [ |
|
(h, w, steps), |
|
(max(384, h//2), max(384, w//2), max(steps-8, 12)), |
|
(384, 384, max(steps-12, 12)), |
|
(256, 256, max(steps-16, 10)), |
|
] |
|
seen = set() |
|
for it in sizes: |
|
if it not in seen: |
|
seen.add(it) |
|
yield it |
|
|
|
def canny_map(img: Image.Image) -> Image.Image: |
|
"""Create Canny edge map from image""" |
|
arr = np.array(img.convert("RGB")) |
|
edges = cv2.Canny(arr, 100, 200) |
|
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) |
|
return Image.fromarray(edges_rgb) |
|
|
|
def depth_proxy(img: Image.Image) -> Image.Image: |
|
"""Create depth-like proxy using grayscale""" |
|
gray = img.convert("L") |
|
return Image.merge("RGB", (gray, gray, gray)) |
|
|
|
def save_plot(fig) -> Image.Image: |
|
"""Save matplotlib figure as PIL Image""" |
|
buf = io.BytesIO() |
|
fig.savefig(buf, format="png", bbox_inches="tight") |
|
plt.close(fig) |
|
buf.seek(0) |
|
return Image.open(buf).convert("RGB") |
|
|
|
def env_snapshot() -> Dict: |
|
"""Create environment snapshot for reproducibility""" |
|
import sys |
|
try: |
|
import importlib.metadata as im |
|
except Exception: |
|
import importlib_metadata as im |
|
|
|
pkgs = {} |
|
for pkg in ["torch", "diffusers", "transformers", "accelerate", "opencv-python-headless", |
|
"librosa", "whisper", "textblob", "pandas", "numpy", "matplotlib", |
|
"feedparser", "streamlit", "Pillow"]: |
|
try: |
|
pkgs[pkg] = im.version(pkg) |
|
except Exception: |
|
pass |
|
|
|
return { |
|
"timestamp": datetime.now().isoformat(), |
|
"python_version": sys.version, |
|
"platform": platform.platform(), |
|
"packages": pkgs |
|
} |
|
|
|
def mk_readme(bundle_meta: Dict, df_meta: pd.DataFrame) -> str: |
|
"""Generate README for export bundle""" |
|
L = [] |
|
L.append(f"# CompI Export — {bundle_meta['bundle_name']}\n") |
|
L.append(f"_Created: {bundle_meta['created_at']}_\n") |
|
L += [ |
|
"## What's inside", |
|
"- Selected images", |
|
"- `manifest.json` (environment + settings)", |
|
"- `metadata.csv` (merged logs)", |
|
"- `annotations.csv` (ratings/tags/notes)", |
|
] |
|
if bundle_meta.get("preset"): |
|
L.append("- `preset.json` (saved generation settings)") |
|
|
|
L.append("\n## Summary of selected runs") |
|
if not df_meta.empty and "mode" in df_meta.columns: |
|
counts = df_meta["mode"].value_counts().to_dict() |
|
L.append("Modes:") |
|
for k, v in counts.items(): |
|
L.append(f"- {k}: {v}") |
|
|
|
L.append("\n## Reproducing") |
|
L.append("1. Install versions in `manifest.json`.") |
|
L.append("2. Use `preset.json` or copy prompt/params from `metadata.csv`.") |
|
L.append("3. Run the dashboard with these settings.") |
|
|
|
return "\n".join(L) |
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=True) |
|
def load_sd15(txt2img=True): |
|
"""Load Stable Diffusion 1.5 pipeline""" |
|
if txt2img: |
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
safety_checker=None, |
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
|
) |
|
else: |
|
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
safety_checker=None, |
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
|
) |
|
return pipe.to(DEVICE) |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def load_sdxl(): |
|
"""Load SDXL pipeline""" |
|
if not HAS_SDXL: |
|
return None |
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
safety_checker=None, |
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
|
) |
|
return pipe.to(DEVICE) |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def load_upscaler(): |
|
"""Load latent upscaler pipeline""" |
|
if not HAS_UPSCALER: |
|
return None |
|
up = StableDiffusionLatentUpscalePipeline.from_pretrained( |
|
"stabilityai/sd-x2-latent-upscaler", |
|
safety_checker=None, |
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
|
) |
|
return up.to(DEVICE) |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def load_controlnet(cn_type: str): |
|
"""Load ControlNet pipeline""" |
|
if not HAS_CONTROLNET: |
|
return None |
|
cn_id = "lllyasviel/sd-controlnet-canny" if cn_type == "Canny" else "lllyasviel/sd-controlnet-depth" |
|
controlnet = ControlNetModel.from_pretrained( |
|
cn_id, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 |
|
) |
|
pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
controlnet=controlnet, |
|
safety_checker=None, |
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
|
).to(DEVICE) |
|
try: |
|
pipe.enable_xformers_memory_efficient_attention() |
|
except Exception: |
|
pass |
|
pipe.enable_attention_slicing() |
|
return pipe |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def load_controlnet_img2img(cn_type: str): |
|
"""Load ControlNet + Img2Img hybrid pipeline""" |
|
global CN_IMG2IMG_AVAILABLE |
|
if not HAS_CONTROLNET: |
|
return None |
|
try: |
|
cn_id = "lllyasviel/sd-controlnet-canny" if cn_type == "Canny" else "lllyasviel/sd-controlnet-depth" |
|
controlnet = ControlNetModel.from_pretrained( |
|
cn_id, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 |
|
) |
|
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
controlnet=controlnet, |
|
safety_checker=None, |
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
|
).to(DEVICE) |
|
try: |
|
pipe.enable_xformers_memory_efficient_attention() |
|
except Exception: |
|
pass |
|
pipe.enable_attention_slicing() |
|
return pipe |
|
except Exception: |
|
CN_IMG2IMG_AVAILABLE = False |
|
return None |
|
|
|
|
|
|
|
st.set_page_config(page_title="CompI — Phase 3 Final Dashboard", layout="wide") |
|
st.title("🧪 CompI — Final Integrated Dashboard (3.A → 3.E)") |
|
|
|
|
|
|
|
def inject_minimal_css(): |
|
st.markdown( |
|
""" |
|
<style> |
|
.block-container {padding-top: 1.2rem; padding-bottom: 2rem; max-width: 1200px;} |
|
.stTabs [role="tablist"] {gap: 6px;} |
|
.stTabs [role="tab"] {padding: 6px 10px; border-radius: 8px; background: rgba(255,255,255,0.02); border: 1px solid rgba(255,255,255,0.08);} |
|
.stTabs [aria-selected="true"] {background: rgba(255,255,255,0.04); border-color: rgba(255,255,255,0.16);} |
|
h1, h2, h3 {margin-bottom: .3rem;} |
|
.section {padding: 14px 16px; border: 1px solid rgba(255,255,255,0.08); border-radius: 12px; background: rgba(255,255,255,0.02); margin-bottom: 14px;} |
|
.muted {color: rgba(255,255,255,0.6); text-transform: uppercase; letter-spacing: .08em; font-size: .75rem; margin-bottom: .25rem;} |
|
.stButton>button {border-radius: 10px; height: 44px;} |
|
.stButton>button[kind="primary"] {background: #2563eb; border-color: #2563eb;} |
|
.stTextInput input, .stTextArea textarea {border-radius: 10px;} |
|
.stMultiSelect [data-baseweb="tag"] {border-radius: 8px;} |
|
pre, code {border-radius: 10px;} |
|
#MainMenu, footer {visibility: hidden;} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
inject_minimal_css() |
|
|
|
|
|
colA, colB, colC, colD = st.columns(4) |
|
with colA: |
|
st.metric("Device", DEVICE) |
|
with colB: |
|
st.metric("VRAM (GB)", f"{vram_gb():.2f}" if vram_gb() else "N/A") |
|
with colC: |
|
st.metric("Used VRAM (GB)", f"{vram_used_gb():.2f}" if vram_used_gb() else "N/A") |
|
with colD: |
|
st.caption(f"PyTorch {torch.__version__} • diffusers ready") |
|
|
|
|
|
if st.session_state.get("clear_inputs", False): |
|
|
|
keys_to_clear = [ |
|
|
|
"main_prompt_input", "style_input", "mood_input", "neg_prompt_input", "style_ms", "mood_ms", |
|
|
|
"emo_free_textarea", "ref_urls_textarea", |
|
|
|
"audio_file_uploader", "data_file_uploader", "formula_input", "ref_images_uploader", |
|
|
|
"enable_emo_checkbox", "enable_rt_checkbox", "enable_ref_checkbox", |
|
"model_choice_selectbox", "gen_mode_selectbox", |
|
"use_lora_checkbox", "lora_path_input", "lora_scale_slider", |
|
"width_input", "height_input", "steps_input", "guidance_input", |
|
"batch_input", "seed_input", "upsample_checkbox", |
|
"use_xformers_checkbox", "attn_slice_checkbox", "vae_slice_checkbox", "vae_tile_checkbox", |
|
"oom_retry_checkbox", |
|
|
|
"city_input", "headlines_slider", |
|
] |
|
for k in keys_to_clear: |
|
st.session_state.pop(k, None) |
|
|
|
|
|
st.session_state["generated_images"] = [] |
|
st.session_state["generation_results"] = [] |
|
|
|
|
|
st.session_state["clear_inputs"] = False |
|
|
|
|
|
|
|
|
|
tab_inputs, tab_refs, tab_model, tab_gallery, tab_presets, tab_export = st.tabs([ |
|
"🧩 Inputs (Text/Audio/Data/Emotion/Real‑time)", |
|
"🖼️ Advanced References", |
|
"⚙️ Model & Performance", |
|
"🖼️ Gallery & Annotate", |
|
"💾 Presets", |
|
"📦 Export" |
|
]) |
|
|
|
|
|
|
|
with tab_inputs: |
|
st.markdown("<div class='section'>", unsafe_allow_html=True) |
|
st.subheader("🧩 Multimodal Inputs") |
|
|
|
|
|
st.markdown("<div class='muted'>Text & Style</div>", unsafe_allow_html=True) |
|
main_prompt = st.text_input( |
|
"Main prompt", |
|
value=st.session_state.get("main_prompt_input", ""), |
|
placeholder="A serene cyberpunk alley at dawn", |
|
key="main_prompt_input", |
|
) |
|
|
|
|
|
STYLE_OPTIONS = [ |
|
"digital painting", "watercolor", "oil painting", "pixel art", "anime", |
|
"3D render", "photorealistic", "line art", "low poly", "cyberpunk", |
|
"isometric", "concept art", "cel shading", "comic book", "impressionist" |
|
] |
|
MOOD_OPTIONS = [ |
|
"dreamy", "luminous", "dark and moody", "whimsical", "serene", |
|
"epic", "melancholic", "vibrant", "mysterious", "dystopian", |
|
"hopeful", "playful", "contemplative", "energetic", "ethereal" |
|
] |
|
|
|
style_selected = st.multiselect( |
|
"Style (choose one or more)", |
|
options=STYLE_OPTIONS, |
|
default=st.session_state.get("style_ms", []), |
|
key="style_ms", |
|
help="Pick one or more styles to condition the artwork" |
|
) |
|
mood_selected = st.multiselect( |
|
"Mood (choose one or more)", |
|
options=MOOD_OPTIONS, |
|
default=st.session_state.get("mood_ms", []), |
|
key="mood_ms", |
|
help="Pick one or more moods to influence the atmosphere" |
|
) |
|
|
|
|
|
style = ", ".join(style_selected) |
|
mood = ", ".join(mood_selected) |
|
|
|
neg_prompt = st.text_input( |
|
"Negative prompt (optional)", |
|
value=st.session_state.get("neg_prompt_input", ""), |
|
placeholder="e.g., low quality, bad anatomy", |
|
key="neg_prompt_input", |
|
) |
|
|
|
st.markdown("</div>", unsafe_allow_html=True) |
|
|
|
|
|
col1, col2, col3, col4 = st.columns(4) |
|
|
|
|
|
with col1: |
|
st.markdown("### 🎵 Audio Analysis") |
|
enable_audio = st.checkbox("Enable Audio Processing", value=False) |
|
audio_caption = "" |
|
audio_tags = [] |
|
tempo = None |
|
|
|
if enable_audio: |
|
audio_file = st.file_uploader("Upload audio (.wav/.mp3)", type=["wav", "mp3"], key="audio_file_uploader") |
|
if audio_file: |
|
|
|
audio_path = OUTPUT_DIR / "tmp_audio.wav" |
|
with open(audio_path, "wb") as f: |
|
f.write(audio_file.read()) |
|
|
|
|
|
y, sr = librosa.load(audio_path.as_posix(), sr=16000) |
|
dur = librosa.get_duration(y=y, sr=sr) |
|
st.caption(f"Duration: {dur:.1f}s") |
|
|
|
|
|
try: |
|
tempo, _ = librosa.beat.beat_track(y=y, sr=sr) |
|
except Exception: |
|
tempo = None |
|
|
|
|
|
rms = float(np.sqrt(np.mean(y**2))) |
|
zcr = float(np.mean(librosa.feature.zero_crossing_rate(y))) |
|
|
|
|
|
if tempo: |
|
if tempo < 90: |
|
audio_tags.append("slow tempo") |
|
elif tempo > 140: |
|
audio_tags.append("fast tempo") |
|
|
|
if rms > 0.04: |
|
audio_tags.append("energetic") |
|
if zcr > 0.12: |
|
audio_tags.append("percussive") |
|
|
|
|
|
st.info("Transcribing audio (Whisper base)…") |
|
w = whisper.load_model("base", device=DEVICE) |
|
wav = whisper.load_audio(audio_path.as_posix()) |
|
wav = whisper.pad_or_trim(wav) |
|
mel = whisper.log_mel_spectrogram(wav).to(DEVICE) |
|
dec = whisper.DecodingOptions(language="en", fp16=(DEVICE=="cuda")) |
|
res = whisper.decode(w, mel, dec) |
|
audio_caption = res.text.strip() |
|
|
|
st.success(f"Caption: '{audio_caption}'") |
|
if audio_tags: |
|
st.write("Audio tags:", ", ".join(audio_tags)) |
|
|
|
|
|
with col2: |
|
st.markdown("### 📊 Data Analysis") |
|
enable_data = st.checkbox("Enable Data Processing", value=False) |
|
data_summary = "" |
|
data_plot = None |
|
|
|
if enable_data: |
|
data_file = st.file_uploader("Upload CSV", type=["csv"], key="data_file_uploader") |
|
formula = st.text_input("Or numpy formula", placeholder="np.sin(np.linspace(0, 20, 200))", key="formula_input") |
|
|
|
if data_file is not None: |
|
df = pd.read_csv(data_file) |
|
st.dataframe(df.head(), use_container_width=True) |
|
|
|
|
|
num = df.select_dtypes(include=np.number) |
|
if not num.empty: |
|
means, mins, maxs, stds = num.mean(), num.min(), num.max(), num.std() |
|
data_summary = f"{len(num)} rows x {num.shape[1]} cols; " + " ".join([ |
|
f"{c}: avg {means[c]:.2f}, min {mins[c]:.2f}, max {maxs[c]:.2f}." |
|
for c in num.columns[:3] |
|
]) |
|
data_summary += " Variability " + ("high." if stds.mean() > 1 else "gentle.") |
|
|
|
|
|
fig = plt.figure(figsize=(6, 3)) |
|
if num.shape[1] == 1: |
|
plt.plot(num.iloc[:, 0]) |
|
plt.title(f"Pattern: {num.columns[0]}") |
|
else: |
|
plt.plot(num.iloc[:, 0], label=num.columns[0]) |
|
plt.plot(num.iloc[:, 1], label=num.columns[1]) |
|
plt.legend() |
|
plt.title("Data Patterns") |
|
plt.tight_layout() |
|
data_plot = save_plot(fig) |
|
st.image(data_plot, caption="Data pattern") |
|
|
|
elif formula.strip(): |
|
try: |
|
arr = eval(formula, {"np": np, "__builtins__": {}}) |
|
arr = np.array(arr) |
|
data_summary = f"Mathematical pattern with {arr.size} points." |
|
|
|
fig = plt.figure(figsize=(6, 3)) |
|
plt.plot(arr) |
|
plt.title("Formula Pattern") |
|
plt.tight_layout() |
|
data_plot = save_plot(fig) |
|
st.image(data_plot, caption="Formula pattern") |
|
except Exception as e: |
|
st.error(f"Formula error: {e}") |
|
|
|
|
|
with col3: |
|
st.markdown("### 💭 Emotion Analysis") |
|
enable_emo = st.checkbox("Enable Emotion Processing", value=False, key="enable_emo_checkbox") |
|
emo_free = st.text_area( |
|
"Describe a feeling/context", |
|
value=st.session_state.get("emo_free_textarea", ""), |
|
key="emo_free_textarea", |
|
) if enable_emo else "" |
|
emo_label = "" |
|
|
|
if enable_emo and emo_free.strip(): |
|
tb = TextBlob(emo_free) |
|
pol = tb.sentiment.polarity |
|
emo_label = "positive, uplifting" if pol > 0.3 else ( |
|
"sad, melancholic" if pol < -0.3 else "neutral, contemplative" |
|
) |
|
st.info(f"Sentiment: {emo_label} (polarity {pol:.2f})") |
|
|
|
|
|
with col4: |
|
st.markdown("### 🌎 Real-time Data") |
|
enable_rt = st.checkbox("Enable Real-time Feeds", value=False, key="enable_rt_checkbox") |
|
rt_context = "" |
|
|
|
if enable_rt: |
|
city = st.text_input("City (weather)", "Toronto", key="city_input") |
|
headlines_num = st.slider("Headlines", 1, 5, 3, key="headlines_slider") |
|
|
|
def get_weather(city): |
|
try: |
|
key = st.secrets.get("OPENWEATHER_KEY", None) if hasattr(st, "secrets") else None |
|
url = "https://api.openweathermap.org/data/2.5/weather" |
|
params = { |
|
"q": city, |
|
"units": "metric", |
|
"appid": key or "9a524f695a4940f392150142250107" |
|
} |
|
r = requests.get(url, params=params, timeout=6).json() |
|
return f"{r['weather'][0]['description']}, {r['main']['temp']:.1f}°C" |
|
except Exception as e: |
|
return f"unavailable ({e})" |
|
|
|
def get_news(n): |
|
try: |
|
feed = feedparser.parse("https://feeds.bbci.co.uk/news/rss.xml") |
|
return "; ".join([e["title"] for e in feed.entries[:n]]) |
|
except Exception as e: |
|
return f"unavailable ({e})" |
|
|
|
w = get_weather(city) |
|
n = get_news(headlines_num) |
|
st.caption(f"Weather: {w}") |
|
st.caption(f"News: {n}") |
|
rt_context = f"Current weather in {city}: {w}. Today's news: {n}." |
|
|
|
|
|
|
|
with tab_refs: |
|
st.subheader("🖼️ Advanced Multi‑Reference + ControlNet") |
|
enable_ref = st.checkbox("Enable Multi-Reference Processing", value=False, key="enable_ref_checkbox") |
|
ref_images: List[Image.Image] = [] |
|
style_idxs = [] |
|
cn_images = [] |
|
img2img_strength = 0.55 |
|
cn_type = "Canny" |
|
cn_scale = 1.0 |
|
|
|
if enable_ref: |
|
|
|
colU, colURL = st.columns(2) |
|
|
|
with colU: |
|
st.markdown("**📁 Upload Images**") |
|
uploads = st.file_uploader( |
|
"Upload reference images", |
|
type=["png", "jpg", "jpeg"], |
|
accept_multiple_files=True, |
|
key="ref_images_uploader" |
|
) |
|
if uploads: |
|
for u in uploads: |
|
try: |
|
im = Image.open(u).convert("RGB") |
|
ref_images.append(im) |
|
except Exception as e: |
|
st.warning(f"Upload failed: {e}") |
|
|
|
with colURL: |
|
st.markdown("**🔗 Image URLs**") |
|
block = st.text_area( |
|
"Paste image URLs (one per line)", |
|
value=st.session_state.get("ref_urls_textarea", ""), |
|
key="ref_urls_textarea", |
|
) |
|
if block.strip(): |
|
for line in block.splitlines(): |
|
url = line.strip() |
|
if not url: |
|
continue |
|
try: |
|
r = requests.get(url, timeout=8) |
|
if r.status_code == 200: |
|
im = Image.open(io.BytesIO(r.content)).convert("RGB") |
|
ref_images.append(im) |
|
except Exception as e: |
|
st.warning(f"URL failed: {e}") |
|
|
|
if ref_images: |
|
|
|
st.image( |
|
ref_images, |
|
width=180, |
|
caption=[f"Ref {i+1}" for i in range(len(ref_images))] |
|
) |
|
|
|
|
|
st.markdown("### 🎨 Reference Role Assignment") |
|
style_idxs = st.multiselect( |
|
"Use as **Style References (img2img)**", |
|
list(range(1, len(ref_images)+1)), |
|
default=list(range(1, len(ref_images)+1)), |
|
help="These images will influence the artistic style and mood" |
|
) |
|
|
|
|
|
use_cn = st.checkbox("Use **ControlNet** for structure", value=HAS_CONTROLNET) |
|
if use_cn and not HAS_CONTROLNET: |
|
st.warning("ControlNet not available in this environment.") |
|
use_cn = False |
|
|
|
if use_cn: |
|
cn_type = st.selectbox("ControlNet type", ["Canny", "Depth"], index=0) |
|
pick = st.selectbox( |
|
"Pick **one** structural reference", |
|
list(range(1, len(ref_images)+1)), |
|
index=0, |
|
help="This image will control the composition and structure" |
|
) |
|
|
|
|
|
base = ref_images[int(pick)-1].resize((512, 512)) |
|
cn_map = canny_map(base) if cn_type == "Canny" else depth_proxy(base) |
|
|
|
st.markdown("**🔍 Live ControlNet Preview**") |
|
st.image( |
|
[base, cn_map], |
|
width=240, |
|
caption=["Selected Reference", f"{cn_type} Map"] |
|
) |
|
cn_images = [cn_map] |
|
cn_scale = st.slider("ControlNet conditioning scale", 0.1, 2.0, 1.0, 0.05) |
|
|
|
|
|
img2img_strength = st.slider( |
|
"img2img strength (style adherence)", |
|
0.2, 0.85, 0.55, 0.05, |
|
help="Higher values follow style references more closely" |
|
) |
|
|
|
|
|
|
|
with tab_model: |
|
st.subheader("⚙️ Model & Performance Management") |
|
st.caption("Choose a base model, optional style add‑ons (LoRA), and tune speed/quality settings.") |
|
|
|
|
|
@st.dialog("Glossary: Common terms") |
|
def show_glossary(): |
|
st.markdown( |
|
""" |
|
- Base model: The foundation that generates images (SD 1.5 = fast, SDXL = higher detail). |
|
- Generation mode: |
|
- txt2img: Create from your text prompt only. |
|
- img2img: Start from an input image and transform it using your text. |
|
- LoRA: A small add‑on that injects a trained style or subject. Use a .safetensors/.pt file. |
|
- Width/Height: Image size in pixels. Bigger = more detail but slower and more VRAM. |
|
- Steps: How long the model refines the image. More steps usually means cleaner details. |
|
- Guidance: How strongly to follow your text. 6–9 is a good range; too high can look unnatural. |
|
- Batch size: How many images at once. Higher uses more VRAM. |
|
- Seed: Randomness control. Reuse the same non‑zero seed to reproduce a result. |
|
- Upscale ×2: Quickly doubles resolution after generation. |
|
- xFormers attention: GPU speed‑up if supported. |
|
- Attention/VAE slicing: Reduce VRAM usage (slightly slower). Keep on for stability. |
|
- VAE tiling: For very large images; decodes in tiles. |
|
- Auto‑retry on CUDA OOM: If VRAM runs out, try again with safer settings. |
|
""" |
|
) |
|
st.button("Close", use_container_width=True) |
|
|
|
def apply_preset(name: str): |
|
ss = st.session_state |
|
def s(k, v): |
|
ss[k] = v |
|
if name == "fast": |
|
s("model_choice_selectbox", "SD 1.5 (v1-5)") |
|
s("gen_mode_selectbox", "txt2img") |
|
s("width_input", 512); s("height_input", 512) |
|
s("steps_input", 30); s("guidance_input", 7.5) |
|
s("batch_input", 1); s("seed_input", 0) |
|
s("upsample_checkbox", False) |
|
s("use_xformers_checkbox", True); s("attn_slice_checkbox", True) |
|
s("vae_slice_checkbox", True); s("vae_tile_checkbox", False) |
|
s("oom_retry_checkbox", True) |
|
elif name == "high": |
|
model = "SDXL Base 1.0" if HAS_SDXL else "SD 1.5 (v1-5)" |
|
s("model_choice_selectbox", model) |
|
s("gen_mode_selectbox", "txt2img") |
|
s("width_input", 768); s("height_input", 768) |
|
s("steps_input", 40); s("guidance_input", 7.0) |
|
s("batch_input", 1); s("seed_input", 0) |
|
s("upsample_checkbox", True) |
|
|
|
|
|
s("use_xformers_checkbox", True); s("attn_slice_checkbox", True) |
|
s("vae_slice_checkbox", True); s("vae_tile_checkbox", False) |
|
s("oom_retry_checkbox", True) |
|
elif name == "low_vram": |
|
s("model_choice_selectbox", "SD 1.5 (v1-5)") |
|
s("gen_mode_selectbox", "txt2img") |
|
s("width_input", 448); s("height_input", 448) |
|
s("steps_input", 25); s("guidance_input", 7.5) |
|
s("batch_input", 1); s("seed_input", 0) |
|
s("upsample_checkbox", False) |
|
s("use_xformers_checkbox", True); s("attn_slice_checkbox", True) |
|
s("vae_slice_checkbox", True); s("vae_tile_checkbox", False) |
|
s("oom_retry_checkbox", True) |
|
elif name == "portrait": |
|
s("gen_mode_selectbox", "txt2img") |
|
s("width_input", 512); s("height_input", 768) |
|
s("steps_input", 30); s("guidance_input", 7.5) |
|
s("batch_input", 1) |
|
elif name == "landscape": |
|
s("gen_mode_selectbox", "txt2img") |
|
s("width_input", 768); s("height_input", 512) |
|
s("steps_input", 30); s("guidance_input", 7.5) |
|
s("batch_input", 1) |
|
elif name == "instagram": |
|
s("gen_mode_selectbox", "txt2img") |
|
s("width_input", 1024); s("height_input", 1024) |
|
s("steps_input", 35); s("guidance_input", 7.0) |
|
s("batch_input", 1); s("upsample_checkbox", False) |
|
elif name == "defaults": |
|
s("model_choice_selectbox", "SD 1.5 (v1-5)") |
|
s("gen_mode_selectbox", "txt2img") |
|
s("width_input", 512); s("height_input", 512) |
|
s("steps_input", 30); s("guidance_input", 7.5) |
|
s("batch_input", 1); s("seed_input", 0) |
|
s("upsample_checkbox", False) |
|
s("use_xformers_checkbox", True); s("attn_slice_checkbox", True) |
|
s("vae_slice_checkbox", True); s("vae_tile_checkbox", False) |
|
s("oom_retry_checkbox", True) |
|
st.rerun() |
|
|
|
colA, colB, colC, colD = st.columns(4) |
|
with colA: |
|
if st.button("⚡ Fast Start"): |
|
apply_preset("fast") |
|
with colB: |
|
if st.button("🔍 High Detail"): |
|
apply_preset("high") |
|
with colC: |
|
if st.button("💻 Low VRAM"): |
|
apply_preset("low_vram") |
|
with colD: |
|
if st.button("❓ Glossary"): |
|
show_glossary() |
|
|
|
|
|
def estimate_pixels(w, h): |
|
return int(w) * int(h) |
|
def vram_risk_level(w, h, steps, batch, model_name): |
|
px = estimate_pixels(w, h) |
|
multiplier = 1.0 if "1.5" in model_name else 2.0 |
|
load = (px / (512*512)) * (steps / 30.0) * max(1, batch) * multiplier |
|
if load < 1.2: |
|
return "✅ Likely safe" |
|
elif load < 2.2: |
|
return "⚠️ May be heavy — consider smaller size or steps" |
|
else: |
|
return "🟥 High risk of OOM — reduce size/batch/steps" |
|
|
|
risk_msg = vram_risk_level( |
|
st.session_state.get("width_input", 512), |
|
st.session_state.get("height_input", 512), |
|
st.session_state.get("steps_input", 30), |
|
st.session_state.get("batch_input", 1), |
|
st.session_state.get("model_choice_selectbox", "SD 1.5 (v1-5)") |
|
) |
|
st.info(f"VRAM safety: {risk_msg}") |
|
|
|
|
|
|
|
|
|
colP0, colP1a, colP2a, colP3a, colP4a = st.columns(5) |
|
with colP0: |
|
if st.button("🧼 Reset to defaults"): |
|
apply_preset("defaults") |
|
with colP1a: |
|
if st.button("🧍 Portrait"): |
|
apply_preset("portrait") |
|
with colP2a: |
|
if st.button("🏞️ Landscape"): |
|
apply_preset("landscape") |
|
with colP3a: |
|
if st.button("📸 Instagram Post"): |
|
apply_preset("instagram") |
|
with colP4a: |
|
st.write("") |
|
|
|
|
|
st.markdown("### 🤖 Model Selection") |
|
model_choice = st.selectbox( |
|
"Base model", |
|
["SD 1.5 (v1-5)"] + (["SDXL Base 1.0"] if HAS_SDXL else []), |
|
index=0, |
|
help="Choose SD 1.5 for speed/low VRAM. Choose SDXL for higher detail (needs more VRAM/CPU).", |
|
key="model_choice_selectbox" |
|
) |
|
gen_mode = st.selectbox( |
|
"Generation mode", |
|
["txt2img", "img2img"], |
|
index=0, |
|
help="txt2img: make an image from your text. img2img: start from a reference image and transform it.", |
|
key="gen_mode_selectbox" |
|
) |
|
|
|
|
|
st.markdown("### 🎭 LoRA Integration") |
|
use_lora = st.checkbox("Attach LoRA", value=False, help="LoRA = small add-on that injects a learned style or subject into the base model.", key="use_lora_checkbox") |
|
lora_path = st.text_input("LoRA path", "", help="Path to the .safetensors/.pt LoRA file.", key="lora_path_input") if use_lora else "" |
|
lora_scale = st.slider("LoRA scale", 0.1, 1.5, 0.8, 0.05, help="How strongly to apply the LoRA. Start at 0.7–0.9.", key="lora_scale_slider") if use_lora else 0.0 |
|
|
|
|
|
st.markdown("### 🎛️ Generation Parameters") |
|
colP1, colP2, colP3, colP4 = st.columns(4) |
|
with colP1: |
|
width = st.number_input("Width", 256, 1536, 512, 64, help="Image width in pixels. Larger = more detail but slower and more VRAM.", key="width_input") |
|
with colP2: |
|
height = st.number_input("Height", 256, 1536, 512, 64, help="Image height in pixels. Common pairs: 512x512 (square), 768x512 (wide).", key="height_input") |
|
with colP3: |
|
steps = st.number_input("Steps", 10, 100, 30, 1, help="How long to refine the image. More steps = better quality but slower.", key="steps_input") |
|
with colP4: |
|
guidance = st.number_input("Guidance", 1.0, 20.0, 7.5, 0.5, help="How strongly to follow your text prompt. 6–9 is a good range.", key="guidance_input") |
|
|
|
colP5, colP6, colP7 = st.columns(3) |
|
with colP5: |
|
batch = st.number_input("Batch size", 1, 6, 1, 1, help="How many images to generate at once. Higher uses more VRAM.", key="batch_input") |
|
with colP6: |
|
seed = st.number_input("Seed (0=random)", 0, 2**31-1, 0, 1, help="Use the same seed to reproduce a result. 0 picks a random seed.", key="seed_input") |
|
with colP7: |
|
upsample_x2 = st.checkbox("Upscale ×2 (latent upscaler)", value=False, help="Quickly doubles the resolution after generation.", key="upsample_checkbox") |
|
|
|
|
|
st.markdown("### ⚡ Performance & Reliability") |
|
st.caption("These options help run on limited VRAM and reduce crashes. If you are new, keep the defaults on.") |
|
colT1, colT2, colT3, colT4 = st.columns(4) |
|
with colT1: |
|
use_xformers = st.checkbox("xFormers attention", value=True, help="Speeds up attention on GPUs that support it.", key="use_xformers_checkbox") |
|
with colT2: |
|
attn_slice = st.checkbox("Attention slicing", value=True, help="Reduces VRAM usage, slightly slower.", key="attn_slice_checkbox") |
|
with colT3: |
|
vae_slice = st.checkbox("VAE slicing", value=True, help="Lower VRAM for the decoder, usually safe to keep on.", key="vae_slice_checkbox") |
|
with colT4: |
|
vae_tile = st.checkbox("VAE tiling", value=False, help="For very large images. Uses tiles to decode.", key="vae_tile_checkbox") |
|
|
|
oom_retry = st.checkbox("Auto‑retry on CUDA OOM", value=True, help="If out‑of‑memory happens, try again with safer settings.", key="oom_retry_checkbox") |
|
|
|
with st.expander("New to this? Quick tips"): |
|
st.markdown( |
|
"- For fast, reliable results: SD 1.5, 512×512, Steps 25–35, Guidance 7.5, Batch 1.\n" |
|
"- Higher detail: try SDXL (needs more VRAM), Steps 30–50, bigger size like 768×768.\n" |
|
"- Seed: 0 = random. Reuse a non‑zero seed to recreate a result.\n" |
|
"- Out‑of‑memory? Lower width/height, set Batch = 1, keep slicing options on.\n" |
|
"- LoRA: paste path to a .safetensors/.pt file. Start scale at 0.7–0.9.\n" |
|
"- Modes: txt2img = from text; img2img = transform an existing image.\n" |
|
"- Upscale ×2: quickly increases resolution after generation." |
|
) |
|
|
|
|
|
|
|
|
|
with tab_inputs: |
|
st.markdown("<div class='section'>", unsafe_allow_html=True) |
|
st.subheader("🎛️ Fusion & Generation") |
|
|
|
|
|
parts = [p for p in [main_prompt, style, mood] if p and p.strip()] |
|
|
|
|
|
if 'audio_caption' in locals() and enable_audio and audio_caption: |
|
parts.append(f"(sound of: {audio_caption})") |
|
if 'tempo' in locals() and enable_audio and tempo: |
|
tempo_desc = "slow tempo" if tempo < 90 else ("fast tempo" if tempo > 140 else "") |
|
if tempo_desc: |
|
parts.append(tempo_desc) |
|
if 'audio_tags' in locals() and enable_audio and audio_tags: |
|
parts.extend(audio_tags) |
|
|
|
|
|
if 'data_summary' in locals() and enable_data and data_summary: |
|
parts.append(f"reflecting data patterns: {data_summary}") |
|
|
|
|
|
if 'emo_label' in locals() and enable_emo and emo_label: |
|
parts.append(f"with a {emo_label} atmosphere") |
|
elif enable_emo and emo_free.strip(): |
|
parts.append(f"evoking the feeling: {emo_free.strip()}") |
|
|
|
|
|
if 'rt_context' in locals() and enable_rt and rt_context: |
|
parts.append(rt_context) |
|
|
|
|
|
final_prompt = ", ".join([p for p in parts if p]) |
|
st.markdown("</div>", unsafe_allow_html=True) |
|
|
|
st.markdown("### 🔮 Fused Prompt Preview") |
|
st.code(final_prompt, language="text") |
|
|
|
|
|
init_image = None |
|
if gen_mode == "img2img" and enable_ref and style_idxs: |
|
|
|
init_image = ref_images[style_idxs[0]-1].resize((int(width), int(height))) |
|
|
|
|
|
col_gen, col_clear = st.columns([3, 1]) |
|
with col_gen: |
|
go = st.button("🚀 Generate Multimodal Art", type="primary", use_container_width=True) |
|
with col_clear: |
|
clear = st.button("🧹 Clear", use_container_width=True) |
|
|
|
|
|
if 'generated_images' not in st.session_state: |
|
st.session_state.generated_images = [] |
|
if 'generation_results' not in st.session_state: |
|
st.session_state.generation_results = [] |
|
|
|
if clear: |
|
|
|
st.session_state["clear_inputs"] = True |
|
st.success("Cleared current prompt and output. Ready for a new prompt.") |
|
st.rerun() |
|
|
|
|
|
@st.cache_resource(show_spinner=True) |
|
def get_txt2img(): |
|
return load_sd15(txt2img=True) |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def get_img2img(): |
|
return load_sd15(txt2img=False) |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def get_sdxl(): |
|
return load_sdxl() |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def get_upscaler(): |
|
return load_upscaler() |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def get_cn(cn_type: str): |
|
return load_controlnet(cn_type) |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def get_cn_i2i(cn_type: str): |
|
return load_controlnet_img2img(cn_type) |
|
|
|
def apply_lora(pipe, lora_path, lora_scale): |
|
"""Apply LoRA to pipeline""" |
|
if not lora_path: |
|
return "No LoRA" |
|
try: |
|
pipe.load_lora_weights(lora_path) |
|
try: |
|
pipe.fuse_lora(lora_scale=lora_scale) |
|
except Exception: |
|
try: |
|
pipe.set_adapters(["default"], adapter_weights=[lora_scale]) |
|
except Exception: |
|
pass |
|
return f"LoRA loaded: {os.path.basename(lora_path)} (scale {lora_scale})" |
|
except Exception as e: |
|
return f"LoRA failed: {e}" |
|
|
|
def upsample_if_any(img: Image.Image): |
|
"""Apply upscaling if enabled""" |
|
if not upsample_x2 or not HAS_UPSCALER: |
|
return img, False, "none" |
|
try: |
|
up = get_upscaler() |
|
with (torch.autocast(DEVICE) if DEVICE == "cuda" else torch.no_grad()): |
|
out = up(prompt="sharp, detailed, high quality", image=img) |
|
return out.images[0], True, "latent_x2" |
|
except Exception as e: |
|
return img, False, f"fail:{e}" |
|
|
|
def log_rows(rows, log_path): |
|
"""Log generation results""" |
|
exists = Path(log_path).exists() |
|
|
|
header = [ |
|
"filepath", "prompt", "neg_prompt", "steps", "guidance", "mode", "seed", |
|
"width", "height", "model", "img2img_strength", "cn_type", "cn_scale", |
|
"upscaled", "timestamp" |
|
] |
|
with open(log_path, "a", newline="", encoding="utf-8") as f: |
|
w = csv.writer(f) |
|
if not exists: |
|
w.writerow(header) |
|
for r in rows: |
|
w.writerow([r.get(k, "") for k in header]) |
|
|
|
|
|
if go: |
|
images, paths = [], [] |
|
|
|
|
|
if model_choice.startswith("SDXL") and HAS_SDXL and gen_mode == "txt2img": |
|
pipe = get_sdxl() |
|
model_id = "SDXL-Base-1.0" |
|
else: |
|
if gen_mode == "txt2img": |
|
pipe = get_txt2img() |
|
model_id = "SD-1.5" |
|
else: |
|
pipe = get_img2img() |
|
model_id = "SD-1.5 (img2img)" |
|
|
|
|
|
xformed = attempt_enable_xformers(pipe) if use_xformers else False |
|
apply_perf(pipe, attn_slice, vae_slice, vae_tile) |
|
|
|
|
|
lora_msg = "" |
|
if use_lora: |
|
lora_msg = apply_lora(pipe, lora_path, lora_scale) |
|
if lora_msg: |
|
st.caption(lora_msg) |
|
|
|
|
|
have_style = bool(style_idxs) |
|
have_cn = enable_ref and bool(cn_images) |
|
|
|
|
|
mode = "T2I" |
|
if have_cn and have_style and HAS_CONTROLNET: |
|
mode = "CN+I2I" |
|
elif have_cn and HAS_CONTROLNET: |
|
mode = "CN" |
|
elif have_style: |
|
mode = "I2I" |
|
|
|
st.info(f"Mode: **{mode}** • Model: **{model_id}** • xFormers: `{xformed}`") |
|
|
|
rows = [] |
|
attempt_list = list(safe_retry_sizes(height, width, steps)) if oom_retry else [(height, width, steps)] |
|
|
|
|
|
for b in range(int(batch)): |
|
ok = False |
|
last_err = None |
|
|
|
for (h_try, w_try, s_try) in attempt_list: |
|
try: |
|
|
|
seed_eff = torch.seed() if seed == 0 else seed + b |
|
gen = torch.manual_seed(seed_eff) if DEVICE == "cpu" else torch.Generator(DEVICE).manual_seed(seed_eff) |
|
|
|
with (torch.autocast(DEVICE) if DEVICE == "cuda" else torch.no_grad()): |
|
if mode == "CN+I2I": |
|
|
|
if CN_IMG2IMG_AVAILABLE: |
|
cn_pipe = get_cn_i2i(cn_type) |
|
init_ref = ref_images[style_idxs[min(b, len(style_idxs)-1)]-1].resize((w_try, h_try)) |
|
out = cn_pipe( |
|
prompt=final_prompt, |
|
image=init_ref, |
|
control_image=[im for im in cn_images], |
|
controlnet_conditioning_scale=cn_scale, |
|
strength=img2img_strength, |
|
num_inference_steps=s_try, |
|
guidance_scale=guidance, |
|
negative_prompt=neg_prompt if neg_prompt.strip() else None, |
|
generator=gen, |
|
) |
|
img = out.images[0] |
|
else: |
|
|
|
cn_pipe = get_cn(cn_type) |
|
cn_out = cn_pipe( |
|
prompt=final_prompt, |
|
image=[im for im in cn_images], |
|
controlnet_conditioning_scale=cn_scale, |
|
num_inference_steps=max(s_try//2, 12), |
|
guidance_scale=guidance, |
|
negative_prompt=neg_prompt if neg_prompt.strip() else None, |
|
generator=gen, |
|
) |
|
struct_img = cn_out.images[0].resize((w_try, h_try)) |
|
i2i = get_img2img() |
|
init_ref = ref_images[style_idxs[min(b, len(style_idxs)-1)]-1].resize((w_try, h_try)) |
|
blend = Image.blend(init_ref, struct_img, 0.5) |
|
out = i2i( |
|
prompt=final_prompt, |
|
image=blend, |
|
strength=img2img_strength, |
|
num_inference_steps=s_try, |
|
guidance_scale=guidance, |
|
negative_prompt=neg_prompt if neg_prompt.strip() else None, |
|
generator=gen, |
|
) |
|
img = out.images[0] |
|
|
|
elif mode == "CN": |
|
|
|
cn_pipe = get_cn(cn_type) |
|
out = cn_pipe( |
|
prompt=final_prompt, |
|
image=[im for im in cn_images], |
|
controlnet_conditioning_scale=cn_scale, |
|
num_inference_steps=s_try, |
|
guidance_scale=guidance, |
|
negative_prompt=neg_prompt if neg_prompt.strip() else None, |
|
generator=gen, |
|
) |
|
img = out.images[0] |
|
|
|
elif mode == "I2I": |
|
|
|
i2i = get_img2img() |
|
init_ref = ref_images[style_idxs[min(b, len(style_idxs)-1)]-1].resize((w_try, h_try)) |
|
out = i2i( |
|
prompt=final_prompt, |
|
image=init_ref, |
|
strength=img2img_strength, |
|
num_inference_steps=s_try, |
|
guidance_scale=guidance, |
|
negative_prompt=neg_prompt if neg_prompt.strip() else None, |
|
generator=gen, |
|
) |
|
img = out.images[0] |
|
|
|
else: |
|
|
|
kwargs = dict( |
|
prompt=final_prompt, |
|
num_inference_steps=s_try, |
|
guidance_scale=guidance, |
|
negative_prompt=neg_prompt if neg_prompt.strip() else None, |
|
generator=gen, |
|
) |
|
if not (model_choice.startswith("SDXL") and HAS_SDXL): |
|
kwargs.update({"height": h_try, "width": w_try}) |
|
out = pipe(**kwargs) |
|
img = out.images[0] |
|
|
|
|
|
upscaled = "none" |
|
if upsample_x2 and HAS_UPSCALER: |
|
img, did_upscale, upscaled = upsample_if_any(img) |
|
|
|
|
|
fname = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{mode}_{w_try}x{h_try}_s{s_try}_g{guidance}_seed{seed_eff}.png" |
|
path = save_image(img, fname) |
|
st.image(img, caption=fname, use_container_width=True) |
|
paths.append(path) |
|
images.append(img) |
|
|
|
|
|
rows.append({ |
|
"filepath": path, |
|
"prompt": final_prompt, |
|
"neg_prompt": neg_prompt, |
|
"steps": s_try, |
|
"guidance": guidance, |
|
"mode": mode, |
|
"seed": seed_eff, |
|
"width": w_try, |
|
"height": h_try, |
|
"model": model_id, |
|
"img2img_strength": img2img_strength if mode in ["I2I", "CN+I2I"] else "", |
|
"cn_type": cn_type if mode in ["CN", "CN+I2I"] else "", |
|
"cn_scale": cn_scale if mode in ["CN", "CN+I2I"] else "", |
|
"upscaled": upscaled, |
|
"timestamp": datetime.now().isoformat() |
|
}) |
|
ok = True |
|
break |
|
|
|
except RuntimeError as e: |
|
if "out of memory" in str(e).lower() and oom_retry and DEVICE == "cuda": |
|
torch.cuda.empty_cache() |
|
st.warning(f"CUDA OOM — retrying at smaller size/steps…") |
|
continue |
|
else: |
|
st.error(f"Runtime error: {e}") |
|
last_err = str(e) |
|
break |
|
except Exception as e: |
|
st.error(f"Error: {e}") |
|
last_err = str(e) |
|
break |
|
|
|
if not ok and last_err: |
|
st.error(f"Failed item {b+1}: {last_err}") |
|
|
|
|
|
if rows: |
|
|
|
log_rows(rows, RUNLOG) |
|
st.success(f"Saved {len(rows)} image(s). Run log updated: {RUNLOG}") |
|
|
|
|
|
|
|
with tab_gallery: |
|
st.subheader("🖼️ Gallery & Filters") |
|
|
|
|
|
def read_logs(): |
|
"""Read and merge all log files""" |
|
frames = [] |
|
for p in [RUNLOG, RUNLOG_3C, RUNLOG_3E]: |
|
if Path(p).exists(): |
|
try: |
|
df = pd.read_csv(p) |
|
df["source_log"] = Path(p).name |
|
frames.append(df) |
|
except Exception as e: |
|
st.warning(f"Failed reading {p}: {e}") |
|
if not frames: |
|
return pd.DataFrame(columns=["filepath"]) |
|
return pd.concat(frames, ignore_index=True).drop_duplicates(subset=["filepath"]) |
|
|
|
def scan_images(): |
|
"""Scan output directory for images""" |
|
rows = [{"filepath": str(p), "filename": p.name} for p in OUTPUT_DIR.glob("*.png")] |
|
return pd.DataFrame(rows) |
|
|
|
def load_annotations(): |
|
"""Load existing annotations""" |
|
if ANNOT_CSV.exists(): |
|
try: |
|
return pd.read_csv(ANNOT_CSV) |
|
except Exception: |
|
pass |
|
return pd.DataFrame(columns=["filepath", "rating", "tags", "notes"]) |
|
|
|
def save_annotations(df): |
|
"""Save annotations to CSV""" |
|
df.to_csv(ANNOT_CSV, index=False) |
|
|
|
|
|
imgs_df = scan_images() |
|
logs_df = read_logs() |
|
ann_df = load_annotations() |
|
meta_df = imgs_df.merge(logs_df, on="filepath", how="left") |
|
|
|
if meta_df.empty: |
|
st.info("No images found in outputs/. Generate some images first.") |
|
else: |
|
|
|
st.markdown("### 🔍 Filter Images") |
|
colf1, colf2, colf3 = st.columns(3) |
|
|
|
with colf1: |
|
mode_opt = ["(all)"] + sorted([m for m in meta_df.get("mode", pd.Series([])).dropna().unique()]) |
|
sel_mode = st.selectbox("Filter by mode", mode_opt, index=0) |
|
|
|
with colf2: |
|
prompt_filter = st.text_input("Filter prompt contains", "") |
|
|
|
with colf3: |
|
min_steps = st.number_input("Min steps", 0, 200, 0, 1) |
|
|
|
|
|
filtered = meta_df.copy() |
|
if sel_mode != "(all)" and "mode" in filtered.columns: |
|
filtered = filtered[filtered["mode"] == sel_mode] |
|
if prompt_filter.strip() and "prompt" in filtered.columns: |
|
filtered = filtered[filtered["prompt"].fillna("").str.contains(prompt_filter, case=False)] |
|
if "steps" in filtered.columns: |
|
try: |
|
filtered = filtered[pd.to_numeric(filtered["steps"], errors="coerce").fillna(0) >= min_steps] |
|
except Exception: |
|
pass |
|
|
|
st.caption(f"{len(filtered)} image(s) match filters.") |
|
|
|
|
|
if not filtered.empty: |
|
st.markdown("### 🖼️ Image Gallery") |
|
cols = st.columns(4) |
|
for i, row in filtered.reset_index(drop=True).iterrows(): |
|
with cols[i % 4]: |
|
p = row["filepath"] |
|
try: |
|
st.image(p, use_container_width=True, caption=os.path.basename(p)) |
|
except Exception: |
|
st.write(os.path.basename(p)) |
|
if "prompt" in row and pd.notna(row["prompt"]): |
|
st.caption(row["prompt"][:120]) |
|
|
|
|
|
st.markdown("---") |
|
st.subheader("✍️ Annotate / Rate / Tag") |
|
choose = st.multiselect("Pick images to annotate", meta_df["filepath"].tolist()) |
|
|
|
if choose: |
|
for path in choose: |
|
st.markdown("---") |
|
st.write(f"**{os.path.basename(path)}**") |
|
try: |
|
st.image(path, width=320) |
|
except Exception: |
|
pass |
|
|
|
|
|
prev = ann_df[ann_df["filepath"] == path] |
|
rating_val = int(prev.iloc[0]["rating"]) if not prev.empty and not pd.isna(prev.iloc[0]["rating"]) else 3 |
|
tags_val = prev.iloc[0]["tags"] if not prev.empty else "" |
|
notes_val = prev.iloc[0]["notes"] if not prev.empty else "" |
|
|
|
|
|
colE1, colE2, colE3 = st.columns([1, 1, 2]) |
|
with colE1: |
|
rating = st.slider( |
|
f"Rating {os.path.basename(path)}", |
|
1, 5, rating_val, 1, |
|
key=f"rate_{path}" |
|
) |
|
with colE2: |
|
tags = st.text_input("Tags", tags_val, key=f"tags_{path}") |
|
with colE3: |
|
notes = st.text_area("Notes", notes_val, key=f"notes_{path}") |
|
|
|
|
|
if (ann_df["filepath"] == path).any(): |
|
ann_df.loc[ann_df["filepath"] == path, ["rating", "tags", "notes"]] = [rating, tags, notes] |
|
else: |
|
ann_df.loc[len(ann_df)] = [path, rating, tags, notes] |
|
|
|
if st.button("💾 Save annotations", use_container_width=True): |
|
save_annotations(ann_df) |
|
st.success("Annotations saved!") |
|
else: |
|
st.info("Select images above to annotate them.") |
|
|
|
|
|
|
|
with tab_presets: |
|
st.subheader("💾 Create / Save / Load Presets") |
|
|
|
|
|
st.markdown("### 🎛️ Create New Preset") |
|
colP1, colP2 = st.columns(2) |
|
|
|
with colP1: |
|
preset_name = st.text_input("Preset name", "my_style", key="preset_name_input") |
|
p_prompt = st.text_input("Prompt", main_prompt or "A serene cyberpunk alley at dawn", key="preset_prompt_input") |
|
p_style = st.text_input("Style", style or "digital painting", key="preset_style_input") |
|
p_mood = st.text_input("Mood", mood or ", ".join(MOOD_OPTIONS[:2]), key="preset_mood_input") |
|
p_neg = st.text_input("Negative", neg_prompt or "", key="preset_neg_input") |
|
|
|
with colP2: |
|
p_steps = st.number_input("Steps", 10, 100, steps or 30, 1, key="preset_steps_input") |
|
p_guid = st.number_input("Guidance", 1.0, 20.0, guidance or 7.5, 0.5, key="preset_guidance_input") |
|
p_i2i = st.slider("img2img strength", 0.2, 0.9, 0.55, 0.05, key="preset_i2i_slider") |
|
p_cn_type = st.selectbox("ControlNet type", ["Canny", "Depth"], key="preset_cn_type_selectbox") |
|
p_cn_scale = st.slider("ControlNet scale", 0.1, 2.0, 1.0, 0.05, key="preset_cn_scale_slider") |
|
|
|
|
|
preset = { |
|
"name": preset_name, |
|
"prompt": p_prompt, |
|
"style": p_style, |
|
"mood": p_mood, |
|
"negative": p_neg, |
|
"steps": p_steps, |
|
"guidance": p_guid, |
|
"img2img_strength": p_i2i, |
|
"controlnet": {"type": p_cn_type, "scale": p_cn_scale}, |
|
"created_at": datetime.now().isoformat() |
|
} |
|
|
|
st.markdown("### 📋 Preset Preview") |
|
st.code(json.dumps(preset, indent=2), language="json") |
|
|
|
|
|
colPS1, colPS2 = st.columns(2) |
|
|
|
with colPS1: |
|
st.markdown("### 💾 Save Preset") |
|
if st.button("💾 Save preset", use_container_width=True, key="save_preset_button"): |
|
if preset_name.strip(): |
|
fp = PRESETS_DIR / f"{preset_name}.json" |
|
with open(fp, "w", encoding="utf-8") as f: |
|
json.dump(preset, f, indent=2) |
|
st.success(f"Saved {fp}") |
|
else: |
|
st.error("Please enter a preset name") |
|
|
|
with colPS2: |
|
st.markdown("### 📂 Load Preset") |
|
existing = sorted([p.name for p in PRESETS_DIR.glob("*.json")]) |
|
if existing: |
|
sel = st.selectbox("Load preset", ["(choose)"] + existing, key="load_preset_selectbox") |
|
if sel != "(choose)": |
|
with open(PRESETS_DIR / sel, "r", encoding="utf-8") as f: |
|
loaded = json.load(f) |
|
st.success(f"Loaded {sel}") |
|
st.code(json.dumps(loaded, indent=2), language="json") |
|
else: |
|
st.info("No presets found. Create your first preset above!") |
|
|
|
|
|
|
|
with tab_export: |
|
st.subheader("📦 Export Bundle (ZIP)") |
|
|
|
|
|
def read_logs_all(): |
|
"""Read all logs for export""" |
|
frames = [] |
|
for p in [RUNLOG, RUNLOG_3C, RUNLOG_3E]: |
|
if Path(p).exists(): |
|
try: |
|
df = pd.read_csv(p) |
|
df["source_log"] = Path(p).name |
|
frames.append(df) |
|
except Exception as e: |
|
st.warning(f"Read fail {p}: {e}") |
|
if not frames: |
|
return pd.DataFrame(columns=["filepath"]) |
|
return pd.concat(frames, ignore_index=True).drop_duplicates(subset=["filepath"]) |
|
|
|
def scan_imgs(): |
|
"""Scan images for export""" |
|
return pd.DataFrame([ |
|
{"filepath": str(p), "filename": p.name} |
|
for p in OUTPUT_DIR.glob("*.png") |
|
]) |
|
|
|
|
|
imgs_df = scan_imgs() |
|
logs_df = read_logs_all() |
|
|
|
if imgs_df.empty: |
|
st.info("No images to export yet. Generate some images first.") |
|
else: |
|
meta_df = imgs_df.merge(logs_df, on="filepath", how="left") |
|
|
|
|
|
st.markdown("### 📋 Available Images") |
|
display_cols = ["filepath", "prompt", "mode", "steps", "guidance"] |
|
available_cols = [col for col in display_cols if col in meta_df.columns] |
|
st.dataframe( |
|
meta_df[available_cols].fillna("").astype(str), |
|
use_container_width=True, |
|
height=240 |
|
) |
|
|
|
|
|
st.markdown("### 🎯 Export Selection") |
|
sel = st.multiselect( |
|
"Select images to export", |
|
meta_df["filepath"].tolist(), |
|
default=meta_df["filepath"].tolist()[:8], |
|
key="export_images_multiselect" |
|
) |
|
|
|
|
|
include_preset = st.checkbox("Include preset.json", value=False, key="include_preset_checkbox") |
|
preset_blob = None |
|
if include_preset: |
|
ex = sorted([p.name for p in PRESETS_DIR.glob("*.json")]) |
|
if ex: |
|
choose = st.selectbox("Choose preset", ex, key="export_preset_selectbox") |
|
with open(PRESETS_DIR / choose, "r", encoding="utf-8") as f: |
|
preset_blob = json.load(f) |
|
else: |
|
st.warning("No presets found in /presets") |
|
include_preset = False |
|
|
|
|
|
bundle_name = st.text_input( |
|
"Bundle name (no spaces)", |
|
f"compi_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}", |
|
key="bundle_name_input" |
|
) |
|
|
|
|
|
if st.button("📦 Create Export Bundle", type="primary", use_container_width=True, key="create_bundle_button"): |
|
if not sel: |
|
st.error("Pick at least one image.") |
|
elif not bundle_name.strip(): |
|
st.error("Please enter a bundle name.") |
|
else: |
|
with st.spinner("Creating export bundle..."): |
|
|
|
tmp_dir = EXPORTS_DIR / bundle_name |
|
if tmp_dir.exists(): |
|
shutil.rmtree(tmp_dir) |
|
(tmp_dir / "images").mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
for p in sel: |
|
try: |
|
shutil.copy2(p, tmp_dir / "images" / os.path.basename(p)) |
|
except Exception as e: |
|
st.warning(f"Copy failed: {p} ({e})") |
|
|
|
|
|
msel = meta_df[meta_df["filepath"].isin(sel)].copy() |
|
msel.to_csv(tmp_dir / "metadata.csv", index=False) |
|
|
|
|
|
if ANNOT_CSV.exists(): |
|
shutil.copy2(ANNOT_CSV, tmp_dir / "annotations.csv") |
|
else: |
|
pd.DataFrame(columns=["filepath", "rating", "tags", "notes"]).to_csv( |
|
tmp_dir / "annotations.csv", index=False |
|
) |
|
|
|
|
|
manifest = { |
|
"bundle_name": bundle_name, |
|
"created_at": datetime.now().isoformat(), |
|
"environment": env_snapshot(), |
|
"includes": { |
|
"images": True, |
|
"metadata_csv": True, |
|
"annotations_csv": True, |
|
"preset_json": bool(preset_blob), |
|
"readme_md": True |
|
} |
|
} |
|
with open(tmp_dir / "manifest.json", "w", encoding="utf-8") as f: |
|
json.dump(manifest, f, indent=2) |
|
|
|
|
|
if preset_blob: |
|
with open(tmp_dir / "preset.json", "w", encoding="utf-8") as f: |
|
json.dump(preset_blob, f, indent=2) |
|
|
|
|
|
with open(tmp_dir / "README.md", "w", encoding="utf-8") as f: |
|
f.write(mk_readme(manifest, msel)) |
|
|
|
|
|
zpath = EXPORTS_DIR / f"{bundle_name}.zip" |
|
if zpath.exists(): |
|
zpath.unlink() |
|
|
|
with zipfile.ZipFile(zpath, 'w', zipfile.ZIP_DEFLATED) as zf: |
|
for root, _, files in os.walk(tmp_dir): |
|
for file in files: |
|
full = Path(root) / file |
|
zf.write(full, full.relative_to(tmp_dir)) |
|
|
|
|
|
shutil.rmtree(tmp_dir, ignore_errors=True) |
|
|
|
st.success(f"✅ Export created: {zpath}") |
|
st.info(f"📁 Bundle size: {zpath.stat().st_size / (1024*1024):.1f} MB") |
|
|
|
|
|
with open(zpath, "rb") as f: |
|
st.download_button( |
|
label="📥 Download Export Bundle", |
|
data=f.read(), |
|
file_name=f"{bundle_name}.zip", |
|
mime="application/zip", |
|
use_container_width=True |
|
) |
|
|
|
|
|
|
|
st.markdown("---") |
|
st.markdown(""" |
|
<div style='text-align: center; color: #666; padding: 20px;'> |
|
<strong>🧪 CompI Phase 3 Final Dashboard</strong><br> |
|
Complete integration of all Phase 3 components (3.A → 3.E)<br> |
|
<em>Multimodal AI Art Generation • Advanced References • Performance Management • Professional Workflow</em> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|