import base64 import re import os import random import requests import time from PIL import Image from io import BytesIO from typing import Tuple from diffusers import DiffusionPipeline import torch from flask import Flask, request, jsonify from fastapi import FastAPI, Request, HTTPException import json #app = Flask(__name__) app=FastAPI() style_list = [ { "name": "(No style)", "prompt": "{prompt}", "negative_prompt": "", }, { "name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", }, { "name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly", }, { "name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast", }, { "name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style", }, { "name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", "negative_prompt": "photo, photorealistic, realism, ugly", }, { "name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic", }, { "name": "Fantasy art", "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white", }, { "name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured", }, { "name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting", }, ] def infer(prompt, negative="low_quality", style_name=None, guidance_scale=None): seed = random.randint(0,4294967295) prompt, negative = apply_style(style_name, prompt, negative) print(prompt) print(negative) # Load the Stable Diffusion model pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True,variant="fp16") pipe.to("cuda") # Generate the images images = pipe(prompt=prompt, negative_prompt=negative, guidance_scale=guidance_scale, seed=seed).images # Convert the images to base64-encoded strings image_urls = [] for i, image in enumerate(images): buffered = BytesIO() image.save(buffered, format="JPEG") image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") #image_url = f"data:image/jpeg;base64,{image_b64}" #image_urls.append(image_url) return image_b64 @app.post("/") def generate_image(data: dict): if 'prompt' in data and 'style_name' in data and 'guidance_scale' in data: prompt = data['prompt'] style_name = data['style_name'] guidance_scale = data['guidance_scale'] image_urls = infer(prompt, style_name=style_name, guidance_scale=guidance_scale) # Convert the first generated image to base64 image_b64 = image_urls[0].split(",")[1] return {"image_base64": image_b64} else: raise HTTPException(status_code=400, detail="Missing required parameters") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)