Spaces:
Running
Running
import os | |
import io | |
import uuid | |
import base64 | |
from typing import Dict, List, Optional, Any, Union | |
from pathlib import Path | |
import aiohttp | |
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Header, Request | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
import asyncio | |
import uvicorn | |
from datetime import datetime | |
import time | |
# Import all TTI providers | |
from webscout.Provider.TTI import ( | |
# Import all image providers | |
BlackboxAIImager, AsyncBlackboxAIImager, | |
DeepInfraImager, AsyncDeepInfraImager, | |
AiForceimager, AsyncAiForceimager, | |
NexraImager, AsyncNexraImager, | |
FreeAIImager, AsyncFreeAIImager, | |
NinjaImager, AsyncNinjaImager, | |
TalkaiImager, AsyncTalkaiImager, | |
PiclumenImager, AsyncPiclumenImager, | |
ArtbitImager, AsyncArtbitImager, | |
HFimager, AsyncHFimager, | |
) | |
try: | |
from webscout.Provider.TTI import AIArtaImager, AsyncAIArtaImager | |
AIARTA_AVAILABLE = True | |
except ImportError: | |
AIARTA_AVAILABLE = False | |
# Create FastAPI instance | |
app = FastAPI( | |
title="WebScout TTI API Server", | |
description="API server for Text-to-Image generation using various providers with OpenAI-compatible interface", | |
version="1.0.0", | |
) | |
# Add CORS middleware to allow cross-origin requests | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Storage for generated images (in-memory for demo purposes) | |
# In a production environment, you might want to store these in a database or a file system | |
IMAGE_STORAGE = {} | |
# Simple API key verification (demo purposes only) | |
# In production, you'd want a more secure authentication system | |
API_KEYS = {"sk-demo-key": "demo"} | |
# Provider mapping | |
PROVIDER_MAP = { | |
"blackbox": { | |
"class": AsyncBlackboxAIImager, | |
"description": "High-performance image generation with advanced retry mechanisms" | |
}, | |
"deepinfra": { | |
"class": AsyncDeepInfraImager, | |
"description": "Powerful image generation using FLUX-1-schnell and other models" | |
}, | |
"aiforce": { | |
"class": AsyncAiForceimager, | |
"description": "Advanced AI image generation with 12 specialized models" | |
}, | |
"nexra": { | |
"class": AsyncNexraImager, | |
"description": "Next-gen image creation with 19+ models" | |
}, | |
"freeai": { | |
"class": AsyncFreeAIImager, | |
"description": "Premium image generation with DALL-E 3 and Flux series models" | |
}, | |
"ninja": { | |
"class": AsyncNinjaImager, | |
"description": "Ninja-fast image generation with cyberpunk-themed logging" | |
}, | |
"talkai": { | |
"class": AsyncTalkaiImager, | |
"description": "Fast and reliable image generation with comprehensive error handling" | |
}, | |
"piclumen": { | |
"class": AsyncPiclumenImager, | |
"description": "Professional photorealistic image generation with advanced processing" | |
}, | |
"artbit": { | |
"class": AsyncArtbitImager, | |
"description": "Bit-perfect AI art creation with precise control over parameters" | |
}, | |
"huggingface": { | |
"class": AsyncHFimager, | |
"description": "Direct integration with HuggingFace's powerful models" | |
}, | |
} | |
# Add AIArta provider if available | |
if AIARTA_AVAILABLE: | |
PROVIDER_MAP["aiarta"] = { | |
"class": AsyncAIArtaImager, | |
"description": "Generate stunning AI art with AI Arta with 45+ artistic styles" | |
} | |
# Provider model info | |
PROVIDER_MODEL_INFO = { | |
"blackbox": { | |
"default": "blackbox-default", | |
"models": ["blackbox-default"], | |
"default_params": {} | |
}, | |
"deepinfra": { | |
"default": "flux-1-schnell", | |
"models": ["flux-1-schnell"], | |
"default_params": { | |
"num_inference_steps": 25, | |
"guidance_scale": 7.5, | |
"width": 1024, | |
"height": 1024 | |
} | |
}, | |
"aiforce": { | |
"default": "flux-1-pro", | |
"models": [ | |
"stable-diffusion-xl-lightning", | |
"stable-diffusion-xl-base", | |
"flux-1-pro", | |
"ideogram", | |
"flux", | |
"flux-realism", | |
"flux-anime", | |
"flux-3d", | |
"flux-disney", | |
"flux-pixel", | |
"flux-4o", | |
"any-dark" | |
], | |
"default_params": { | |
"width": 768, | |
"height": 768 | |
} | |
}, | |
"nexra": { | |
"default": "midjourney", | |
"models": [ | |
"emi", | |
"stablediffusion-1-5", | |
"stablediffusion-2-1", | |
"sdxl-lora", | |
"dalle", | |
"dalle2", | |
"dalle-mini", | |
"flux", | |
"midjourney", | |
"dreamshaper-xl", | |
"dynavision-xl", | |
"juggernaut-xl", | |
"realism-engine-sdxl", | |
"sd-xl-base-1-0", | |
"animagine-xl-v3", | |
"sd-xl-base-inpainting", | |
"turbovision-xl", | |
"devlish-photorealism-sdxl", | |
"realvis-xl-v4" | |
], | |
"default_params": {} | |
}, | |
"freeai": { | |
"default": "dall-e-3", | |
"models": [ | |
"dall-e-3", | |
"flux-pro-ultra", | |
"flux-pro", | |
"flux-pro-ultra-raw", | |
"flux-schnell", | |
"flux-realism", | |
"grok-2-aurora" | |
], | |
"default_params": { | |
"size": "1024x1024", | |
"quality": "standard", | |
"style": "vivid" | |
} | |
}, | |
"ninja": { | |
"default": "flux-dev", | |
"models": ["stable-diffusion", "flux-dev"], | |
"default_params": {} | |
}, | |
"talkai": { | |
"default": "talkai-default", | |
"models": ["talkai-default"], | |
"default_params": {} | |
}, | |
"piclumen": { | |
"default": "piclumen-default", | |
"models": ["piclumen-default"], | |
"default_params": {} | |
}, | |
"artbit": { | |
"default": "sdxl", | |
"models": ["sdxl", "sd"], | |
"default_params": { | |
"selected_ratio": "1024" | |
} | |
}, | |
"huggingface": { | |
"default": "stable-diffusion-xl-base-1-0", | |
"models": ["stable-diffusion-xl-base-1-0", "stable-diffusion-v1-5"], | |
"default_params": { | |
"guidance_scale": 7.5, | |
"num_inference_steps": 30 | |
} | |
} | |
} | |
# Normalize model names to OpenAI-like format | |
for provider, info in PROVIDER_MODEL_INFO.items(): | |
info["models"] = [model.replace("/", "-").replace(".", "-").replace("_", "-").lower() for model in info["models"]] | |
info["default"] = info["default"].replace("/", "-").replace(".", "-").replace("_", "-").lower() | |
# Add AIArta model info if available | |
if AIARTA_AVAILABLE: | |
PROVIDER_MODEL_INFO["aiarta"] = { | |
"default": "flux", | |
"models": [ | |
"flux", "medieval", "vincent-van-gogh", "f-dev", "low-poly", | |
"dreamshaper-xl", "anima-pencil-xl", "biomech", "trash-polka", | |
"no-style", "cheyenne-xl", "chicano", "embroidery-tattoo", | |
"red-and-black", "fantasy-art", "watercolor", "dotwork", | |
"old-school-colored", "realistic-tattoo", "japanese-2", | |
"realistic-stock-xl", "f-pro", "revanimated", "katayama-mix-xl", | |
"sdxl-l", "cor-epica-xl", "anime-tattoo", "new-school", | |
"death-metal", "old-school", "juggernaut-xl", "photographic", | |
"sdxl-1-0", "graffiti", "mini-tattoo", "surrealism", | |
"neo-traditional", "on-limbs-black", "yamers-realistic-xl", | |
"pony-xl", "playground-xl", "anything-xl", "flame-design", | |
"kawaii", "cinematic-art", "professional", "flux-black-ink" | |
], | |
"default_params": { | |
"negative_prompt": "blurry, deformed hands, ugly", | |
"guidance_scale": 7, | |
"num_inference_steps": 30, | |
"aspect_ratio": "1:1" | |
} | |
} | |
# Define Pydantic models for request and response validation (OpenAI-compatible) | |
class ImageSize(BaseModel): | |
width: int = Field(1024, description="Image width") | |
height: int = Field(1024, description="Image height") | |
class ImageGenerationRequest(BaseModel): | |
model: str = Field(..., description="The model to use for image generation") | |
prompt: str = Field(..., description="The prompt to generate images from") | |
n: Optional[int] = Field(1, description="Number of images to generate", ge=1, le=10) | |
size: Optional[str] = Field("1024x1024", description="Image size in format WIDTHxHEIGHT") | |
response_format: Optional[str] = Field("url", description="The format in which the generated images are returned", enum=["url", "b64_json"]) | |
user: Optional[str] = Field(None, description="A unique identifier for the user") | |
style: Optional[str] = Field(None, description="Style for the generation") | |
quality: Optional[str] = Field(None, description="Quality level for the generation") | |
negative_prompt: Optional[str] = Field(None, description="What to avoid in the generated image") | |
class ImageData(BaseModel): | |
url: Optional[str] = Field(None, description="The URL of the generated image") | |
b64_json: Optional[str] = Field(None, description="Base64 encoded JSON string of the image") | |
revised_prompt: Optional[str] = Field(None, description="The prompt after any revisions") | |
class ImageGenerationResponse(BaseModel): | |
created: int = Field(..., description="Unix timestamp for when the request was created") | |
data: List[ImageData] = Field(..., description="List of generated images") | |
class ModelsListResponse(BaseModel): | |
object: str = Field("list", description="Object type") | |
data: List[Dict[str, Any]] = Field(..., description="List of available models") | |
class ErrorResponse(BaseModel): | |
error: Dict[str, Any] = Field(..., description="Error details") | |
# Error handling | |
class APIError(Exception): | |
def __init__(self, message, code=400, param=None, type="invalid_request_error"): | |
self.message = message | |
self.code = code | |
self.param = param | |
self.type = type | |
# Authentication dependency | |
async def verify_api_key(authorization: Optional[str] = Header(None)): | |
if authorization is None: | |
raise HTTPException( | |
status_code=401, | |
detail={ | |
"error": { | |
"message": "No API key provided", | |
"type": "authentication_error", | |
"param": None, | |
"code": "no_api_key" | |
} | |
} | |
) | |
# Extract the key from the Authorization header | |
parts = authorization.split() | |
if len(parts) != 2 or parts[0].lower() != "bearer": | |
raise HTTPException( | |
status_code=401, | |
detail={ | |
"error": { | |
"message": "Invalid authentication format. Use 'Bearer YOUR_API_KEY'", | |
"type": "authentication_error", | |
"param": None, | |
"code": "invalid_auth_format" | |
} | |
} | |
) | |
api_key = parts[1] | |
# Check if the API key is valid | |
# In production, you'd want to use a more secure method | |
if api_key not in API_KEYS: | |
raise HTTPException( | |
status_code=401, | |
detail={ | |
"error": { | |
"message": "Invalid API key", | |
"type": "authentication_error", | |
"param": None, | |
"code": "invalid_api_key" | |
} | |
} | |
) | |
return api_key | |
# Find provider from model ID - updating this function to support provider/model format | |
def get_provider_for_model(model: str): | |
model = model.lower() | |
# Check if it's in the format 'provider/model' | |
if "/" in model: | |
provider_name, model_name = model.split("/", 1) | |
model_name = model_name.replace("/", "-").replace(".", "-").replace("_", "-").lower() | |
# Check if provider exists | |
if provider_name not in PROVIDER_MAP: | |
raise APIError( | |
message=f"Provider '{provider_name}' not found", | |
code=404, | |
type="provider_not_found" | |
) | |
# Check if model exists for this provider | |
provider_models = PROVIDER_MODEL_INFO[provider_name]["models"] | |
if model_name not in provider_models: | |
# Try searching with less normalization - some providers might use underscore variants | |
original_model_name = model_name.replace("-", "_") | |
if original_model_name not in [m.replace("-", "_") for m in provider_models]: | |
raise APIError( | |
message=f"Model '{model_name}' not found for provider '{provider_name}'", | |
code=404, | |
type="model_not_found" | |
) | |
return provider_name, model_name | |
# If not in provider/model format, search all providers (original behavior) | |
for provider_name, provider_info in PROVIDER_MODEL_INFO.items(): | |
# Check if this model belongs to this provider | |
if model in provider_info["models"] or model == provider_info["default"]: | |
return provider_name, model | |
# If no provider found, return error | |
raise APIError( | |
message=f"Model '{model}' not found", | |
code=404, | |
type="model_not_found" | |
) | |
# Health check endpoint | |
async def health_check(): | |
return {"status": "ok"} | |
# OpenAI-compatible endpoints | |
# List available models | |
async def list_models(): | |
models_data = [] | |
for provider_name, provider_info in PROVIDER_MODEL_INFO.items(): | |
provider_description = PROVIDER_MAP.get(provider_name, {}).get("description", "") | |
for model_name in provider_info["models"]: | |
is_default = model_name == provider_info["default"] | |
models_data.append({ | |
"id": model_name, | |
"object": "model", | |
"created": int(time.time()), | |
"owned_by": provider_name, | |
"permission": [], | |
"root": model_name, | |
"parent": None, | |
"description": f"{provider_description} - {'Default model' if is_default else 'Alternative model'}", | |
}) | |
return { | |
"object": "list", | |
"data": models_data | |
} | |
# Get model information | |
async def get_model(model_id: str): | |
try: | |
provider_name, model = get_provider_for_model(model_id) | |
provider_description = PROVIDER_MAP.get(provider_name, {}).get("description", "") | |
return { | |
"id": model, | |
"object": "model", | |
"created": int(time.time()), | |
"owned_by": provider_name, | |
"permission": [], | |
"root": model, | |
"parent": None, | |
"description": provider_description | |
} | |
except APIError as e: | |
return JSONResponse( | |
status_code=e.code, | |
content={"error": {"message": e.message, "type": e.type, "param": e.param, "code": e.code}} | |
) | |
# Generate images | |
async def create_image(request: ImageGenerationRequest, background_tasks: BackgroundTasks): | |
try: | |
# Get provider for the requested model | |
provider_name, model = get_provider_for_model(request.model) | |
provider_class = PROVIDER_MAP[provider_name]["class"] | |
# Parse size | |
width, height = 1024, 1024 | |
if request.size: | |
try: | |
size_parts = request.size.split("x") | |
if len(size_parts) == 2: | |
width, height = int(size_parts[0]), int(size_parts[1]) | |
else: | |
width = height = int(size_parts[0]) | |
except: | |
pass | |
# Create task ID | |
task_id = str(uuid.uuid4()) | |
IMAGE_STORAGE[task_id] = {"status": "processing", "images": []} | |
# Get default params and update with user-provided values | |
default_params = PROVIDER_MODEL_INFO[provider_name].get("default_params", {}).copy() | |
# Add additional parameters from the request | |
if request.negative_prompt: | |
default_params["negative_prompt"] = request.negative_prompt | |
if request.quality: | |
default_params["quality"] = request.quality | |
if request.style: | |
default_params["style"] = request.style | |
# Update size parameters | |
default_params["width"] = width | |
default_params["height"] = height | |
# Function to generate images in the background | |
async def generate_images(): | |
try: | |
# Initialize provider based on the provider name | |
if provider_name == "freeai": | |
provider_instance = provider_class(model=model) | |
elif provider_name == "deepinfra" and "-flux-" in model: | |
# Convert back to model format expected by provider | |
original_model = "black-forest-labs/FLUX-1-schnell" | |
provider_instance = provider_class(model=original_model) | |
else: | |
provider_instance = provider_class() | |
# Generate images with provider-specific parameters | |
# Each provider may have different parameter requirements | |
if provider_name == "aiforce": | |
images = await provider_instance.generate( | |
prompt=request.prompt, | |
amount=request.n, | |
model=model.replace("-", "_"), # Convert back to format used by provider | |
width=default_params.get("width", 768), | |
height=default_params.get("height", 768), | |
seed=default_params.get("seed", None) | |
) | |
elif provider_name == "deepinfra": | |
images = await provider_instance.generate( | |
prompt=request.prompt, | |
amount=request.n, | |
num_inference_steps=default_params.get("num_inference_steps", 25), | |
guidance_scale=default_params.get("guidance_scale", 7.5), | |
width=default_params.get("width", 1024), | |
height=default_params.get("height", 1024), | |
seed=default_params.get("seed", None) | |
) | |
elif provider_name == "nexra": | |
# Convert back to original model format | |
original_model = model.replace("-", "_") | |
images = await provider_instance.generate( | |
prompt=request.prompt, | |
amount=request.n, | |
model=original_model, | |
additional_params=default_params | |
) | |
elif provider_name == "freeai": | |
images = await provider_instance.generate( | |
prompt=request.prompt, | |
amount=request.n, | |
size=f"{width}x{height}", | |
quality=default_params.get("quality", "standard"), | |
style=default_params.get("style", "vivid") | |
) | |
elif provider_name == "ninja": | |
images = await provider_instance.generate( | |
prompt=request.prompt, | |
amount=request.n, | |
model=model.replace("-", "_") | |
) | |
elif provider_name == "artbit": | |
images = await provider_instance.generate( | |
prompt=request.prompt, | |
amount=request.n, | |
caption_model=model, | |
selected_ratio=default_params.get("selected_ratio", "1024"), | |
negative_prompt=default_params.get("negative_prompt", "") | |
) | |
elif provider_name == "huggingface": | |
# Convert from dash format to slash format for HF | |
original_model = model.replace("-", "/") | |
images = await provider_instance.generate( | |
prompt=request.prompt, | |
amount=request.n, | |
model=original_model, | |
guidance_scale=default_params.get("guidance_scale", 7.5), | |
negative_prompt=default_params.get("negative_prompt", None), | |
num_inference_steps=default_params.get("num_inference_steps", 30), | |
width=width, | |
height=height | |
) | |
elif provider_name == "aiarta" and AIARTA_AVAILABLE: | |
images = await provider_instance.generate( | |
prompt=request.prompt, | |
amount=request.n, | |
model=model, | |
negative_prompt=default_params.get("negative_prompt", "blurry, deformed hands, ugly"), | |
guidance_scale=default_params.get("guidance_scale", 7), | |
num_inference_steps=default_params.get("num_inference_steps", 30), | |
aspect_ratio=default_params.get("aspect_ratio", "1:1") | |
) | |
else: | |
# Default case for providers with simpler interfaces | |
images = await provider_instance.generate( | |
prompt=request.prompt, | |
amount=request.n | |
) | |
# Process and store the generated images | |
for i, img in enumerate(images): | |
# Handle both URL strings and binary data | |
if isinstance(img, str): | |
# For providers that return URLs instead of binary data | |
async with aiohttp.ClientSession() as session: | |
async with session.get(img) as resp: | |
resp.raise_for_status() | |
img_data = await resp.read() | |
else: | |
img_data = img | |
# Generate a unique URL for the image | |
image_id = f"{i}" | |
image_url = f"/v1/images/{task_id}/{image_id}" | |
# Store image data based on requested format | |
if request.response_format == "b64_json": | |
encoded = base64.b64encode(img_data).decode('utf-8') | |
IMAGE_STORAGE[task_id]["images"].append({ | |
"image_id": image_id, | |
"data": encoded, | |
"url": image_url, | |
}) | |
else: # Default to URL | |
IMAGE_STORAGE[task_id]["images"].append({ | |
"image_id": image_id, | |
"data": img_data, | |
"url": image_url, | |
}) | |
# Update task status | |
IMAGE_STORAGE[task_id]["status"] = "completed" | |
except Exception as e: | |
# Handle errors | |
IMAGE_STORAGE[task_id]["status"] = "failed" | |
IMAGE_STORAGE[task_id]["error"] = str(e) | |
# Start background task | |
background_tasks.add_task(generate_images) | |
# Immediate response with task details | |
# For compatibility, we need to structure this like OpenAI's response | |
created_timestamp = int(time.time()) | |
# Wait briefly to allow the background task to start | |
await asyncio.sleep(0.1) | |
# Check if the task failed immediately | |
if IMAGE_STORAGE[task_id]["status"] == "failed": | |
error_message = IMAGE_STORAGE[task_id].get("error", "Unknown error") | |
raise APIError(message=f"Image generation failed: {error_message}", code=500) | |
# Prepare response data | |
image_data = [] | |
for i in range(request.n): | |
if request.response_format == "b64_json": | |
image_data.append({ | |
"b64_json": "", # Will be filled in by the background task | |
"revised_prompt": request.prompt | |
}) | |
else: | |
image_data.append({ | |
"url": f"/v1/images/{task_id}/{i}", | |
"revised_prompt": request.prompt | |
}) | |
return { | |
"created": created_timestamp, | |
"data": image_data | |
} | |
except APIError as e: | |
return JSONResponse( | |
status_code=e.code, | |
content={"error": {"message": e.message, "type": e.type, "param": e.param, "code": e.code}} | |
) | |
except Exception as e: | |
return JSONResponse( | |
status_code=500, | |
content={"error": {"message": str(e), "type": "server_error", "param": None, "code": 500}} | |
) | |
# Image retrieval endpoint | |
async def get_image(task_id: str, image_id: str): | |
if task_id not in IMAGE_STORAGE: | |
return JSONResponse( | |
status_code=404, | |
content={"error": {"message": f"Image not found", "type": "not_found_error"}} | |
) | |
task_data = IMAGE_STORAGE[task_id] | |
if task_data["status"] == "failed": | |
return JSONResponse( | |
status_code=500, | |
content={"error": {"message": f"Image generation failed: {task_data.get('error', 'Unknown error')}", "type": "processing_error"}} | |
) | |
if task_data["status"] == "processing": | |
return JSONResponse( | |
status_code=202, | |
content={"status": "processing", "message": "Image is still being generated"} | |
) | |
# Find the requested image | |
for img in task_data["images"]: | |
if img["image_id"] == image_id: | |
# If it's stored as base64, it's already in the right format | |
if isinstance(img["data"], str): | |
return JSONResponse(content={"b64_json": img["data"]}) | |
# If it's binary data, return as an image stream | |
return StreamingResponse( | |
io.BytesIO(img["data"]), | |
media_type="image/png" | |
) | |
return JSONResponse( | |
status_code=404, | |
content={"error": {"message": f"Image not found", "type": "not_found_error"}} | |
) | |
# Legacy endpoints for backward compatibility | |
async def list_providers_legacy(): | |
providers = {} | |
for provider_name, provider_info in PROVIDER_MAP.items(): | |
model_info = PROVIDER_MODEL_INFO.get(provider_name, {}) | |
providers[provider_name] = { | |
"description": provider_info.get("description", ""), | |
"default_model": model_info.get("default", "default"), | |
"models": model_info.get("models", ["default"]), | |
"default_params": model_info.get("default_params", {}) | |
} | |
return providers | |
# Main entry point | |
if __name__ == "__main__": | |
uvicorn.run( | |
"app:app", | |
host="0.0.0.0", | |
port=8000, | |
reload=True | |
) |