File size: 6,859 Bytes
9d8c3ac
 
e6892ca
9d8c3ac
 
 
 
 
 
b8b69c3
e6892ca
9d8c3ac
b8b69c3
c6498ec
84e9994
 
c6498ec
 
 
9d8c3ac
 
c6498ec
9d8c3ac
 
 
e6892ca
84e9994
e6892ca
9d8c3ac
c6498ec
9d8c3ac
 
e6892ca
c6498ec
 
9d8c3ac
e6892ca
9d8c3ac
e6892ca
 
 
c6498ec
9d8c3ac
 
c6498ec
b8b69c3
 
c6498ec
b8b69c3
 
 
 
 
 
 
 
 
 
 
 
 
e6892ca
 
 
c6498ec
b8b69c3
c6498ec
 
 
 
 
 
9d8c3ac
 
 
c6498ec
9d8c3ac
5c955c9
 
 
 
9d8c3ac
 
c6498ec
 
 
 
 
 
70216af
 
 
 
 
 
 
 
 
 
e6892ca
 
 
c6498ec
 
 
e6892ca
 
c6498ec
 
 
 
 
 
 
e6892ca
 
 
 
 
 
 
c6498ec
b8b69c3
c6498ec
 
70216af
e6892ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6498ec
 
 
 
 
 
 
 
 
 
 
 
e6892ca
c6498ec
 
 
 
 
9d8c3ac
 
c6498ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from flask import Flask, request, jsonify
from flask_cors import CORS
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, UniPCMultistepScheduler
import torch
import os
from PIL import Image
import base64
import time
import logging
from huggingface_hub import list_repo_files
import psutil

# Disable GPU detection (remove these lines if GPU is available)
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["CUDA_DEVICE_ORDER"] = ""
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
torch.set_default_device("cpu")

app = Flask(__name__, static_folder='static')
CORS(app)

# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Log device and memory info
logger.info(f"Device in use: {torch.device('cpu')}")
logger.info(f"Available memory: {psutil.virtual_memory().available / (1024 ** 3):.2f} GB")

# Model cache
model_cache = {}
model_paths = {
    "ssd-1b": "segmind/SSD-1B",  # Using segmind/SSD-1B for testing
    "sd-v1-5": "remiai3/stable-diffusion-v1-5"
}

# Image ratio to dimensions (optimized for CPU, multiple of 8)
ratio_to_dims = {
    "1:1": (512, 512),  # Default for SSD-1B
    "3:4": (384, 512),
    "16:9": (512, 288)
}

def load_model(model_id):
    if model_id not in model_cache:
        logger.info(f"Loading model {model_id} from {model_paths[model_id]}")
        logger.info(f"HF_TOKEN present: {os.getenv('HF_TOKEN') is not None}")
        try:
            # Log repository files for debugging
            repo_files = list_repo_files(model_paths[model_id], token=os.getenv("HF_TOKEN"))
            logger.info(f"Files in {model_paths[model_id]}: {repo_files}")

            # Choose pipeline based on model
            pipe_class = StableDiffusionXLPipeline if model_id == "ssd-1b" else StableDiffusionPipeline
            pipe = pipe_class.from_pretrained(
                model_paths[model_id],
                torch_dtype=torch.float32,
                use_auth_token=os.getenv("HF_TOKEN"),
                use_safetensors=True,
                low_cpu_mem_usage=True
            )
            # Use UniPCMultistepScheduler for SSD-1B, DPMSolver for SD-v1-5
            scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) if model_id == "ssd-1b" else DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
            pipe.scheduler = scheduler
            pipe.enable_attention_slicing()
            pipe.to(torch.device("cpu"))  # Change to "cuda" if GPU is available
            model_cache[model_id] = pipe
            logger.info(f"Model {model_id} loaded successfully")
        except Exception as e:
            logger.error(f"Error loading model {model_id}: {str(e)}")
            raise
    return model_cache[model_id]

@app.route('/')
def index():
    return app.send_static_file('index.html')

@app.route('/assets/<path:filename>')
def serve_assets(filename):
    return app.send_static_file(os.path.join('assets', filename))

@app.route('/generate', methods=['POST'])
def generate():
    try:
        data = request.json
        model_id = data.get('model', 'ssd-1b')
        prompt = data.get('prompt', '')
        ratio = data.get('ratio', '1:1')
        num_images = min(int(data.get('num_images', 1)), 4)
        
        # Handle guidance_scale with explicit type conversion
        guidance_scale_raw = data.get('guidance_scale', 7.5)
        logger.info(f"Raw guidance_scale: {guidance_scale_raw} (type: {type(guidance_scale_raw)})")
        try:
            guidance_scale = float(guidance_scale_raw)
            guidance_scale = min(max(guidance_scale, 1.0), 20.0)  # Clamp between 1.0 and 20.0
        except (ValueError, TypeError):
            logger.error(f"Invalid guidance_scale value: {guidance_scale_raw}")
            return jsonify({"error": "guidance_scale must be a valid number"}), 400

        # Log input parameters
        logger.info(f"Generating with model: {model_id}, prompt: {prompt}, ratio: {ratio}, num_images: {num_images}, guidance_scale: {guidance_scale}")

        if not prompt:
            return jsonify({"error": "Prompt is required"}), 400
        if len(prompt) > 512:
            return jsonify({"error": "Prompt is too long (max 512 characters)"}), 400
        if model_id == 'ssd-1b' and num_images > 1:
            return jsonify({"error": "SSD-1B allows only 1 image per generation"}), 400
        if model_id == 'ssd-1b' and ratio != '1:1':
            return jsonify({"error": "SSD-1B supports only 1:1 ratio"}), 400
        if model_id == 'sd-v1-5' and len(prompt.split()) > 77:
            return jsonify({"error": "Prompt exceeds 77 tokens for Stable Diffusion v1.5"}), 400

        width, height = ratio_to_dims.get(ratio, (512, 512))
        if width % 8 != 0 or height % 8 != 0:
            return jsonify({"error": "Width and height must be multiples of 8"}), 400

        # Log memory before generation
        logger.info(f"Memory before generation: {psutil.virtual_memory().available / (1024 ** 3):.2f} GB")

        pipe = load_model(model_id)
        pipe.to(torch.device("cpu"))  # Change to "cuda" if GPU is available

        images = []
        num_inference_steps = 30 if model_id == 'ssd-1b' else 30  # Unified steps for stability
        try:
            for _ in range(num_images):
                image = pipe(
                    prompt=prompt,
                    height=height,
                    width=width,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale
                ).images[0]
                images.append(image)
        except IndexError as e:
            logger.error(f"IndexError during generation: {str(e)}")
            return jsonify({"error": f"Generation failed due to invalid index access: {str(e)}"}), 500
        except Exception as e:
            logger.error(f"Unexpected error during generation: {str(e)}")
            return jsonify({"error": f"Generation failed: {str(e)}"}), 500

        output_dir = "outputs"
        os.makedirs(output_dir, exist_ok=True)
        image_urls = []
        for i, img in enumerate(images):
            img_path = os.path.join(output_dir, f"generated_{int(time.time())}_{i}.png")
            img.save(img_path)
            with open(img_path, "rb") as f:
                img_data = base64.b64encode(f.read()).decode('utf-8')
            image_urls.append(f"data:image/png;base64,{img_data}")
            os.remove(img_path)

        logger.info(f"Generation successful, returning {len(image_urls)} images")
        return jsonify({"images": image_urls})

    except Exception as e:
        logger.error(f"Image generation failed: {str(e)}")
        return jsonify({"error": f"Image generation failed: {str(e)}"}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)