File size: 4,169 Bytes
9d8c3ac
 
c6498ec
9d8c3ac
 
 
 
 
 
 
c6498ec
 
84e9994
 
c6498ec
 
 
9d8c3ac
 
c6498ec
9d8c3ac
 
 
84e9994
 
9d8c3ac
c6498ec
9d8c3ac
 
c6498ec
 
 
9d8c3ac
c6498ec
9d8c3ac
c6498ec
 
 
 
9d8c3ac
 
c6498ec
 
 
 
 
 
 
 
5c955c9
c6498ec
 
 
84e9994
c6498ec
 
 
 
 
 
9d8c3ac
 
 
c6498ec
9d8c3ac
5c955c9
 
 
 
9d8c3ac
 
c6498ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84e9994
c6498ec
 
61df085
c6498ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d8c3ac
 
c6498ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from flask import Flask, request, jsonify
from flask_cors import CORS
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import torch
import os
from PIL import Image
import base64
import time
import logging

# Disable GPU detection
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["CUDA_DEVICE_ORDER"] = ""
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
torch.set_default_device("cpu")

app = Flask(__name__, static_folder='static')
CORS(app)

# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Log device in use
logger.info(f"Device in use: {torch.device('cpu')}")

# Model cache
model_cache = {}
model_paths = {
    "ssd-1b": "remiai3/ssd-1b",
    "sd-v1-5": "remiai3/stable-diffusion-v1-5"
}

# Image ratio to dimensions (optimized for CPU)
ratio_to_dims = {
    "1:1": (256, 256),
    "3:4": (192, 256),
    "16:9": (256, 144)
}

def load_model(model_id):
    if model_id not in model_cache:
        logger.info(f"Loading model {model_id}...")
        try:
            pipe = StableDiffusionPipeline.from_pretrained(
                model_paths[model_id],
                torch_dtype=torch.float32,
                use_auth_token=os.getenv("HF_TOKEN"),
                use_safetensors=True,
                low_cpu_mem_usage=True
            )
            pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
            pipe.enable_attention_slicing()
            pipe.to(torch.device("cpu"))
            model_cache[model_id] = pipe
            logger.info(f"Model {model_id} loaded successfully")
        except Exception as e:
            logger.error(f"Error loading model {model_id}: {str(e)}")
            raise
    return model_cache[model_id]

@app.route('/')
def index():
    return app.send_static_file('index.html')

@app.route('/assets/<path:filename>')
def serve_assets(filename):
    return app.send_static_file(os.path.join('assets', filename))

@app.route('/generate', methods=['POST'])
def generate():
    try:
        data = request.json
        model_id = data.get('model', 'ssd-1b')
        prompt = data.get('prompt', '')
        ratio = data.get('ratio', '1:1')
        num_images = min(int(data.get('num_images', 1)), 4)
        guidance_scale = float(data.get('guidance_scale', 7.5))

        if not prompt:
            return jsonify({"error": "Prompt is required"}), 400

        if model_id == 'ssd-1b' and num_images > 1:
            return jsonify({"error": "SSD-1B allows only 1 image per generation"}), 400
        if model_id == 'ssd-1b' and ratio != '1:1':
            return jsonify({"error": "SSD-1B supports only 1:1 ratio"}), 400
        if model_id == 'sd-v1-5' and len(prompt.split()) > 77:
            return jsonify({"error": "Prompt exceeds 77 tokens for Stable Diffusion v1.5"}), 400

        width, height = ratio_to_dims.get(ratio, (256, 256))
        pipe = load_model(model_id)
        pipe.to(torch.device("cpu"))

        images = []
        num_inference_steps = 30 if model_id == 'ssd-1b' else 40
        for _ in range(num_images):
            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale
            ).images[0]
            images.append(image)

        output_dir = "outputs"
        os.makedirs(output_dir, exist_ok=True)
        image_urls = []
        for i, img in enumerate(images):
            img_path = os.path.join(output_dir, f"generated_{int(time.time())}_{i}.png")
            img.save(img_path)
            with open(img_path, "rb") as f:
                img_data = base64.b64encode(f.read()).decode('utf-8')
            image_urls.append(f"data:image/png;base64,{img_data}")
            os.remove(img_path)

        return jsonify({"images": image_urls})

    except Exception as e:
        logger.error(f"Image generation failed: {str(e)}")
        return jsonify({"error": f"Image generation failed: {str(e)}"}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)