File size: 3,692 Bytes
338d95d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# compi_phase1_text2image.py

import os
import sys
import torch
from datetime import datetime
from diffusers import StableDiffusionPipeline
from PIL import Image

# Add project root to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))

# ------------------ 1. SETUP AND CHECKS ------------------

# Check for GPU
if torch.cuda.is_available():
    device = "cuda"
    print("CUDA GPU detected. Running on GPU for best performance.")
else:
    device = "cpu"
    print("No CUDA GPU detected. Running on CPU. Generation will be slow.")

# Set up output directory
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', "outputs")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Logging function
def log(msg):
    now = datetime.now().strftime("[%Y-%m-%d %H:%M:%S]")
    print(f"{now} {msg}")

# ------------------ 2. LOAD MODEL ------------------

MODEL_NAME = "runwayml/stable-diffusion-v1-5"
log(f"Loading model: {MODEL_NAME} (this may take a minute on first run)")

# Optionally, disable the safety checker for pure creative exploration
def dummy_safety_checker(images, **kwargs):
    return images, [False] * len(images)

try:
    pipe = StableDiffusionPipeline.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        safety_checker=dummy_safety_checker,  # Remove for production!
    )
except Exception as e:
    log(f"Error loading model: {e}")
    sys.exit(1)

pipe = pipe.to(device)
pipe.enable_attention_slicing()  # Reduce VRAM use

log("Model loaded successfully.")

# ------------------ 3. PROMPT HANDLING ------------------

def main():
    """Main function for command-line execution"""
    if len(sys.argv) > 1:
        prompt = " ".join(sys.argv[1:])
        log(f"Prompt taken from command line: {prompt}")
    else:
        prompt = input("Enter your prompt (e.g. 'A magical forest, digital art'): ").strip()
        log(f"Prompt entered: {prompt}")

    if not prompt:
        log("No prompt provided. Exiting.")
        sys.exit(0)

    # ------------------ 4. GENERATION PARAMETERS ------------------

    SEED = torch.seed()  # You can use a fixed seed for reproducibility, e.g. 1234
    generator = torch.manual_seed(SEED) if device == "cpu" else torch.Generator(device).manual_seed(torch.seed())

    num_inference_steps = 30   # More steps = better quality, slower (default 50)
    guidance_scale = 7.5       # Higher = follow prompt more strictly

    # Output image size (SDv1.5 default 512x512)
    height = 512
    width = 512

    # ------------------ 5. IMAGE GENERATION ------------------

    log(f"Generating image for prompt: {prompt}")
    log(f"Params: steps={num_inference_steps}, guidance_scale={guidance_scale}, seed={SEED}")

    with torch.autocast(device) if device == "cuda" else torch.no_grad():
        result = pipe(
            prompt,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator,
        )

        image: Image.Image = result.images[0]

    # ------------------ 6. SAVE OUTPUT ------------------

    # Filename: prompt short, datetime, seed
    prompt_slug = "_".join(prompt.lower().split()[:6])
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"{prompt_slug[:40]}_{timestamp}_seed{SEED}.png"
    filepath = os.path.join(OUTPUT_DIR, filename)
    image.save(filepath)
    log(f"Image saved to {filepath}")

    # Optionally, show image (uncomment next line if running locally)
    # image.show()

    # Log end
    log("Generation complete.")

if __name__ == "__main__":
    main()