File size: 9,099 Bytes
5f38454
 
 
 
 
 
318dc42
5f38454
318dc42
21e3bc8
 
318dc42
21e3bc8
318dc42
21e3bc8
318dc42
 
 
 
21e3bc8
 
 
318dc42
 
 
 
 
 
 
 
21e3bc8
 
318dc42
 
 
 
 
 
 
 
21e3bc8
 
 
 
 
 
 
 
 
 
 
318dc42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21e3bc8
 
318dc42
 
21e3bc8
318dc42
 
21e3bc8
318dc42
21e3bc8
 
318dc42
21e3bc8
318dc42
 
21e3bc8
318dc42
 
 
21e3bc8
 
 
318dc42
 
 
21e3bc8
318dc42
 
21e3bc8
318dc42
 
 
21e3bc8
318dc42
21e3bc8
 
 
 
318dc42
 
 
21e3bc8
 
 
 
 
318dc42
2a256c2
2700e0a
318dc42
2700e0a
fe29fc7
11ba365
fe29fc7
 
 
 
2700e0a
318dc42
fe29fc7
318dc42
 
 
4d9b874
 
318dc42
 
 
 
 
 
 
2a256c2
318dc42
2700e0a
 
 
 
11ba365
 
 
318dc42
 
 
 
2700e0a
 
 
318dc42
2a256c2
 
318dc42
2700e0a
11ba365
2700e0a
318dc42
2700e0a
11ba365
2700e0a
318dc42
 
2700e0a
 
 
318dc42
2700e0a
11ba365
318dc42
11ba365
 
318dc42
2700e0a
318dc42
 
 
2700e0a
 
11ba365
2700e0a
 
989094e
cf527f7
fe29fc7
a3b38cb
318dc42
fe29fc7
 
 
 
 
318dc42
 
fe29fc7
2a256c2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import os
import torch
import gradio as gr
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms as tfms
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler  # Import DPMSolver

# 1.  Device and dtype: Correctly determine device and dtype.  Use float16 if CUDA is available.
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32
print(f"Using device: {torch_device}, dtype: {torch_dtype}")  # Helpful for debugging

# 2.  Model Path and Loading: Use a more efficient scheduler and reduce memory usage.
model_path = "CompVis/stable-diffusion-v1-4"

# Use DPMSolverMultistepScheduler for faster and higher-quality sampling
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler")

sd_pipeline = DiffusionPipeline.from_pretrained(
    model_path,
    torch_dtype=torch_dtype,
    scheduler=scheduler,  # Use the DPM scheduler
    # low_cpu_mem_usage is deprecated,  but still helpful on CPU.
    low_cpu_mem_usage=True if torch_device == "cpu" else False,
    #  Use attention slicing to reduce VRAM usage during inference.
    # This has a small performance cost but significantly lowers memory.
     safety_checker=None, #Removing the safety checker to avoid false positives blocking image generation
    requires_safety_checker=False

).to(torch_device)

# Optimize attention for memory efficiency (if using CUDA)
if torch_device == "cuda":
    sd_pipeline.enable_xformers_memory_efficient_attention()  # Use xformers if installed!
    # OR, if xformers is not available:
    # sd_pipeline.enable_attention_slicing() # Use attention slicing (less effective, but built-in)

# 3.  Textual Inversion Loading: Load *only* the necessary concepts. Load them one by one.
#   This is *much* more memory efficient than loading all at once.

style_token_dict = {
    "Illustration Style": '<illustration-style>',
    "Line Art": '<line-art>',
    "Hitokomoru Style": '<hitokomoru-style-nao>',
    "Marc Allante": '<Marc_Allante>',
    "Midjourney": '<midjourney-style>',
    "Hanfu Anime": '<hanfu-anime-style>',
    "Birb Style": '<birb-style>'
}

# Load inversions individually.  This is crucial for managing memory.
def load_inversion(concept_name, token):
    try:
        sd_pipeline.load_textual_inversion(f"sd-concepts-library/{concept_name}", token=token)
        print(f"Loaded textual inversion: {concept_name}")
    except Exception as e:
        print(f"Error loading {concept_name}: {e}")

# Load each style individually.
load_inversion("illustration-style", style_token_dict["Illustration Style"])
load_inversion("line-art", style_token_dict["Line Art"])
load_inversion("hitokomoru-style-nao", style_token_dict["Hitokomoru Style"])
load_inversion("style-of-marc-allante", style_token_dict["Marc Allante"])
load_inversion("midjourney-style", style_token_dict["Midjourney"])
load_inversion("hanfu-anime-style", style_token_dict["Hanfu Anime"])
load_inversion("birb-style", style_token_dict["Birb Style"])



# 4. Guidance Function: Optimized for speed and clarity.
def apply_guidance(image, guidance_method, loss_scale):
    img_tensor = tfms.ToTensor()(image).unsqueeze(0).to(torch_device)
    loss_scale = loss_scale / 10000.0  # Pre-calculate for efficiency

    if guidance_method == 'Grayscale':
        gray = tfms.Grayscale(num_output_channels=3)(img_tensor) # keep 3 channels
        guided = img_tensor + (gray - img_tensor) * loss_scale
    elif guidance_method == 'Bright':
        guided = torch.clamp(img_tensor * (1 + loss_scale), 0, 1)  # Direct brightness adjustment
    elif guidance_method == 'Contrast':
        mean = img_tensor.mean()
        guided = torch.clamp((img_tensor - mean) * (1 + loss_scale) + mean, 0, 1) # Contrast adjustment
    elif guidance_method == 'Symmetry':
        flipped = torch.flip(img_tensor, [3])
        guided = img_tensor + (flipped - img_tensor) * loss_scale
    elif guidance_method == 'Saturation':
        # Use torchvision's functional approach for efficiency.
        guided = tfms.functional.adjust_saturation(img_tensor, 1 + loss_scale)
        guided = torch.clamp(guided, 0, 1)
    else:
        return image

    # Convert back to PIL Image (optimized for conciseness)
    guided = tfms.ToPILImage()(guided.squeeze(0).cpu())
    return guided


# 5. Inference Function:  Use the pipeline efficiently.
def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size):
    prompt = f"{text} {style_token_dict[style]}"
    width, height = map(int, image_size.split('x'))
    generator = torch.Generator(device=torch_device).manual_seed(seed)

    # Generate image (more concise)
    image_pipeline = sd_pipeline(
        prompt,
        num_inference_steps=inference_step,
        guidance_scale=guidance_scale,
        generator=generator,
        height=height,
        width=width,
    ).images[0]

    image_guide = apply_guidance(image_pipeline, guidance_method, loss_scale)
    return image_pipeline, image_guide

# 6. Gradio Interface (CSS and HTML remain largely the same, but I've included minor improvements)
css_and_html = """
<style>
    /* Your CSS here - mostly unchanged, but I've added a few tweaks */
    body {
        background: linear-gradient(135deg, #1a1c2c, #4a4e69, #9a8c98);
        font-family: 'Arial', sans-serif;
        color: #f2e9e4;
        margin: 0;
        padding: 0;
        min-height: 100vh;
    }
    /* ... (Rest of your CSS) ... */
    .gr-box {
    background-color: rgba(255, 255, 255, 0.1) !important;
    border: 1px solid rgba(255, 255, 255, 0.2) !important;
    border-radius: 0.5em !important; /* Add border-radius */
    }

.gr-input, .gr-button, .gr-dropdown, .gr-slider {
    background-color: rgba(255, 255, 255, 0.1) !important;
    color: #f2e9e4 !important;
    border: 1px solid rgba(255, 255, 255, 0.2) !important;
    border-radius: 0.5em !important; /* Add border-radius */
}
    /* ... (Rest of your CSS) ... */

</style>
<div id="app-header">
    <div class="artifact large"></div>
    <div class="artifact medium"></div>
    <div class="artifact small"></div>
    <h1>Dreamscape Creator</h1>
    <p>Unleash your imagination with AI-powered generative art</p>
    <div class="concept-container">
        <div class="concept"><div class="concept-emoji">🎨</div><div class="concept-description">Illustration Style</div></div>
        <div class="concept"><div class="concept-emoji">✏️</div><div class="concept-description">Line Art</div></div>
        <div class="concept"><div class="concept-emoji">🌌</div><div class="concept-description">Midjourney Style</div></div>
        <div class="concept"><div class="concept-emoji">πŸ‘˜</div><div class="concept-description">Hanfu Anime</div></div>
    </div>
</div>
"""

with gr.Blocks(css=css_and_html) as demo:
    gr.HTML(css_and_html)

    with gr.Row():
        text = gr.Textbox(label="Prompt", placeholder="Describe your dreamscape...")
        style = gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style")

    with gr.Row():
        inference_step = gr.Slider(1, 50, 20, step=1, label="Inference steps")
        guidance_scale = gr.Slider(1, 10, 7.5, step=0.1, label="Guidance scale")
        seed = gr.Slider(0, 10000, 42, step=1, label="Seed", randomize=True)  # Add randomize

    with gr.Row():
        guidance_method = gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast', 'Symmetry', 'Saturation'], value="Grayscale")
        loss_scale = gr.Slider(100, 10000, 200, step=100, label="Loss scale")

    with gr.Row():
        image_size = gr.Radio(["256x256", "512x512"], label="Image Size", value="256x256")

    with gr.Row():
        generate_button = gr.Button("Create Dreamscape", variant="primary")

    with gr.Row():
        output_image = gr.Image(label="Your Dreamscape", interactive=False)  # Disable interaction
        output_image_guided = gr.Image(label="Guided Dreamscape", interactive=False) # Disable interaction

    generate_button.click(
        inference,
        inputs=[text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size],
        outputs=[output_image, output_image_guided]
    )

    gr.Examples(
        examples=[
            ["Magical Forest with Glowing Trees", 'Birb Style', 40, 7.5, 42, 'Grayscale', 200, "256x256"],
            ["Ancient Temple Ruins at Sunset", 'Midjourney', 30, 8.0, 123, 'Bright', 5678, "256x256"],
            ["Japanese garden with cherry blossoms", 'Hitokomoru Style', 40, 7.0, 789, 'Contrast', 250, "256x256"],
        ],
        inputs=[text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size],
        outputs=[output_image, output_image_guided],
        fn=inference,
        # cache_examples=True, # Caching can be problematic on Spaces, especially with limited RAM.  Disable if needed.
         cache_examples=False,
        examples_per_page=5
    )

if __name__ == "__main__":
    demo.launch()