Spaces:
Sleeping
Sleeping
File size: 4,169 Bytes
9d8c3ac c6498ec 9d8c3ac c6498ec 84e9994 c6498ec 9d8c3ac c6498ec 9d8c3ac 84e9994 9d8c3ac c6498ec 9d8c3ac c6498ec 9d8c3ac c6498ec 9d8c3ac c6498ec 9d8c3ac c6498ec 5c955c9 c6498ec 84e9994 c6498ec 9d8c3ac c6498ec 9d8c3ac 5c955c9 9d8c3ac c6498ec 84e9994 c6498ec 61df085 c6498ec 9d8c3ac c6498ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from flask import Flask, request, jsonify
from flask_cors import CORS
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import torch
import os
from PIL import Image
import base64
import time
import logging
# Disable GPU detection
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["CUDA_DEVICE_ORDER"] = ""
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
torch.set_default_device("cpu")
app = Flask(__name__, static_folder='static')
CORS(app)
# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Log device in use
logger.info(f"Device in use: {torch.device('cpu')}")
# Model cache
model_cache = {}
model_paths = {
"ssd-1b": "remiai3/ssd-1b",
"sd-v1-5": "remiai3/stable-diffusion-v1-5"
}
# Image ratio to dimensions (optimized for CPU)
ratio_to_dims = {
"1:1": (256, 256),
"3:4": (192, 256),
"16:9": (256, 144)
}
def load_model(model_id):
if model_id not in model_cache:
logger.info(f"Loading model {model_id}...")
try:
pipe = StableDiffusionPipeline.from_pretrained(
model_paths[model_id],
torch_dtype=torch.float32,
use_auth_token=os.getenv("HF_TOKEN"),
use_safetensors=True,
low_cpu_mem_usage=True
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_attention_slicing()
pipe.to(torch.device("cpu"))
model_cache[model_id] = pipe
logger.info(f"Model {model_id} loaded successfully")
except Exception as e:
logger.error(f"Error loading model {model_id}: {str(e)}")
raise
return model_cache[model_id]
@app.route('/')
def index():
return app.send_static_file('index.html')
@app.route('/assets/<path:filename>')
def serve_assets(filename):
return app.send_static_file(os.path.join('assets', filename))
@app.route('/generate', methods=['POST'])
def generate():
try:
data = request.json
model_id = data.get('model', 'ssd-1b')
prompt = data.get('prompt', '')
ratio = data.get('ratio', '1:1')
num_images = min(int(data.get('num_images', 1)), 4)
guidance_scale = float(data.get('guidance_scale', 7.5))
if not prompt:
return jsonify({"error": "Prompt is required"}), 400
if model_id == 'ssd-1b' and num_images > 1:
return jsonify({"error": "SSD-1B allows only 1 image per generation"}), 400
if model_id == 'ssd-1b' and ratio != '1:1':
return jsonify({"error": "SSD-1B supports only 1:1 ratio"}), 400
if model_id == 'sd-v1-5' and len(prompt.split()) > 77:
return jsonify({"error": "Prompt exceeds 77 tokens for Stable Diffusion v1.5"}), 400
width, height = ratio_to_dims.get(ratio, (256, 256))
pipe = load_model(model_id)
pipe.to(torch.device("cpu"))
images = []
num_inference_steps = 30 if model_id == 'ssd-1b' else 40
for _ in range(num_images):
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale
).images[0]
images.append(image)
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
image_urls = []
for i, img in enumerate(images):
img_path = os.path.join(output_dir, f"generated_{int(time.time())}_{i}.png")
img.save(img_path)
with open(img_path, "rb") as f:
img_data = base64.b64encode(f.read()).decode('utf-8')
image_urls.append(f"data:image/png;base64,{img_data}")
os.remove(img_path)
return jsonify({"images": image_urls})
except Exception as e:
logger.error(f"Image generation failed: {str(e)}")
return jsonify({"error": f"Image generation failed: {str(e)}"}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860) |