Spaces:
Runtime error
Runtime error
#!/bin/env python3.11 | |
import gradio as gr | |
import os | |
import sqlite3 | |
import replicate | |
import argparse | |
import requests | |
from datetime import datetime | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request, Form, Query,Response | |
from fastapi.templating import Jinja2Templates | |
from fastapi.responses import FileResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
from typing import Optional, List | |
import uvicorn | |
from asyncio import gather, Semaphore, create_task | |
from mistralai import Mistral | |
from contextlib import contextmanager | |
from io import BytesIO | |
import zipfile | |
import sys | |
print(f"Arguments: {sys.argv}") | |
token = os.getenv("HF_TOKEN") | |
api_key = os.getenv("MISTRAL_API_KEY") | |
agent_id = os.getenv("MISTRAL_FLUX_AGENT") | |
# ANSI Escape Codes für farbige Ausgabe (kann entfernt werden, falls nicht benötigt) | |
HEADER = "\033[38;2;255;255;153m" | |
TITLE = "\033[38;2;255;255;153m" | |
MENU = "\033[38;2;255;165;0m" | |
SUCCESS = "\033[38;2;153;255;153m" | |
ERROR = "\033[38;2;255;69;0m" | |
MAIN = "\033[38;2;204;204;255m" | |
SPEAKER1 = "\033[38;2;173;216;230m" | |
SPEAKER2 = "\033[38;2;255;179;102m" | |
RESET = "\033[0m" | |
DOWNLOAD_DIR = "/home/user/app/flux-pics" # Pfad zu deinen Bildern (sollte korrekt sein) | |
DATABASE_PATH = "/home/user/app/flux_logs.db" # Datenbank-Pfad | |
TIMEOUT_DURATION = 900 # Timeout-Dauer in Sekunden (scheint angemessen) | |
# WICHTIG: Stelle sicher, dass dieses Verzeichnis existiert und die Bilder enthält. | |
IMAGE_STORAGE_PATH = DOWNLOAD_DIR | |
app = FastAPI() | |
# StaticFiles Middleware hinzufügen (korrekt und wichtig!) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
app.mount("/flux-pics", StaticFiles(directory=IMAGE_STORAGE_PATH), name="flux-pics") | |
templates = Jinja2Templates(directory="templates") | |
# Datenbank-Hilfsfunktionen (sehen gut aus) | |
def get_db_connection(db_path=DATABASE_PATH): | |
conn = sqlite3.connect(db_path) | |
try: | |
yield conn | |
finally: | |
conn.close() | |
def initialize_database(db_path=DATABASE_PATH): | |
with get_db_connection(db_path) as conn: | |
cursor = conn.cursor() | |
# Tabellen-Erstellung (scheint korrekt, keine Auffälligkeiten) | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS generation_logs ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
timestamp TEXT, | |
prompt TEXT, | |
optimized_prompt TEXT, | |
hf_lora TEXT, | |
lora_scale REAL, | |
aspect_ratio TEXT, | |
guidance_scale REAL, | |
output_quality INTEGER, | |
prompt_strength REAL, | |
num_inference_steps INTEGER, | |
output_file TEXT, | |
album_id INTEGER, | |
category_id INTEGER | |
) | |
""") | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS albums ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
name TEXT NOT NULL | |
) | |
""") | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS categories ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
name TEXT NOT NULL | |
) | |
""") | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS pictures ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
timestamp TEXT, | |
file_path TEXT, | |
file_name TEXT, | |
album_id INTEGER, | |
FOREIGN KEY (album_id) REFERENCES albums(id) | |
) | |
""") | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS picture_categories ( | |
picture_id INTEGER, | |
category_id INTEGER, | |
FOREIGN KEY (picture_id) REFERENCES pictures(id), | |
FOREIGN KEY (category_id) REFERENCES categories(id), | |
PRIMARY KEY (picture_id, category_id) | |
) | |
""") | |
conn.commit() | |
def log_generation(args, optimized_prompt, image_file): | |
file_path, file_name = os.path.split(image_file) | |
try: | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute(""" | |
INSERT INTO generation_logs ( | |
timestamp, prompt, optimized_prompt, hf_lora, lora_scale, aspect_ratio, guidance_scale, | |
output_quality, prompt_strength, num_inference_steps, output_file, album_id, category_id | |
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
""", ( | |
datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
args.prompt, | |
optimized_prompt, | |
args.hf_lora, | |
args.lora_scale, | |
args.aspect_ratio, | |
args.guidance_scale, | |
args.output_quality, | |
args.prompt_strength, | |
args.num_inference_steps, | |
image_file, | |
args.album_id, | |
args.category_ids[0] if args.category_ids else None # Hier auf erstes Element zugreifen | |
)) | |
picture_id = cursor.lastrowid # Dies scheint nicht korrekt zu sein, da die ID für die Tabelle pictures benötigt wird | |
cursor.execute(""" | |
INSERT INTO pictures ( | |
timestamp, file_path, file_name, album_id | |
) VALUES (?, ?, ?, ?) | |
""", ( | |
datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
file_path, | |
file_name, | |
args.album_id | |
)) | |
picture_id = cursor.lastrowid # Korrekte Zeile | |
# Insert multiple categories | |
for category_id in args.category_ids: | |
cursor.execute(""" | |
INSERT INTO picture_categories (picture_id, category_id) | |
VALUES (?, ?) | |
""", (picture_id, category_id)) | |
conn.commit() | |
except sqlite3.Error as e: | |
print(f"Error logging generation: {e}") # Sollte durch logger.error ersetzt werden. | |
def startup_event(): | |
initialize_database() | |
def read_root(request: Request): | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("SELECT id, name FROM albums") | |
albums = cursor.fetchall() | |
cursor.execute("SELECT id, name FROM categories") | |
categories = cursor.fetchall() | |
return templates.TemplateResponse("index.html", {"request": request, "albums": albums, "categories": categories}) | |
def read_archive( | |
request: Request, | |
album: Optional[str] = Query(None), | |
category: Optional[List[str]] = Query(None), | |
search: Optional[str] = None, | |
items_per_page: int = Query(30), | |
page: int = Query(1) | |
): | |
album_id = int(album) if album and album.isdigit() else None | |
category_ids = [int(cat) for cat in category] if category else [] | |
offset = (page - 1) * items_per_page | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
query = """ | |
SELECT gl.timestamp, gl.prompt, gl.optimized_prompt, gl.output_file, a.name as album, c.name as category | |
FROM generation_logs gl | |
LEFT JOIN albums a ON gl.album_id = a.id | |
LEFT JOIN categories c ON gl.category_id = c.id | |
WHERE 1=1 | |
""" | |
params = [] | |
if album_id is not None: | |
query += " AND gl.album_id = ?" | |
params.append(album_id) | |
if category_ids: | |
# Hier ist die Verknüpfungstabelle picture_categories notwendig | |
query = """ | |
SELECT gl.timestamp, gl.prompt, gl.optimized_prompt, gl.output_file, a.name as album, GROUP_CONCAT(c.name) as categories | |
FROM generation_logs gl | |
LEFT JOIN albums a ON gl.album_id = a.id | |
LEFT JOIN picture_categories pc ON gl.id = pc.picture_id | |
LEFT JOIN categories c ON pc.category_id = c.id | |
WHERE 1=1 | |
""" | |
if album_id is not None: | |
query += " AND gl.album_id = ?" | |
params.append(album_id) | |
query += " AND pc.category_id IN ({})".format(','.join('?' for _ in category_ids)) | |
params.extend(category_ids) | |
if search: | |
query += " AND (gl.prompt LIKE ? OR gl.optimized_prompt LIKE ?)" | |
params.append(f'%{search}%') | |
params.append(f'%{search}%') | |
query += " GROUP BY gl.id, gl.timestamp, gl.prompt, gl.optimized_prompt, gl.output_file, a.name ORDER BY gl.timestamp DESC LIMIT ? OFFSET ?" | |
params.extend([items_per_page, offset]) | |
cursor.execute(query, params) | |
logs = cursor.fetchall() | |
logs = [{ | |
"timestamp": log[0], | |
"prompt": log[1], | |
"optimized_prompt": log[2], | |
"output_file": log[3], | |
"album": log[4], | |
"category": log[5] | |
} for log in logs] | |
cursor.execute("SELECT id, name FROM albums") | |
albums = cursor.fetchall() | |
cursor.execute("SELECT id, name FROM categories") | |
categories = cursor.fetchall() | |
return templates.TemplateResponse("archive.html", { | |
"request": request, | |
"logs": logs, | |
"albums": albums, | |
"categories": categories, | |
"selected_album": album, | |
"selected_categories": category_ids, | |
"search_query": search, | |
"items_per_page": items_per_page, | |
"page": page | |
}) | |
def read_backend(request: Request): | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("SELECT id, name FROM albums") | |
albums = cursor.fetchall() | |
cursor.execute("SELECT id, name FROM categories") | |
categories = cursor.fetchall() | |
return templates.TemplateResponse("backend.html", {"request": request, "albums": albums, "categories": categories}) | |
async def get_backend_stats(): | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
# Anzahl der Bilder (aus der pictures-Tabelle) | |
cursor.execute("SELECT COUNT(*) FROM pictures") | |
total_images = cursor.fetchone()[0] | |
# Alben-Statistiken (Anzahl) | |
cursor.execute("SELECT COUNT(*) FROM albums") | |
total_albums = cursor.fetchone()[0] | |
# Kategorie-Statistiken (Anzahl) | |
cursor.execute("SELECT COUNT(*) FROM categories") | |
total_categories = cursor.fetchone()[0] | |
# Monatliche Statistiken (Anzahl der Bilder pro Monat) | |
cursor.execute(""" | |
SELECT strftime('%Y-%m', timestamp) as month, COUNT(*) | |
FROM pictures | |
GROUP BY month | |
ORDER BY month | |
""") | |
monthly_stats = [{"month": row[0], "count": row[1]} for row in cursor.fetchall()] | |
# Speicherplatzberechnung | |
total_size = 0 | |
for filename in os.listdir(IMAGE_STORAGE_PATH): | |
filepath = os.path.join(IMAGE_STORAGE_PATH, filename) | |
if os.path.isfile(filepath): | |
total_size += os.path.getsize(filepath) | |
total_size_mb = total_size / (1024 * 1024) | |
# Daten für die Kategorien-Statistik (Beispiel: Anzahl der Bilder pro Kategorie) | |
cursor.execute(""" | |
SELECT c.name, COUNT(pc.picture_id) | |
FROM categories c | |
LEFT JOIN picture_categories pc ON c.id = pc.category_id | |
GROUP BY c.name | |
""") | |
category_stats = [{"name": row[0], "count": row[1]} for row in cursor.fetchall()] | |
return { | |
"total_images": total_images, | |
"albums": { | |
"total": total_albums | |
}, | |
"categories": { | |
"total": total_categories, | |
"data": category_stats | |
}, | |
"storage_usage_mb": total_size_mb, | |
"monthly": monthly_stats | |
} # Hier war die Klammer falsch gesetzt | |
# Neue Routen für Alben | |
async def get_albums(): | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("SELECT id, name FROM albums") | |
result = cursor.fetchall() | |
albums = [{"id": row[0], "name": row[1]} for row in result] | |
return albums | |
async def create_album_route(name: str = Form(...), description: Optional[str] = Form(None)): | |
try: | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("INSERT INTO albums (name) VALUES (?)", (name,)) | |
conn.commit() | |
new_album_id = cursor.lastrowid | |
return {"message": "Album erstellt", "id": new_album_id, "name": name} | |
except sqlite3.Error as e: | |
raise HTTPException(status_code=500, detail=f"Error creating album: {e}") | |
async def delete_album(album_id: int): | |
try: | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
# Lösche die Verknüpfungen in picture_categories | |
cursor.execute("DELETE FROM picture_categories WHERE picture_id IN (SELECT id FROM pictures WHERE album_id = ?)", (album_id,)) | |
# Lösche die Bilder aus der pictures-Tabelle | |
cursor.execute("DELETE FROM pictures WHERE album_id = ?", (album_id,)) | |
# Lösche die Einträge aus generation_logs | |
cursor.execute("DELETE FROM generation_logs WHERE album_id = ?", (album_id,)) | |
# Lösche das Album aus der albums-Tabelle | |
cursor.execute("DELETE FROM albums WHERE id = ?", (album_id,)) | |
conn.commit() | |
return {"message": f"Album {album_id} und zugehörige Einträge gelöscht"} | |
except sqlite3.Error as e: | |
raise HTTPException(status_code=500, detail=f"Error deleting album: {e}") | |
async def update_album(album_id: int, request: Request): | |
data = await request.json() | |
try: | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("UPDATE albums SET name = ? WHERE id = ?", (data["name"], album_id)) | |
conn.commit() | |
if cursor.rowcount == 0: | |
raise HTTPException(status_code=404, detail=f"Album {album_id} nicht gefunden") | |
return {"message": f"Album {album_id} aktualisiert"} | |
except sqlite3.Error as e: | |
raise HTTPException(status_code=500, detail=f"Error updating album: {e}") | |
# Neue Routen für Kategorien | |
async def get_categories(): | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("SELECT id, name FROM categories") | |
result = cursor.fetchall() | |
categories = [{"id": row[0], "name": row[1]} for row in result] | |
return categories | |
async def create_category_route(name: str = Form(...)): | |
try: | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("INSERT INTO categories (name) VALUES (?)", (name,)) | |
conn.commit() | |
new_category_id = cursor.lastrowid | |
return {"message": "Kategorie erstellt", "id": new_category_id, "name": name} | |
except sqlite3.Error as e: | |
raise HTTPException(status_code=500, detail=f"Error creating category: {e}") | |
async def delete_category(category_id: int): | |
try: | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
# Lösche die Verknüpfungen in picture_categories | |
cursor.execute("DELETE FROM picture_categories WHERE category_id = ?", (category_id,)) | |
# Lösche die Kategorie aus der categories-Tabelle | |
cursor.execute("DELETE FROM categories WHERE id = ?", (category_id,)) | |
conn.commit() | |
return {"message": f"Kategorie {category_id} und zugehörige Einträge gelöscht"} | |
except sqlite3.Error as e: | |
raise HTTPException(status_code=500, detail=f"Error deleting category: {e}") | |
async def update_category(category_id: int, request: Request): | |
data = await request.json() | |
try: | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("UPDATE categories SET name = ? WHERE id = ?", (data["name"], category_id)) | |
conn.commit() | |
if cursor.rowcount == 0: | |
raise HTTPException(status_code=404, detail=f"Kategorie {category_id} nicht gefunden") | |
return {"message": f"Kategorie {category_id} aktualisiert"} | |
except sqlite3.Error as e: | |
raise HTTPException(status_code=500, detail=f"Error updating category: {e}") | |
async def download_images(request: Request): | |
try: | |
body = await request.json() | |
logger.info(f"Received request body: {body}") | |
image_files = body.get("selectedImages", []) | |
if not image_files: | |
raise HTTPException(status_code=400, detail="Keine Bilder ausgewählt.") | |
logger.info(f"Processing image files: {image_files}") | |
# Überprüfe ob Download-Verzeichnis existiert | |
if not os.path.exists(IMAGE_STORAGE_PATH): | |
logger.error(f"Storage path not found: {IMAGE_STORAGE_PATH}") | |
raise HTTPException(status_code=500, detail="Storage path not found") | |
zip_buffer = BytesIO() | |
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: | |
for image_file in image_files: | |
image_path = os.path.join(IMAGE_STORAGE_PATH, image_file) | |
logger.info(f"Processing file: {image_path}") | |
if os.path.exists(image_path): | |
zip_file.write(image_path, arcname=image_file) | |
else: | |
logger.error(f"File not found: {image_path}") | |
raise HTTPException(status_code=404, detail=f"Bild {image_file} nicht gefunden.") | |
zip_buffer.seek(0) | |
# Korrekter Response mit Buffer | |
return Response( | |
content=zip_buffer.getvalue(), | |
media_type="application/zip", | |
headers={ | |
"Content-Disposition": f"attachment; filename=images.zip" | |
} | |
) | |
except Exception as e: | |
logger.error(f"Error in download_images: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def download_single_image(request: Request): | |
try: | |
data = await request.json() | |
filename = data.get("filename") | |
logger.info(f"Requested file download: {filename}") | |
if not filename: | |
logger.error("No filename provided") | |
raise HTTPException(status_code=400, detail="Kein Dateiname angegeben") | |
file_path = os.path.join(IMAGE_STORAGE_PATH, filename) | |
logger.info(f"Full file path: {file_path}") | |
if not os.path.exists(file_path): | |
logger.error(f"File not found: {file_path}") | |
raise HTTPException(status_code=404, detail=f"Datei {filename} nicht gefunden") | |
# Determine MIME type | |
file_extension = filename.lower().split('.')[-1] | |
mime_types = { | |
'png': 'image/png', | |
'jpg': 'image/jpeg', | |
'jpeg': 'image/jpeg', | |
'gif': 'image/gif', | |
'webp': 'image/webp' | |
} | |
media_type = mime_types.get(file_extension, 'application/octet-stream') | |
logger.info(f"Serving file with media type: {media_type}") | |
return FileResponse( | |
path=file_path, | |
filename=filename, | |
media_type=media_type, | |
headers={ | |
"Content-Disposition": f"attachment; filename={filename}" | |
} | |
) | |
except Exception as e: | |
logger.error(f"Error in download_single_image: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def websocket_endpoint(websocket: WebSocket): | |
await websocket.accept() | |
try: | |
data = await websocket.receive_json() | |
prompts = data.get("prompts", [data]) | |
for prompt_data in prompts: | |
prompt_data["lora_scale"] = float(prompt_data["lora_scale"]) | |
prompt_data["guidance_scale"] = float(prompt_data["guidance_scale"]) | |
prompt_data["prompt_strength"] = float(prompt_data["prompt_strength"]) | |
prompt_data["num_inference_steps"] = int(prompt_data["num_inference_steps"]) | |
prompt_data["num_outputs"] = int(prompt_data["num_outputs"]) | |
prompt_data["output_quality"] = int(prompt_data["output_quality"]) | |
# Handle new album and category creation | |
album_name = prompt_data.get("album_id") | |
category_names = prompt_data.get("category_ids", []) | |
if album_name and not album_name.isdigit(): | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute( | |
"INSERT INTO albums (name) VALUES (?)", (album_name,) | |
) | |
conn.commit() | |
prompt_data["album_id"] = cursor.lastrowid | |
else: | |
prompt_data["album_id"] = int(album_name) if album_name else None | |
category_ids = [] | |
for category_name in category_names: | |
if not category_name.isdigit(): | |
with get_db_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute( | |
"INSERT INTO categories (name) VALUES (?)", (category_name,) | |
) | |
conn.commit() | |
category_ids.append(cursor.lastrowid) | |
else: | |
category_ids.append(int(category_name) if category_name else None) | |
prompt_data["category_ids"] = category_ids | |
args = argparse.Namespace(**prompt_data) | |
# await websocket.send_json({"message": "Optimiere Prompt..."}) | |
optimized_prompt = ( | |
optimize_prompt(args.prompt) | |
if getattr(args, "agent", False) | |
else args.prompt | |
) | |
await websocket.send_json({"optimized_prompt": optimized_prompt}) | |
if prompt_data.get("optimize_only"): | |
continue | |
await generate_and_download_image(websocket, args, optimized_prompt) | |
except WebSocketDisconnect: | |
print("Client disconnected") | |
except Exception as e: | |
await websocket.send_json({"message": str(e)}) | |
raise e | |
finally: | |
await websocket.close() | |
async def fetch_image(item, index, args, filenames, semaphore, websocket, timestamp): | |
async with semaphore: | |
try: | |
response = requests.get(item, timeout=TIMEOUT_DURATION) | |
if response.status_code == 200: | |
filename = ( | |
f"{DOWNLOAD_DIR}/image_{timestamp}_{index}.{args.output_format}" | |
) | |
with open(filename, "wb") as file: | |
file.write(response.content) | |
filenames.append( | |
f"/flux-pics/image_{timestamp}_{index}.{args.output_format}" | |
) | |
progress = int((index + 1) / args.num_outputs * 100) | |
await websocket.send_json({"progress": progress}) | |
else: | |
await websocket.send_json( | |
{ | |
"message": f"Fehler beim Herunterladen des Bildes {index + 1}: {response.status_code}" | |
} | |
) | |
except requests.exceptions.Timeout: | |
await websocket.send_json( | |
{"message": f"Timeout beim Herunterladen des Bildes {index + 1}"} | |
) | |
async def generate_and_download_image(websocket: WebSocket, args, optimized_prompt): | |
try: | |
input_data = { | |
"prompt": optimized_prompt, | |
"hf_lora": getattr( | |
args, "hf_lora", None | |
), # Use getattr to safely access hf_lora | |
"lora_scale": args.lora_scale, | |
"num_outputs": args.num_outputs, | |
"aspect_ratio": args.aspect_ratio, | |
"output_format": args.output_format, | |
"guidance_scale": args.guidance_scale, | |
"output_quality": args.output_quality, | |
"prompt_strength": args.prompt_strength, | |
"num_inference_steps": args.num_inference_steps, | |
"disable_safety_checker": False, | |
} | |
# await websocket.send_json({"message": "Generiere Bilder..."}) | |
# Debug: Log the start of the replication process | |
print( | |
f"Starting replication process for {args.num_outputs} outputs with timeout {TIMEOUT_DURATION}" | |
) | |
output = replicate.run( | |
"lucataco/flux-dev-lora:091495765fa5ef2725a175a57b276ec30dc9d39c22d30410f2ede68a3eab66b3", | |
input=input_data, | |
timeout=TIMEOUT_DURATION, | |
) | |
if not os.path.exists(DOWNLOAD_DIR): | |
os.makedirs(DOWNLOAD_DIR) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filenames = [] | |
semaphore = Semaphore(3) # Limit concurrent downloads | |
tasks = [ | |
create_task( | |
fetch_image( | |
item, index, args, filenames, semaphore, websocket, timestamp | |
) | |
) | |
for index, item in enumerate(output) | |
] | |
await gather(*tasks) | |
for file in filenames: | |
log_generation(args, optimized_prompt, file) | |
await websocket.send_json( | |
{"message": "Bilder erfolgreich generiert", "generated_files": filenames} | |
) | |
except requests.exceptions.Timeout: | |
await websocket.send_json( | |
{"message": "Fehler bei der Bildgenerierung: Timeout überschritten"} | |
) | |
except Exception as e: | |
await websocket.send_json( | |
{"message": f"Fehler bei der Bildgenerierung: {str(e)}"} | |
) | |
raise Exception(f"Fehler bei der Bildgenerierung: {str(e)}") | |
def optimize_prompt(prompt): | |
api_key = os.environ.get("MISTRAL_API_KEY") | |
agent_id = os.environ.get("MISTRAL_FLUX_AGENT") | |
if not api_key or not agent_id: | |
raise ValueError("MISTRAL_API_KEY oder MISTRAL_FLUX_AGENT nicht gesetzt") | |
client = Mistral(api_key=api_key) | |
chat_response = client.agents.complete( | |
agent_id=agent_id, | |
messages=[ | |
{ | |
"role": "user", | |
"content": f"Optimiere folgenden Prompt für Flux Lora: {prompt}", | |
} | |
], | |
) | |
return chat_response.choices[0].message.content | |
if __name__ == "__main__": | |
# Parse command line arguments | |
parser = argparse.ArgumentParser(description="Beschreibung") | |
parser.add_argument('--hf_lora', default=None, help='HF LoRA Model') | |
args = parser.parse_args() | |
# Pass arguments to the FastAPI application | |
app.state.args = args | |
# Run the Uvicorn server | |
uvicorn.run( | |
"app:app", | |
host="0.0.0.0", | |
port=7860, | |
reload=True, | |
log_level="trace" | |
) | |