Spaces:
Running
on
Zero
Running
on
Zero
| import os, time, zipfile, io | |
| import random | |
| import gradio as gr | |
| from langdetect import detect, DetectorFactory | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from transformers import CLIPTokenizer | |
| DEV_MODE = os.getenv("DEV_MODE_", "0") == "1" | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1920 | |
| MODELS = { | |
| "v150": "John6666/wai-nsfw-illustrious-sdxl-v150-sdxl", | |
| "v140": "Ine007/waiNSFWIllustrious_v140", | |
| "v130": "dhead/waiNSFWIllustrious_v130", | |
| "v120": "votepurchase/waiNSFWIllustrious_v120" | |
| } | |
| # LLM | |
| LLM_PIPELINE = None | |
| MAX_NEW_TOKENS = 80 | |
| if DEV_MODE: | |
| from mock import MockPipe | |
| from collections import defaultdict | |
| pipes = defaultdict(MockPipe) | |
| device = "cpu" | |
| else: | |
| from diffusers import DiffusionPipeline | |
| device = "cuda" | |
| pipes = {} | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| for model_name, model_repo_id in MODELS.items(): | |
| pipes[model_name] = DiffusionPipeline.from_pretrained( | |
| model_repo_id, | |
| torch_dtype=torch_dtype, | |
| use_safetensors=True, | |
| add_watermarker=None, | |
| ).to(device) | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| model_id = os.getenv("llm_repo", "") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_8bit=True | |
| ) | |
| tok = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_TOKEN", "")) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| token=os.getenv("HF_TOKEN", "") | |
| ) | |
| LLM_PIPELINE = pipeline("text-generation", model=model, tokenizer=tok) | |
| DetectorFactory.seed = 0 | |
| def _apply_preset_ui(preset): | |
| w, h = apply_preset(preset) | |
| return int(w), int(h) | |
| def apply_preset(preset): | |
| mapping = { | |
| "768×768 (square)": (768, 768), | |
| "1024×1024": (1024, 1024), | |
| "832×1216 (portrait)": (832, 1216), | |
| "1152×896 (landscape)": (1152, 896), | |
| "768×1344 (portrait, lighter)": (768, 1344), | |
| } | |
| return mapping.get(preset, (1024, 768)) | |
| def detect_language(text): | |
| try: | |
| lang = detect(text) | |
| except Exception: | |
| lang = "en" | |
| return lang | |
| def infer( | |
| model: str, | |
| prompt: str, | |
| quality_prompt: str, | |
| negative_prompt: str, | |
| seed: int, | |
| randomize_seed: bool, | |
| width: int, | |
| height: int, | |
| guidance_scale: float, | |
| num_inference_steps: int, | |
| num_images: int, | |
| history: list, | |
| use_quality: bool, | |
| language_warning_count, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| # detect non-english text only first time | |
| if language_warning_count < 1: | |
| prompt_lang = detect_language(prompt) | |
| if prompt_lang != "en": | |
| language_warning_count += 1 | |
| gr.Warning( | |
| f"If your prompt contains non-English characters ({prompt_lang}), " | |
| f"enable translation in advanced settings." | |
| ) | |
| # call _infer WITHOUT language_warning_count | |
| last_fit, last_raw, base_seed, history, history_dup = _infer( | |
| model, prompt, quality_prompt, negative_prompt, seed, randomize_seed, | |
| width, height, guidance_scale, num_inference_steps, num_images, | |
| history, use_quality, progress=gr.Progress(track_tqdm=True), | |
| ) | |
| # return updated state as the last output | |
| return last_fit, last_raw, base_seed, history, history_dup, language_warning_count | |
| def _infer( | |
| model: str, | |
| prompt: str, | |
| quality_prompt: str, | |
| negative_prompt: str, | |
| seed: int, | |
| randomize_seed: bool, | |
| width: int, | |
| height: int, | |
| guidance_scale: float, | |
| num_inference_steps: int, | |
| num_images: int, | |
| history: list, | |
| use_quality: bool, | |
| progress=gr.Progress(track_tqdm=True), | |
| ) -> tuple: | |
| pipe = pipes[model] | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| base_seed = int(seed) | |
| full_prompt = (prompt + "," + quality_prompt) if use_quality else prompt | |
| print(f"original: {full_prompt}") | |
| ids = tokenizer(full_prompt)["input_ids"] | |
| print(f"ids: {ids}\n------------------------------------") | |
| history = history or [] | |
| last_img = None | |
| for i in range(int(num_images)): | |
| gen = torch.Generator(device=device).manual_seed(base_seed + i) | |
| img = pipe( | |
| prompt=full_prompt, | |
| negative_prompt=negative_prompt or None, | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(num_inference_steps), | |
| width=int(width), | |
| height=int(height), | |
| generator=gen, | |
| ).images[0] | |
| caption = f"seed={base_seed + i}, {width}x{height}, steps={num_inference_steps}, cfg={guidance_scale}, model={model}" | |
| history.append((img, caption)) | |
| last_img = img | |
| # send same image to both views (fit + raw), return base seed | |
| return last_img, last_img, base_seed, history, history | |
| def clear_history(): | |
| return [], [] | |
| def download_all(history): | |
| if not history: | |
| return None | |
| ts = time.strftime("%Y%m%d_%H%M%S") | |
| zip_path = f"/tmp/sdxl_session_{ts}.zip" | |
| with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: | |
| for idx, (img, caption) in enumerate(history, start=1): | |
| try: | |
| seed_val = caption.split(",")[0].split("=")[1].strip() | |
| base = f"{idx:03d}_seed{seed_val}" | |
| except Exception: | |
| base = f"{idx:03d}" | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| buf.seek(0) | |
| zf.writestr(f"{base}.png", buf.read()) | |
| return zip_path | |
| def toggle_controls(hide: bool): | |
| return gr.update(visible=not hide) | |
| def compute_token_count(prompt: str, quality_prompt: str, use_quality: bool): | |
| full_prompt = (prompt + quality_prompt) if use_quality else prompt | |
| return len(tokenizer(full_prompt)["input_ids"]) - 2 | |
| def toggle_quality_prompt(enabled: bool): | |
| return gr.update(interactive=enabled) | |
| examples = [ | |
| "anime girl with flowing silver hair, cherry blossoms in the background, soft lighting, detailed eyes, high resolution", | |
| "samurai standing in a bamboo forest at night, glowing lanterns, cinematic lighting, dramatic pose", | |
| "school rooftop at sunset, two characters looking at each other, detailed clouds, anime style", | |
| "cyberpunk city street, neon lights, rainy atmosphere, anime illustration, high detail", | |
| "fantasy anime landscape with floating islands and waterfalls, vibrant colors, wide shot", | |
| ] | |
| css = """ | |
| #col-container { margin: 0 auto; max-width: 1250px; width: 100%; padding: 0 12px; } | |
| #left-col { position: sticky; top: 12px; align-self: start; } | |
| /* responsive image (FIT view) */ | |
| #result_fit img { max-height: 700px; width: auto; height: auto; } | |
| /* 'Hide controls' toggle only visible on small screens */ | |
| #hide_controls_row { display: none; } | |
| #title { text-align: center; } | |
| @media (max-width: 768px) { | |
| #hide_controls_row { display: block; margin-bottom: 8px; } | |
| #left-col { position: static; } /* less sticky on small screens */ | |
| } | |
| """ | |
| custom_theme = gr.themes.Soft( | |
| primary_hue="violet", # overall accent = violet | |
| secondary_hue="fuchsia", # secondary accents | |
| neutral_hue="slate" # neutral surfaces/text | |
| ).set( | |
| # --- Backgrounds (gradient) --- | |
| body_background_fill=( | |
| ), | |
| body_background_fill_dark=( | |
| "#0c0a24" | |
| ), | |
| # --- Blocks / cards (semi-transparent to fit bg) --- | |
| block_background_fill="rgba(255, 255, 255, 0.65)", # light: milky panel | |
| block_background_fill_dark="rgba(20, 18, 40, 0.55)", # dark: inky panel | |
| block_border_color="rgba(84, 76, 140, 0.35)", # light: violet-gray outline | |
| block_border_color_dark="rgba(164, 148, 255, 0.18)", # dark: subtle lilac outline | |
| block_shadow="0 12px 30px rgba(93, 87, 160, 0.25)", # light shadow | |
| block_shadow_dark="0 16px 36px rgba(0, 0, 0, 0.45)", # dark shadow | |
| # --- Inputs (textboxes, dropdowns, sliders) --- | |
| input_background_fill="rgba(255, 255, 255, 0.9)", # light input bg | |
| input_background_fill_dark="rgba(14, 12, 30, 0.65)", # dark input bg | |
| input_border_color="rgba(107, 114, 255, 0.45)", # light border indigo | |
| input_border_color_dark="rgba(131, 118, 255, 0.28)", # dark border lilac | |
| input_placeholder_color="rgba(23, 19, 43, 0.45)", # light placeholder | |
| input_placeholder_color_dark="rgba(246, 245, 255, 0.45)", # dark placeholder | |
| # === PRIMARY BUTTONS === | |
| button_primary_text_color="#ffffff", | |
| button_primary_text_color_dark="#ffffff", | |
| button_primary_background_fill="linear-gradient(135deg, #7b5cff 0%, #c14cff 100%)", # light | |
| button_primary_background_fill_dark="linear-gradient(135deg, #5c47d6 0%, #8a2ec9 100%)", # dark | |
| button_primary_background_fill_hover="linear-gradient(135deg, #8b6bff 0%, #d85cff 100%)", # light hover | |
| button_primary_background_fill_hover_dark="linear-gradient(135deg, #6a56ea 0%, #a23bdd 100%)", # dark hover | |
| # === SECONDARY BUTTONS === | |
| button_secondary_text_color="#2a2550", # light | |
| button_secondary_text_color_dark="#e8e6ff", # dark | |
| button_secondary_background_fill="rgba(255,255,255,0.55)", # light | |
| button_secondary_background_fill_dark="rgba(255,255,255,0.10)", # dark | |
| button_secondary_background_fill_hover="rgba(255,255,255,0.75)", | |
| button_secondary_background_fill_hover_dark="rgba(255,255,255,0.18)", | |
| # --- Text colors tuned for readability on purple bg --- | |
| body_text_color="#000000", | |
| body_text_color_dark="#e9e8ff", | |
| body_text_color_subdued="#000000", | |
| body_text_color_subdued_dark="rgba(233,232,255,0.75)", | |
| link_text_color="#000000", # violet-300-ish | |
| link_text_color_dark="#a78bfa", # violet-400-ish | |
| link_text_color_active="#000000", # fuchsia-200-ish | |
| link_text_color_active_dark="#e9d5ff", # fuchsia-300-ish | |
| # === SLIDER / CHECK / RADIO ACCENTS === | |
| slider_color="#7048ff", # light rail/handle | |
| slider_color_dark="#b89cff", # dark rail/handle | |
| checkbox_label_text_color="#1f1a39", # light label | |
| checkbox_label_text_color_dark="#ecebff", # dark label | |
| ) | |
| with gr.Blocks(css=css, theme=custom_theme) as demo: | |
| history_state = gr.State([]) | |
| language_warning_count = gr.State(0) | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# SDXL Text-to-Image (waiNSFWIllustrious_v12-v14)", elem_id="title") | |
| with gr.Row(): | |
| # LEFT: controls | |
| with gr.Column(scale=1, elem_id="left-col"): | |
| with gr.Row(elem_id="hide_controls_row"): | |
| hide_controls_cb = gr.Checkbox(label="Hide advanced controls (mobile friendly)", value=False) | |
| with gr.Group(visible=True) as controls_group: | |
| non_english_text = gr.Textbox( | |
| label="Prompt to translate", placeholder="Enter text to translate", interactive=True, visible=False | |
| ) | |
| translate_btn = gr.Button("translate", visible=False) | |
| prompt = gr.Text( | |
| label="Prompt", | |
| lines=2, | |
| max_lines=6, | |
| placeholder="Enter your prompt", | |
| scale=1, | |
| min_width=0, | |
| autofocus=True, | |
| ) | |
| quality_prompt = gr.Text( | |
| label="Quality prompt", | |
| value="masterpiece, best quality, fine details" | |
| ) | |
| with gr.Group(visible=True) as adv_controls_group: | |
| quality_prompt_toggle = gr.Checkbox( | |
| label="Use quality prompt", | |
| value=True | |
| ) | |
| generations = gr.Slider( | |
| label="Generations", | |
| maximum=10, | |
| minimum=1, | |
| step=1, | |
| value=1, | |
| info="Control how many images are generated sequentially.", | |
| ) | |
| model = gr.Radio( | |
| choices=MODELS.keys(), | |
| value="v140", | |
| info="choose the model you want.", | |
| label="Model", | |
| ) | |
| token_count = gr.Number( | |
| label="Token count", | |
| info="SDXL models work best when the token count is <= 77." | |
| ) | |
| run_button = gr.Button("Run", variant="primary") | |
| # RIGHT: image + toggle | |
| with gr.Column(scale=2): | |
| # two image views: FIT (responsive) and RAW (no scaling) | |
| result_fit = gr.Image(label="Result", show_label=False, elem_id="result_fit", visible=True) | |
| result_raw = gr.Image(label="Result (original size)", show_label=False, visible=False) | |
| # Advanced settings | |
| with gr.Accordion("Advanced Settings", open=False): | |
| no_rescale_cb = gr.Checkbox( | |
| label="Do not rescale to fit screen", | |
| value=False, | |
| info="Uncheck = fit preview to screen (default).", | |
| visible=True | |
| ) | |
| translation_cb = gr.Checkbox( | |
| label="Enable translation", | |
| value=False, | |
| info="Enable translation for the prompt.", | |
| visible=True, | |
| ) | |
| negative_prompt = gr.Text( | |
| label="Negative prompt", | |
| max_lines=1, | |
| placeholder="Enter a negative prompt", | |
| value="blurry, low quality, watermark, monochrome, text", | |
| ) | |
| with gr.Row(): | |
| size_preset = gr.Dropdown( | |
| ["768×768 (square)", "1024×1024", "832×1216 (portrait)", "1152×896 (landscape)", "768×1344 (portrait, lighter)"], | |
| value="1024×1024", | |
| label="Size preset", | |
| ) | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
| height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=12.0, step=0.1, value=6) | |
| num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=75, step=1, value=30) | |
| gr.Examples(examples=examples, inputs=[prompt]) | |
| # Gallery + actions | |
| gallery = gr.Gallery(label="History", preview=True, columns=3, height=320) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear history", variant="secondary") | |
| download_btn = gr.Button("Download all") | |
| size_preset.change(fn=_apply_preset_ui, inputs=[size_preset], outputs=[width, height]) | |
| def toggle_rescale(no_rescale: bool): | |
| # show FIT when not checked; show RAW when checked | |
| return gr.update(visible=not no_rescale), gr.update(visible=no_rescale) | |
| no_rescale_cb.change(fn=toggle_rescale, inputs=[no_rescale_cb], outputs=[result_fit, result_raw]) | |
| def toggle_translate(on: bool): | |
| return gr.update(visible=on), gr.update(visible=on) | |
| def move_prompt_to_non_english(prompt_text: str): | |
| return gr.update(value=prompt_text), gr.update(value="") | |
| translation_cb.change(fn=toggle_translate, inputs=[translation_cb], outputs=[non_english_text, translate_btn]) | |
| translation_cb.change(fn=move_prompt_to_non_english, inputs=[prompt], outputs=[non_english_text, prompt]) | |
| def translate_text(text): | |
| messages = [ | |
| {"role": "user", "content": f"translate the following text into English: <start>{text}<end>. return the translated text only!"}, | |
| ] | |
| translated_text = LLM_PIPELINE(messages, max_new_tokens=MAX_NEW_TOKENS, return_full_text=False) | |
| print("--------------------translation:------------------------ \n" | |
| f"non eng text: {text}" | |
| f"translated: {translated_text[0]['generated_text']}") | |
| return translated_text[0]['generated_text'] | |
| translate_btn.click(fn=translate_text, inputs=[non_english_text], outputs=[prompt]) | |
| # Mobile: hide/show controls group | |
| hide_controls_cb.change(fn=toggle_controls, inputs=[hide_controls_cb], outputs=[adv_controls_group]) | |
| # Clear & Download | |
| clear_btn.click(fn=clear_history, inputs=None, outputs=[gallery, history_state]) | |
| download_btn.click(fn=download_all, inputs=[history_state], outputs=[gr.File(label="images.zip")]) | |
| quality_prompt_toggle.change( | |
| fn=toggle_quality_prompt, | |
| inputs=[quality_prompt_toggle], | |
| outputs=[quality_prompt] | |
| ) | |
| prompt.change( | |
| fn=compute_token_count, | |
| inputs=[prompt, quality_prompt, quality_prompt_toggle], | |
| outputs=[token_count] | |
| ) | |
| quality_prompt.change( | |
| fn=compute_token_count, | |
| inputs=[prompt, quality_prompt, quality_prompt_toggle], | |
| outputs=[token_count] | |
| ) | |
| gr.on( | |
| triggers=[run_button.click, prompt.submit], | |
| fn=infer, | |
| inputs=[model, prompt, quality_prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, generations, history_state, quality_prompt_toggle, language_warning_count], | |
| outputs=[result_fit, result_raw, seed, gallery, history_state, language_warning_count], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |