Scalino84 commited on
Commit
4c7b6b6
·
verified ·
1 Parent(s): 609870d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +651 -635
app.py CHANGED
@@ -1,184 +1,189 @@
1
- #!/bin/env python3.11
2
- import gradio as gr
3
- import os
4
- import sqlite3
5
- import replicate
6
- import argparse
7
- import requests
8
- from datetime import datetime
9
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request, Form, Query,Response
10
- from fastapi.templating import Jinja2Templates
11
- from fastapi.responses import FileResponse
12
- from fastapi.staticfiles import StaticFiles
13
- from pydantic import BaseModel
14
- from typing import Optional, List
15
- import uvicorn
16
- from asyncio import gather, Semaphore, create_task
17
- from mistralai import Mistral
18
- from contextlib import contextmanager
19
- from io import BytesIO
20
- import zipfile
21
-
22
- import sys
23
- print(f"Arguments: {sys.argv}")
24
-
25
-
26
- token = os.getenv("HF_TOKEN")
27
- api_key = os.getenv("MISTRAL_API_KEY")
28
- agent_id = os.getenv("MISTRAL_FLUX_AGENT")
29
-
30
- # ANSI Escape Codes für farbige Ausgabe (kann entfernt werden, falls nicht benötigt)
31
- HEADER = "\033[38;2;255;255;153m"
32
- TITLE = "\033[38;2;255;255;153m"
33
- MENU = "\033[38;2;255;165;0m"
34
- SUCCESS = "\033[38;2;153;255;153m"
35
- ERROR = "\033[38;2;255;69;0m"
36
- MAIN = "\033[38;2;204;204;255m"
37
- SPEAKER1 = "\033[38;2;173;216;230m"
38
- SPEAKER2 = "\033[38;2;255;179;102m"
39
- RESET = "\033[0m"
40
-
41
- DOWNLOAD_DIR = "/home/user/app/flux-pics" # Pfad zu deinen Bildern (sollte korrekt sein)
42
- DATABASE_PATH = "/home/user/app/flux_logs.db" # Datenbank-Pfad
43
- TIMEOUT_DURATION = 900 # Timeout-Dauer in Sekunden (scheint angemessen)
44
-
45
- # WICHTIG: Stelle sicher, dass dieses Verzeichnis existiert und die Bilder enthält.
46
- IMAGE_STORAGE_PATH = DOWNLOAD_DIR
47
-
48
- app = FastAPI()
49
-
50
- # StaticFiles Middleware hinzufügen (korrekt und wichtig!)
51
- app.mount("/static", StaticFiles(directory="static"), name="static")
52
- app.mount("/flux-pics", StaticFiles(directory=IMAGE_STORAGE_PATH), name="flux-pics")
53
-
54
- templates = Jinja2Templates(directory="templates")
55
-
56
- # Datenbank-Hilfsfunktionen (sehen gut aus)
57
- @contextmanager
58
- def get_db_connection(db_path=DATABASE_PATH):
59
- conn = sqlite3.connect(db_path)
60
- try:
61
- yield conn
62
- finally:
63
- conn.close()
64
-
65
- def initialize_database(db_path=DATABASE_PATH):
66
- with get_db_connection(db_path) as conn:
67
- cursor = conn.cursor()
68
- # Tabellen-Erstellung (scheint korrekt, keine Auffälligkeiten)
69
- cursor.execute("""
70
- CREATE TABLE IF NOT EXISTS generation_logs (
71
- id INTEGER PRIMARY KEY AUTOINCREMENT,
72
- timestamp TEXT,
73
- prompt TEXT,
74
- optimized_prompt TEXT,
75
- hf_lora TEXT,
76
- lora_scale REAL,
77
- aspect_ratio TEXT,
78
- guidance_scale REAL,
79
- output_quality INTEGER,
80
- prompt_strength REAL,
81
- num_inference_steps INTEGER,
82
- output_file TEXT,
83
- album_id INTEGER,
84
- category_id INTEGER
85
- )
86
- """)
87
- cursor.execute("""
88
- CREATE TABLE IF NOT EXISTS albums (
89
- id INTEGER PRIMARY KEY AUTOINCREMENT,
90
- name TEXT NOT NULL
91
- )
92
- """)
93
- cursor.execute("""
94
- CREATE TABLE IF NOT EXISTS categories (
95
- id INTEGER PRIMARY KEY AUTOINCREMENT,
96
- name TEXT NOT NULL
97
- )
98
- """)
99
- cursor.execute("""
100
- CREATE TABLE IF NOT EXISTS pictures (
101
- id INTEGER PRIMARY KEY AUTOINCREMENT,
102
- timestamp TEXT,
103
- file_path TEXT,
104
- file_name TEXT,
105
- album_id INTEGER,
106
- FOREIGN KEY (album_id) REFERENCES albums(id)
107
- )
108
- """)
109
- cursor.execute("""
110
- CREATE TABLE IF NOT EXISTS picture_categories (
111
- picture_id INTEGER,
112
- category_id INTEGER,
113
- FOREIGN KEY (picture_id) REFERENCES pictures(id),
114
- FOREIGN KEY (category_id) REFERENCES categories(id),
115
- PRIMARY KEY (picture_id, category_id)
116
- )
117
- """)
118
- conn.commit()
119
- def log_generation(args, optimized_prompt, image_file):
120
- file_path, file_name = os.path.split(image_file)
121
- try:
122
- with get_db_connection() as conn:
123
- cursor = conn.cursor()
124
- cursor.execute("""
125
- INSERT INTO generation_logs (
126
- timestamp, prompt, optimized_prompt, hf_lora, lora_scale, aspect_ratio, guidance_scale,
127
- output_quality, prompt_strength, num_inference_steps, output_file, album_id, category_id
128
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
129
- """, (
130
- datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
131
- args.prompt,
132
- optimized_prompt,
133
- args.hf_lora,
134
- args.lora_scale,
135
- args.aspect_ratio,
136
- args.guidance_scale,
137
- args.output_quality,
138
- args.prompt_strength,
139
- args.num_inference_steps,
140
- image_file,
141
- args.album_id,
142
- args.category_ids[0] if args.category_ids else None # Hier auf erstes Element zugreifen
143
- ))
144
- picture_id = cursor.lastrowid # Dies scheint nicht korrekt zu sein, da die ID für die Tabelle pictures benötigt wird
145
- cursor.execute("""
146
- INSERT INTO pictures (
147
- timestamp, file_path, file_name, album_id
148
- ) VALUES (?, ?, ?, ?)
149
- """, (
150
- datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
151
- file_path,
152
- file_name,
153
- args.album_id
154
- ))
155
- picture_id = cursor.lastrowid # Korrekte Zeile
156
-
157
- # Insert multiple categories
158
- for category_id in args.category_ids:
159
- cursor.execute("""
160
- INSERT INTO picture_categories (picture_id, category_id)
161
- VALUES (?, ?)
162
- """, (picture_id, category_id))
163
-
164
- conn.commit()
165
- except sqlite3.Error as e:
166
- print(f"Error logging generation: {e}") # Sollte durch logger.error ersetzt werden.
167
-
168
- @app.on_event("startup")
169
- def startup_event():
170
- initialize_database()
171
-
172
- @app.get("/")
173
- def read_root(request: Request):
174
- with get_db_connection() as conn:
175
- cursor = conn.cursor()
176
- cursor.execute("SELECT id, name FROM albums")
177
- albums = cursor.fetchall()
178
- cursor.execute("SELECT id, name FROM categories")
179
- categories = cursor.fetchall()
180
- return templates.TemplateResponse("index.html", {"request": request, "albums": albums, "categories": categories})
 
 
 
 
181
 
 
182
  @app.get("/archive")
183
  def read_archive(
184
  request: Request,
@@ -186,7 +191,8 @@ def read_archive(
186
  category: Optional[List[str]] = Query(None),
187
  search: Optional[str] = None,
188
  items_per_page: int = Query(30),
189
- page: int = Query(1)
 
190
  ):
191
  album_id = int(album) if album and album.isdigit() else None
192
  category_ids = [int(cat) for cat in category] if category else []
@@ -208,7 +214,6 @@ def read_archive(
208
  params.append(album_id)
209
 
210
  if category_ids:
211
- # Hier ist die Verknüpfungstabelle picture_categories notwendig
212
  query = """
213
  SELECT gl.timestamp, gl.prompt, gl.optimized_prompt, gl.output_file, a.name as album, GROUP_CONCAT(c.name) as categories
214
  FROM generation_logs gl
@@ -218,10 +223,10 @@ def read_archive(
218
  WHERE 1=1
219
  """
220
  if album_id is not None:
221
- query += " AND gl.album_id = ?"
222
- params.append(album_id)
223
-
224
- query += " AND pc.category_id IN ({})".format(','.join('?' for _ in category_ids))
225
  params.extend(category_ids)
226
 
227
  if search:
@@ -258,456 +263,467 @@ def read_archive(
258
  "selected_categories": category_ids,
259
  "search_query": search,
260
  "items_per_page": items_per_page,
261
- "page": page
 
262
  })
263
 
264
-
265
- @app.get("/backend")
266
- def read_backend(request: Request):
267
- with get_db_connection() as conn:
268
- cursor = conn.cursor()
269
- cursor.execute("SELECT id, name FROM albums")
270
- albums = cursor.fetchall()
271
- cursor.execute("SELECT id, name FROM categories")
272
- categories = cursor.fetchall()
273
- return templates.TemplateResponse("backend.html", {"request": request, "albums": albums, "categories": categories})
274
-
275
- @app.get("/backend/stats")
276
- async def get_backend_stats():
277
- with get_db_connection() as conn:
278
- cursor = conn.cursor()
279
-
280
- # Anzahl der Bilder (aus der pictures-Tabelle)
281
- cursor.execute("SELECT COUNT(*) FROM pictures")
282
- total_images = cursor.fetchone()[0]
283
-
284
- # Alben-Statistiken (Anzahl)
285
- cursor.execute("SELECT COUNT(*) FROM albums")
286
- total_albums = cursor.fetchone()[0]
287
-
288
- # Kategorie-Statistiken (Anzahl)
289
- cursor.execute("SELECT COUNT(*) FROM categories")
290
- total_categories = cursor.fetchone()[0]
291
-
292
- # Monatliche Statistiken (Anzahl der Bilder pro Monat)
293
- cursor.execute("""
294
- SELECT strftime('%Y-%m', timestamp) as month, COUNT(*)
295
- FROM pictures
296
- GROUP BY month
297
- ORDER BY month
298
- """)
299
- monthly_stats = [{"month": row[0], "count": row[1]} for row in cursor.fetchall()]
300
-
301
- # Speicherplatzberechnung
302
- total_size = 0
303
- for filename in os.listdir(IMAGE_STORAGE_PATH):
304
- filepath = os.path.join(IMAGE_STORAGE_PATH, filename)
305
- if os.path.isfile(filepath):
306
- total_size += os.path.getsize(filepath)
307
- total_size_mb = total_size / (1024 * 1024)
308
-
309
- # Daten für die Kategorien-Statistik (Beispiel: Anzahl der Bilder pro Kategorie)
310
- cursor.execute("""
311
- SELECT c.name, COUNT(pc.picture_id)
312
- FROM categories c
313
- LEFT JOIN picture_categories pc ON c.id = pc.category_id
314
- GROUP BY c.name
315
- """)
316
- category_stats = [{"name": row[0], "count": row[1]} for row in cursor.fetchall()]
317
-
318
- return {
319
- "total_images": total_images,
320
- "albums": {
321
- "total": total_albums
322
- },
323
- "categories": {
324
- "total": total_categories,
325
- "data": category_stats
326
- },
327
- "storage_usage_mb": total_size_mb,
328
- "monthly": monthly_stats
329
- } # Hier war die Klammer falsch gesetzt
330
-
331
- # Neue Routen für Alben
332
- @app.get("/albums")
333
- async def get_albums():
334
- with get_db_connection() as conn:
335
- cursor = conn.cursor()
336
- cursor.execute("SELECT id, name FROM albums")
337
- result = cursor.fetchall()
338
- albums = [{"id": row[0], "name": row[1]} for row in result]
339
- return albums
340
-
341
- @app.post("/create_album")
342
- async def create_album_route(name: str = Form(...), description: Optional[str] = Form(None)):
343
- try:
344
- with get_db_connection() as conn:
345
- cursor = conn.cursor()
346
- cursor.execute("INSERT INTO albums (name) VALUES (?)", (name,))
347
- conn.commit()
348
- new_album_id = cursor.lastrowid
349
- return {"message": "Album erstellt", "id": new_album_id, "name": name}
350
- except sqlite3.Error as e:
351
- raise HTTPException(status_code=500, detail=f"Error creating album: {e}")
352
-
353
- @app.delete("/delete_album/{album_id}")
354
- async def delete_album(album_id: int):
355
- try:
356
- with get_db_connection() as conn:
357
- cursor = conn.cursor()
358
- # Lösche die Verknüpfungen in picture_categories
359
- cursor.execute("DELETE FROM picture_categories WHERE picture_id IN (SELECT id FROM pictures WHERE album_id = ?)", (album_id,))
360
- # Lösche die Bilder aus der pictures-Tabelle
361
- cursor.execute("DELETE FROM pictures WHERE album_id = ?", (album_id,))
362
- # Lösche die Einträge aus generation_logs
363
- cursor.execute("DELETE FROM generation_logs WHERE album_id = ?", (album_id,))
364
- # Lösche das Album aus der albums-Tabelle
365
- cursor.execute("DELETE FROM albums WHERE id = ?", (album_id,))
366
- conn.commit()
367
- return {"message": f"Album {album_id} und zugehörige Einträge gelöscht"}
368
- except sqlite3.Error as e:
369
- raise HTTPException(status_code=500, detail=f"Error deleting album: {e}")
370
-
371
- @app.put("/update_album/{album_id}")
372
- async def update_album(album_id: int, request: Request):
373
- data = await request.json()
374
- try:
375
- with get_db_connection() as conn:
376
- cursor = conn.cursor()
377
- cursor.execute("UPDATE albums SET name = ? WHERE id = ?", (data["name"], album_id))
378
- conn.commit()
379
- if cursor.rowcount == 0:
380
- raise HTTPException(status_code=404, detail=f"Album {album_id} nicht gefunden")
381
- return {"message": f"Album {album_id} aktualisiert"}
382
- except sqlite3.Error as e:
383
- raise HTTPException(status_code=500, detail=f"Error updating album: {e}")
384
-
385
- # Neue Routen für Kategorien
386
- @app.get("/categories")
387
- async def get_categories():
388
- with get_db_connection() as conn:
389
- cursor = conn.cursor()
390
- cursor.execute("SELECT id, name FROM categories")
391
- result = cursor.fetchall()
392
- categories = [{"id": row[0], "name": row[1]} for row in result]
393
- return categories
394
-
395
- @app.post("/create_category")
396
- async def create_category_route(name: str = Form(...)):
397
- try:
398
- with get_db_connection() as conn:
399
- cursor = conn.cursor()
400
- cursor.execute("INSERT INTO categories (name) VALUES (?)", (name,))
401
- conn.commit()
402
- new_category_id = cursor.lastrowid
403
- return {"message": "Kategorie erstellt", "id": new_category_id, "name": name}
404
- except sqlite3.Error as e:
405
- raise HTTPException(status_code=500, detail=f"Error creating category: {e}")
406
-
407
- @app.delete("/delete_category/{category_id}")
408
- async def delete_category(category_id: int):
409
- try:
410
  with get_db_connection() as conn:
411
  cursor = conn.cursor()
412
- # Lösche die Verknüpfungen in picture_categories
413
- cursor.execute("DELETE FROM picture_categories WHERE category_id = ?", (category_id,))
414
- # Lösche die Kategorie aus der categories-Tabelle
415
- cursor.execute("DELETE FROM categories WHERE id = ?", (category_id,))
416
- conn.commit()
417
- return {"message": f"Kategorie {category_id} und zugehörige Einträge gelöscht"}
418
- except sqlite3.Error as e:
419
- raise HTTPException(status_code=500, detail=f"Error deleting category: {e}")
420
-
421
- @app.put("/update_category/{category_id}")
422
- async def update_category(category_id: int, request: Request):
423
- data = await request.json()
424
- try:
425
- with get_db_connection() as conn:
426
- cursor = conn.cursor()
427
- cursor.execute("UPDATE categories SET name = ? WHERE id = ?", (data["name"], category_id))
428
- conn.commit()
429
- if cursor.rowcount == 0:
430
- raise HTTPException(status_code=404, detail=f"Kategorie {category_id} nicht gefunden")
431
- return {"message": f"Kategorie {category_id} aktualisiert"}
432
- except sqlite3.Error as e:
433
- raise HTTPException(status_code=500, detail=f"Error updating category: {e}")
434
-
435
- @app.post("/flux-pics")
436
- async def download_images(request: Request):
437
- try:
438
- body = await request.json()
439
- logger.info(f"Received request body: {body}")
440
-
441
- image_files = body.get("selectedImages", [])
442
- if not image_files:
443
- raise HTTPException(status_code=400, detail="Keine Bilder ausgewählt.")
444
-
445
- logger.info(f"Processing image files: {image_files}")
446
-
447
- # Überprüfe ob Download-Verzeichnis existiert
448
- if not os.path.exists(IMAGE_STORAGE_PATH):
449
- logger.error(f"Storage path not found: {IMAGE_STORAGE_PATH}")
450
- raise HTTPException(status_code=500, detail="Storage path not found")
451
-
452
- zip_buffer = BytesIO()
453
- with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
454
- for image_file in image_files:
455
- image_path = os.path.join(IMAGE_STORAGE_PATH, image_file)
456
- logger.info(f"Processing file: {image_path}")
457
-
458
- if os.path.exists(image_path):
459
- zip_file.write(image_path, arcname=image_file)
460
- else:
461
- logger.error(f"File not found: {image_path}")
462
- raise HTTPException(status_code=404, detail=f"Bild {image_file} nicht gefunden.")
463
-
464
- zip_buffer.seek(0)
465
-
466
- # Korrekter Response mit Buffer
467
- return Response(
468
- content=zip_buffer.getvalue(),
469
- media_type="application/zip",
470
- headers={
471
- "Content-Disposition": f"attachment; filename=images.zip"
472
- }
473
- )
474
-
475
- except Exception as e:
476
- logger.error(f"Error in download_images: {str(e)}")
477
- raise HTTPException(status_code=500, detail=str(e))
478
-
479
- @app.post("/flux-pics/single")
480
- async def download_single_image(request: Request):
481
- try:
482
- data = await request.json()
483
- filename = data.get("filename")
484
- logger.info(f"Requested file download: {filename}")
485
-
486
- if not filename:
487
- logger.error("No filename provided")
488
- raise HTTPException(status_code=400, detail="Kein Dateiname angegeben")
489
-
490
- file_path = os.path.join(IMAGE_STORAGE_PATH, filename)
491
- logger.info(f"Full file path: {file_path}")
492
-
493
- if not os.path.exists(file_path):
494
- logger.error(f"File not found: {file_path}")
495
- raise HTTPException(status_code=404, detail=f"Datei {filename} nicht gefunden")
496
-
497
- # Determine MIME type
498
- file_extension = filename.lower().split('.')[-1]
499
- mime_types = {
500
- 'png': 'image/png',
501
- 'jpg': 'image/jpeg',
502
- 'jpeg': 'image/jpeg',
503
- 'gif': 'image/gif',
504
- 'webp': 'image/webp'
505
- }
506
- media_type = mime_types.get(file_extension, 'application/octet-stream')
507
- logger.info(f"Serving file with media type: {media_type}")
508
-
509
- return FileResponse(
510
- path=file_path,
511
- filename=filename,
512
- media_type=media_type,
513
- headers={
514
- "Content-Disposition": f"attachment; filename={filename}"
515
- }
516
- )
517
- except Exception as e:
518
- logger.error(f"Error in download_single_image: {str(e)}")
519
- raise HTTPException(status_code=500, detail=str(e))
520
-
521
- @app.websocket("/ws")
522
- async def websocket_endpoint(websocket: WebSocket):
523
- await websocket.accept()
524
- try:
525
- data = await websocket.receive_json()
526
- prompts = data.get("prompts", [data])
527
-
528
- for prompt_data in prompts:
529
- prompt_data["lora_scale"] = float(prompt_data["lora_scale"])
530
- prompt_data["guidance_scale"] = float(prompt_data["guidance_scale"])
531
- prompt_data["prompt_strength"] = float(prompt_data["prompt_strength"])
532
- prompt_data["num_inference_steps"] = int(prompt_data["num_inference_steps"])
533
- prompt_data["num_outputs"] = int(prompt_data["num_outputs"])
534
- prompt_data["output_quality"] = int(prompt_data["output_quality"])
535
-
536
- # Handle new album and category creation
537
- album_name = prompt_data.get("album_id")
538
- category_names = prompt_data.get("category_ids", [])
539
-
540
- if album_name and not album_name.isdigit():
541
- with get_db_connection() as conn:
542
- cursor = conn.cursor()
543
- cursor.execute(
544
- "INSERT INTO albums (name) VALUES (?)", (album_name,)
545
- )
546
- conn.commit()
547
- prompt_data["album_id"] = cursor.lastrowid
548
- else:
549
- prompt_data["album_id"] = int(album_name) if album_name else None
550
-
551
- category_ids = []
552
- for category_name in category_names:
553
- if not category_name.isdigit():
554
- with get_db_connection() as conn:
555
- cursor = conn.cursor()
556
- cursor.execute(
557
- "INSERT INTO categories (name) VALUES (?)", (category_name,)
558
- )
559
- conn.commit()
560
- category_ids.append(cursor.lastrowid)
561
- else:
562
- category_ids.append(int(category_name) if category_name else None)
563
- prompt_data["category_ids"] = category_ids
564
-
565
- args = argparse.Namespace(**prompt_data)
566
-
567
- # await websocket.send_json({"message": "Optimiere Prompt..."})
568
- optimized_prompt = (
569
- optimize_prompt(args.prompt)
570
- if getattr(args, "agent", False)
571
- else args.prompt
572
- )
573
- await websocket.send_json({"optimized_prompt": optimized_prompt})
574
-
575
- if prompt_data.get("optimize_only"):
576
- continue
577
-
578
- await generate_and_download_image(websocket, args, optimized_prompt)
579
- except WebSocketDisconnect:
580
- print("Client disconnected")
581
- except Exception as e:
582
- await websocket.send_json({"message": str(e)})
583
- raise e
584
- finally:
585
- await websocket.close()
586
-
587
- async def fetch_image(item, index, args, filenames, semaphore, websocket, timestamp):
588
- async with semaphore:
589
- try:
590
- response = requests.get(item, timeout=TIMEOUT_DURATION)
591
- if response.status_code == 200:
592
- filename = (
593
- f"{DOWNLOAD_DIR}/image_{timestamp}_{index}.{args.output_format}"
594
- )
595
- with open(filename, "wb") as file:
596
- file.write(response.content)
597
- filenames.append(
598
- f"/flux-pics/image_{timestamp}_{index}.{args.output_format}"
599
- )
600
- progress = int((index + 1) / args.num_outputs * 100)
601
- await websocket.send_json({"progress": progress})
602
- else:
603
- await websocket.send_json(
604
- {
605
- "message": f"Fehler beim Herunterladen des Bildes {index + 1}: {response.status_code}"
606
- }
607
- )
608
- except requests.exceptions.Timeout:
609
- await websocket.send_json(
610
- {"message": f"Timeout beim Herunterladen des Bildes {index + 1}"}
611
- )
612
-
613
- async def generate_and_download_image(websocket: WebSocket, args, optimized_prompt):
614
- try:
615
- input_data = {
616
- "prompt": optimized_prompt,
617
- "hf_lora": getattr(
618
- args, "hf_lora", None
619
- ), # Use getattr to safely access hf_lora
620
- "lora_scale": args.lora_scale,
621
- "num_outputs": args.num_outputs,
622
- "aspect_ratio": args.aspect_ratio,
623
- "output_format": args.output_format,
624
- "guidance_scale": args.guidance_scale,
625
- "output_quality": args.output_quality,
626
- "prompt_strength": args.prompt_strength,
627
- "num_inference_steps": args.num_inference_steps,
628
- "disable_safety_checker": False,
629
- }
630
-
631
- # await websocket.send_json({"message": "Generiere Bilder..."})
632
-
633
- # Debug: Log the start of the replication process
634
- print(
635
- f"Starting replication process for {args.num_outputs} outputs with timeout {TIMEOUT_DURATION}"
636
- )
637
-
638
- output = replicate.run(
639
- "lucataco/flux-dev-lora:091495765fa5ef2725a175a57b276ec30dc9d39c22d30410f2ede68a3eab66b3",
640
- input=input_data,
641
- timeout=TIMEOUT_DURATION,
642
- )
643
-
644
- if not os.path.exists(DOWNLOAD_DIR):
645
- os.makedirs(DOWNLOAD_DIR)
646
-
647
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
648
- filenames = []
649
- semaphore = Semaphore(3) # Limit concurrent downloads
650
-
651
- tasks = [
652
- create_task(
653
- fetch_image(
654
- item, index, args, filenames, semaphore, websocket, timestamp
655
- )
656
- )
657
- for index, item in enumerate(output)
658
- ]
659
- await gather(*tasks)
660
-
661
- for file in filenames:
662
- log_generation(args, optimized_prompt, file)
663
-
664
- await websocket.send_json(
665
- {"message": "Bilder erfolgreich generiert", "generated_files": filenames}
666
- )
667
- except requests.exceptions.Timeout:
668
- await websocket.send_json(
669
- {"message": "Fehler bei der Bildgenerierung: Timeout überschritten"}
670
- )
671
- except Exception as e:
672
- await websocket.send_json(
673
- {"message": f"Fehler bei der Bildgenerierung: {str(e)}"}
674
- )
675
- raise Exception(f"Fehler bei der Bildgenerierung: {str(e)}")
676
-
677
- def optimize_prompt(prompt):
678
- api_key = os.environ.get("MISTRAL_API_KEY")
679
- agent_id = os.environ.get("MISTRAL_FLUX_AGENT")
680
-
681
- if not api_key or not agent_id:
682
- raise ValueError("MISTRAL_API_KEY oder MISTRAL_FLUX_AGENT nicht gesetzt")
683
-
684
- client = Mistral(api_key=api_key)
685
- chat_response = client.agents.complete(
686
- agent_id=agent_id,
687
- messages=[
688
- {
689
- "role": "user",
690
- "content": f"Optimiere folgenden Prompt für Flux Lora: {prompt}",
691
- }
692
- ],
693
- )
694
-
695
- return chat_response.choices[0].message.content
696
-
697
- if __name__ == "__main__":
698
- # Parse command line arguments
699
- parser = argparse.ArgumentParser(description="Beschreibung")
700
- parser.add_argument('--hf_lora', default=None, help='HF LoRA Model')
701
- args = parser.parse_args()
702
-
703
- # Pass arguments to the FastAPI application
704
- app.state.args = args
705
-
706
- # Run the Uvicorn server
707
- uvicorn.run(
708
- "app:app",
709
- host="0.0.0.0",
710
- port=7860,
711
- reload=True,
712
- log_level="error"
713
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/env python3.11
2
+ import gradio as gr
3
+ import os
4
+ import sqlite3
5
+ import replicate
6
+ import argparse
7
+ import requests
8
+ from datetime import datetime
9
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request, Form, Query,Response
10
+ from fastapi.templating import Jinja2Templates
11
+ from fastapi.responses import FileResponse
12
+ from fastapi.staticfiles import StaticFiles
13
+ from pydantic import BaseModel
14
+ from typing import Optional, List
15
+ import uvicorn
16
+ from asyncio import gather, Semaphore, create_task
17
+ from mistralai import Mistral
18
+ from contextlib import contextmanager
19
+ from io import BytesIO
20
+ import zipfile
21
+
22
+ import sys
23
+ print(f"Arguments: {sys.argv}")
24
+
25
+
26
+ token = os.getenv("HF_TOKEN")
27
+ api_key = os.getenv("MISTRAL_API_KEY")
28
+ agent_id = os.getenv("MISTRAL_FLUX_AGENT")
29
+
30
+ # ANSI Escape Codes für farbige Ausgabe (kann entfernt werden, falls nicht benötigt)
31
+ HEADER = "\033[38;2;255;255;153m"
32
+ TITLE = "\033[38;2;255;255;153m"
33
+ MENU = "\033[38;2;255;165;0m"
34
+ SUCCESS = "\033[38;2;153;255;153m"
35
+ ERROR = "\033[38;2;255;69;0m"
36
+ MAIN = "\033[38;2;204;204;255m"
37
+ SPEAKER1 = "\033[38;2;173;216;230m"
38
+ SPEAKER2 = "\033[38;2;255;179;102m"
39
+ RESET = "\033[0m"
40
+
41
+ DOWNLOAD_DIR = "/mnt/d/ai/dialog/2/flux-pics" # Pfad zu deinen Bildern (sollte korrekt sein)
42
+ DATABASE_PATH = "flux_logs_neu.db" # Datenbank-Pfad
43
+ TIMEOUT_DURATION = 900 # Timeout-Dauer in Sekunden (scheint angemessen)
44
+
45
+ # WICHTIG: Stelle sicher, dass dieses Verzeichnis existiert und die Bilder enthält.
46
+ IMAGE_STORAGE_PATH = DOWNLOAD_DIR
47
+
48
+ app = FastAPI()
49
+ security = HTTPBasic()
50
+
51
+ # Umgebungsvariablen für Benutzername und Passwort
52
+ USERNAME = os.getenv("CF_USER", "default_user")
53
+ PASSWORD = os.getenv("CF_PASSWORD", "default_password")
54
+ # StaticFiles Middleware hinzufügen (korrekt und wichtig!)
55
+ app.mount("/static", StaticFiles(directory="static"), name="static")
56
+ app.mount("/flux-pics", StaticFiles(directory=IMAGE_STORAGE_PATH), name="flux-pics")
57
+
58
+ templates = Jinja2Templates(directory="templates")
59
+
60
+ # Datenbank-Hilfsfunktionen (sehen gut aus)
61
+ @contextmanager
62
+ def get_db_connection(db_path=DATABASE_PATH):
63
+ conn = sqlite3.connect(db_path)
64
+ try:
65
+ yield conn
66
+ finally:
67
+ conn.close()
68
+
69
+ def initialize_database(db_path=DATABASE_PATH):
70
+ with get_db_connection(db_path) as conn:
71
+ cursor = conn.cursor()
72
+ # Tabellen-Erstellung (scheint korrekt, keine Auffälligkeiten)
73
+ cursor.execute("""
74
+ CREATE TABLE IF NOT EXISTS generation_logs (
75
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
76
+ timestamp TEXT,
77
+ prompt TEXT,
78
+ optimized_prompt TEXT,
79
+ hf_lora TEXT,
80
+ lora_scale REAL,
81
+ aspect_ratio TEXT,
82
+ guidance_scale REAL,
83
+ output_quality INTEGER,
84
+ prompt_strength REAL,
85
+ num_inference_steps INTEGER,
86
+ output_file TEXT,
87
+ album_id INTEGER,
88
+ category_id INTEGER
89
+ )
90
+ """)
91
+ cursor.execute("""
92
+ CREATE TABLE IF NOT EXISTS albums (
93
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
94
+ name TEXT NOT NULL
95
+ )
96
+ """)
97
+ cursor.execute("""
98
+ CREATE TABLE IF NOT EXISTS categories (
99
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
100
+ name TEXT NOT NULL
101
+ )
102
+ """)
103
+ cursor.execute("""
104
+ CREATE TABLE IF NOT EXISTS pictures (
105
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
106
+ timestamp TEXT,
107
+ file_path TEXT,
108
+ file_name TEXT,
109
+ album_id INTEGER,
110
+ FOREIGN KEY (album_id) REFERENCES albums(id)
111
+ )
112
+ """)
113
+ cursor.execute("""
114
+ CREATE TABLE IF NOT EXISTS picture_categories (
115
+ picture_id INTEGER,
116
+ category_id INTEGER,
117
+ FOREIGN KEY (picture_id) REFERENCES pictures(id),
118
+ FOREIGN KEY (category_id) REFERENCES categories(id),
119
+ PRIMARY KEY (picture_id, category_id)
120
+ )
121
+ """)
122
+ conn.commit()
123
+ def log_generation(args, optimized_prompt, image_file):
124
+ file_path, file_name = os.path.split(image_file)
125
+ try:
126
+ with get_db_connection() as conn:
127
+ cursor = conn.cursor()
128
+ cursor.execute("""
129
+ INSERT INTO generation_logs (
130
+ timestamp, prompt, optimized_prompt, hf_lora, lora_scale, aspect_ratio, guidance_scale,
131
+ output_quality, prompt_strength, num_inference_steps, output_file, album_id, category_id
132
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
133
+ """, (
134
+ datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
135
+ args.prompt,
136
+ optimized_prompt,
137
+ args.hf_lora,
138
+ args.lora_scale,
139
+ args.aspect_ratio,
140
+ args.guidance_scale,
141
+ args.output_quality,
142
+ args.prompt_strength,
143
+ args.num_inference_steps,
144
+ image_file,
145
+ args.album_id,
146
+ args.category_ids[0] if args.category_ids else None # Hier auf erstes Element zugreifen
147
+ ))
148
+ picture_id = cursor.lastrowid # Dies scheint nicht korrekt zu sein, da die ID für die Tabelle pictures benötigt wird
149
+ cursor.execute("""
150
+ INSERT INTO pictures (
151
+ timestamp, file_path, file_name, album_id
152
+ ) VALUES (?, ?, ?, ?)
153
+ """, (
154
+ datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
155
+ file_path,
156
+ file_name,
157
+ args.album_id
158
+ ))
159
+ picture_id = cursor.lastrowid # Korrekte Zeile
160
+
161
+ # Insert multiple categories
162
+ for category_id in args.category_ids:
163
+ cursor.execute("""
164
+ INSERT INTO picture_categories (picture_id, category_id)
165
+ VALUES (?, ?)
166
+ """, (picture_id, category_id))
167
+
168
+ conn.commit()
169
+ except sqlite3.Error as e:
170
+ print(f"Error logging generation: {e}") # Sollte durch logger.error ersetzt werden.
171
+
172
+ @app.on_event("startup")
173
+ def startup_event():
174
+ initialize_database()
175
+
176
+ # Authentifizierungsfunktion
177
+ def authenticate(credentials: HTTPBasicCredentials = Depends(security)):
178
+ if credentials.username != USERNAME or credentials.password != PASSWORD:
179
+ raise HTTPException(
180
+ status_code=status.HTTP_401_UNAUTHORIZED,
181
+ detail="Ungültige Anmeldedaten",
182
+ headers={"WWW-Authenticate": "Basic"},
183
+ )
184
+ return credentials.username
185
 
186
+ # Geschützte Route für archive.html
187
  @app.get("/archive")
188
  def read_archive(
189
  request: Request,
 
191
  category: Optional[List[str]] = Query(None),
192
  search: Optional[str] = None,
193
  items_per_page: int = Query(30),
194
+ page: int = Query(1),
195
+ username: str = Depends(authenticate),
196
  ):
197
  album_id = int(album) if album and album.isdigit() else None
198
  category_ids = [int(cat) for cat in category] if category else []
 
214
  params.append(album_id)
215
 
216
  if category_ids:
 
217
  query = """
218
  SELECT gl.timestamp, gl.prompt, gl.optimized_prompt, gl.output_file, a.name as album, GROUP_CONCAT(c.name) as categories
219
  FROM generation_logs gl
 
223
  WHERE 1=1
224
  """
225
  if album_id is not None:
226
+ query += " AND gl.album_id = ?"
227
+ params.append(album_id)
228
+
229
+ query += " AND pc.category_id IN ({})".format(",".join("?" for _ in category_ids))
230
  params.extend(category_ids)
231
 
232
  if search:
 
263
  "selected_categories": category_ids,
264
  "search_query": search,
265
  "items_per_page": items_per_page,
266
+ "page": page,
267
+ "username": username,
268
  })
269
 
270
+ # Öffentliche Route
271
+ @app.get("/")
272
+ def read_root(request: Request):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  with get_db_connection() as conn:
274
  cursor = conn.cursor()
275
+ cursor.execute("SELECT id, name FROM albums")
276
+ albums = cursor.fetchall()
277
+ cursor.execute("SELECT id, name FROM categories")
278
+ categories = cursor.fetchall()
279
+ return templates.TemplateResponse("index.html", {"request": request, "albums": albums, "categories": categories})
280
+
281
+ @app.get("/backend")
282
+ def read_backend(request: Request):
283
+ with get_db_connection() as conn:
284
+ cursor = conn.cursor()
285
+ cursor.execute("SELECT id, name FROM albums")
286
+ albums = cursor.fetchall()
287
+ cursor.execute("SELECT id, name FROM categories")
288
+ categories = cursor.fetchall()
289
+ return templates.TemplateResponse("backend.html", {"request": request, "albums": albums, "categories": categories})
290
+
291
+ @app.get("/backend/stats")
292
+ async def get_backend_stats():
293
+ with get_db_connection() as conn:
294
+ cursor = conn.cursor()
295
+
296
+ # Anzahl der Bilder (aus der pictures-Tabelle)
297
+ cursor.execute("SELECT COUNT(*) FROM pictures")
298
+ total_images = cursor.fetchone()[0]
299
+
300
+ # Alben-Statistiken (Anzahl)
301
+ cursor.execute("SELECT COUNT(*) FROM albums")
302
+ total_albums = cursor.fetchone()[0]
303
+
304
+ # Kategorie-Statistiken (Anzahl)
305
+ cursor.execute("SELECT COUNT(*) FROM categories")
306
+ total_categories = cursor.fetchone()[0]
307
+
308
+ # Monatliche Statistiken (Anzahl der Bilder pro Monat)
309
+ cursor.execute("""
310
+ SELECT strftime('%Y-%m', timestamp) as month, COUNT(*)
311
+ FROM pictures
312
+ GROUP BY month
313
+ ORDER BY month
314
+ """)
315
+ monthly_stats = [{"month": row[0], "count": row[1]} for row in cursor.fetchall()]
316
+
317
+ # Speicherplatzberechnung
318
+ total_size = 0
319
+ for filename in os.listdir(IMAGE_STORAGE_PATH):
320
+ filepath = os.path.join(IMAGE_STORAGE_PATH, filename)
321
+ if os.path.isfile(filepath):
322
+ total_size += os.path.getsize(filepath)
323
+ total_size_mb = total_size / (1024 * 1024)
324
+
325
+ # Daten für die Kategorien-Statistik (Beispiel: Anzahl der Bilder pro Kategorie)
326
+ cursor.execute("""
327
+ SELECT c.name, COUNT(pc.picture_id)
328
+ FROM categories c
329
+ LEFT JOIN picture_categories pc ON c.id = pc.category_id
330
+ GROUP BY c.name
331
+ """)
332
+ category_stats = [{"name": row[0], "count": row[1]} for row in cursor.fetchall()]
333
+
334
+ return {
335
+ "total_images": total_images,
336
+ "albums": {
337
+ "total": total_albums
338
+ },
339
+ "categories": {
340
+ "total": total_categories,
341
+ "data": category_stats
342
+ },
343
+ "storage_usage_mb": total_size_mb,
344
+ "monthly": monthly_stats
345
+ } # Hier war die Klammer falsch gesetzt
346
+
347
+ # Neue Routen für Alben
348
+ @app.get("/albums")
349
+ async def get_albums():
350
+ with get_db_connection() as conn:
351
+ cursor = conn.cursor()
352
+ cursor.execute("SELECT id, name FROM albums")
353
+ result = cursor.fetchall()
354
+ albums = [{"id": row[0], "name": row[1]} for row in result]
355
+ return albums
356
+
357
+ @app.post("/create_album")
358
+ async def create_album_route(name: str = Form(...), description: Optional[str] = Form(None)):
359
+ try:
360
+ with get_db_connection() as conn:
361
+ cursor = conn.cursor()
362
+ cursor.execute("INSERT INTO albums (name) VALUES (?)", (name,))
363
+ conn.commit()
364
+ new_album_id = cursor.lastrowid
365
+ return {"message": "Album erstellt", "id": new_album_id, "name": name}
366
+ except sqlite3.Error as e:
367
+ raise HTTPException(status_code=500, detail=f"Error creating album: {e}")
368
+
369
+ @app.delete("/delete_album/{album_id}")
370
+ async def delete_album(album_id: int):
371
+ try:
372
+ with get_db_connection() as conn:
373
+ cursor = conn.cursor()
374
+ # Lösche die Verknüpfungen in picture_categories
375
+ cursor.execute("DELETE FROM picture_categories WHERE picture_id IN (SELECT id FROM pictures WHERE album_id = ?)", (album_id,))
376
+ # Lösche die Bilder aus der pictures-Tabelle
377
+ cursor.execute("DELETE FROM pictures WHERE album_id = ?", (album_id,))
378
+ # Lösche die Einträge aus generation_logs
379
+ cursor.execute("DELETE FROM generation_logs WHERE album_id = ?", (album_id,))
380
+ # Lösche das Album aus der albums-Tabelle
381
+ cursor.execute("DELETE FROM albums WHERE id = ?", (album_id,))
382
+ conn.commit()
383
+ return {"message": f"Album {album_id} und zugehörige Einträge gelöscht"}
384
+ except sqlite3.Error as e:
385
+ raise HTTPException(status_code=500, detail=f"Error deleting album: {e}")
386
+
387
+ @app.put("/update_album/{album_id}")
388
+ async def update_album(album_id: int, request: Request):
389
+ data = await request.json()
390
+ try:
391
+ with get_db_connection() as conn:
392
+ cursor = conn.cursor()
393
+ cursor.execute("UPDATE albums SET name = ? WHERE id = ?", (data["name"], album_id))
394
+ conn.commit()
395
+ if cursor.rowcount == 0:
396
+ raise HTTPException(status_code=404, detail=f"Album {album_id} nicht gefunden")
397
+ return {"message": f"Album {album_id} aktualisiert"}
398
+ except sqlite3.Error as e:
399
+ raise HTTPException(status_code=500, detail=f"Error updating album: {e}")
400
+
401
+ # Neue Routen für Kategorien
402
+ @app.get("/categories")
403
+ async def get_categories():
404
+ with get_db_connection() as conn:
405
+ cursor = conn.cursor()
406
+ cursor.execute("SELECT id, name FROM categories")
407
+ result = cursor.fetchall()
408
+ categories = [{"id": row[0], "name": row[1]} for row in result]
409
+ return categories
410
+
411
+ @app.post("/create_category")
412
+ async def create_category_route(name: str = Form(...)):
413
+ try:
414
+ with get_db_connection() as conn:
415
+ cursor = conn.cursor()
416
+ cursor.execute("INSERT INTO categories (name) VALUES (?)", (name,))
417
+ conn.commit()
418
+ new_category_id = cursor.lastrowid
419
+ return {"message": "Kategorie erstellt", "id": new_category_id, "name": name}
420
+ except sqlite3.Error as e:
421
+ raise HTTPException(status_code=500, detail=f"Error creating category: {e}")
422
+
423
+ @app.delete("/delete_category/{category_id}")
424
+ async def delete_category(category_id: int):
425
+ try:
426
+ with get_db_connection() as conn:
427
+ cursor = conn.cursor()
428
+ # Lösche die Verknüpfungen in picture_categories
429
+ cursor.execute("DELETE FROM picture_categories WHERE category_id = ?", (category_id,))
430
+ # Lösche die Kategorie aus der categories-Tabelle
431
+ cursor.execute("DELETE FROM categories WHERE id = ?", (category_id,))
432
+ conn.commit()
433
+ return {"message": f"Kategorie {category_id} und zugehörige Einträge gelöscht"}
434
+ except sqlite3.Error as e:
435
+ raise HTTPException(status_code=500, detail=f"Error deleting category: {e}")
436
+
437
+ @app.put("/update_category/{category_id}")
438
+ async def update_category(category_id: int, request: Request):
439
+ data = await request.json()
440
+ try:
441
+ with get_db_connection() as conn:
442
+ cursor = conn.cursor()
443
+ cursor.execute("UPDATE categories SET name = ? WHERE id = ?", (data["name"], category_id))
444
+ conn.commit()
445
+ if cursor.rowcount == 0:
446
+ raise HTTPException(status_code=404, detail=f"Kategorie {category_id} nicht gefunden")
447
+ return {"message": f"Kategorie {category_id} aktualisiert"}
448
+ except sqlite3.Error as e:
449
+ raise HTTPException(status_code=500, detail=f"Error updating category: {e}")
450
+
451
+ @app.post("/flux-pics")
452
+ async def download_images(request: Request):
453
+ try:
454
+ body = await request.json()
455
+ logger.info(f"Received request body: {body}")
456
+
457
+ image_files = body.get("selectedImages", [])
458
+ if not image_files:
459
+ raise HTTPException(status_code=400, detail="Keine Bilder ausgewählt.")
460
+
461
+ logger.info(f"Processing image files: {image_files}")
462
+
463
+ # Überprüfe ob Download-Verzeichnis existiert
464
+ if not os.path.exists(IMAGE_STORAGE_PATH):
465
+ logger.error(f"Storage path not found: {IMAGE_STORAGE_PATH}")
466
+ raise HTTPException(status_code=500, detail="Storage path not found")
467
+
468
+ zip_buffer = BytesIO()
469
+ with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
470
+ for image_file in image_files:
471
+ image_path = os.path.join(IMAGE_STORAGE_PATH, image_file)
472
+ logger.info(f"Processing file: {image_path}")
473
+
474
+ if os.path.exists(image_path):
475
+ zip_file.write(image_path, arcname=image_file)
476
+ else:
477
+ logger.error(f"File not found: {image_path}")
478
+ raise HTTPException(status_code=404, detail=f"Bild {image_file} nicht gefunden.")
479
+
480
+ zip_buffer.seek(0)
481
+
482
+ # Korrekter Response mit Buffer
483
+ return Response(
484
+ content=zip_buffer.getvalue(),
485
+ media_type="application/zip",
486
+ headers={
487
+ "Content-Disposition": f"attachment; filename=images.zip"
488
+ }
489
+ )
490
+
491
+ except Exception as e:
492
+ logger.error(f"Error in download_images: {str(e)}")
493
+ raise HTTPException(status_code=500, detail=str(e))
494
+
495
+ @app.post("/flux-pics/single")
496
+ async def download_single_image(request: Request):
497
+ try:
498
+ data = await request.json()
499
+ filename = data.get("filename")
500
+ logger.info(f"Requested file download: {filename}")
501
+
502
+ if not filename:
503
+ logger.error("No filename provided")
504
+ raise HTTPException(status_code=400, detail="Kein Dateiname angegeben")
505
+
506
+ file_path = os.path.join(IMAGE_STORAGE_PATH, filename)
507
+ logger.info(f"Full file path: {file_path}")
508
+
509
+ if not os.path.exists(file_path):
510
+ logger.error(f"File not found: {file_path}")
511
+ raise HTTPException(status_code=404, detail=f"Datei {filename} nicht gefunden")
512
+
513
+ # Determine MIME type
514
+ file_extension = filename.lower().split('.')[-1]
515
+ mime_types = {
516
+ 'png': 'image/png',
517
+ 'jpg': 'image/jpeg',
518
+ 'jpeg': 'image/jpeg',
519
+ 'gif': 'image/gif',
520
+ 'webp': 'image/webp'
521
+ }
522
+ media_type = mime_types.get(file_extension, 'application/octet-stream')
523
+ logger.info(f"Serving file with media type: {media_type}")
524
+
525
+ return FileResponse(
526
+ path=file_path,
527
+ filename=filename,
528
+ media_type=media_type,
529
+ headers={
530
+ "Content-Disposition": f"attachment; filename={filename}"
531
+ }
532
+ )
533
+ except Exception as e:
534
+ logger.error(f"Error in download_single_image: {str(e)}")
535
+ raise HTTPException(status_code=500, detail=str(e))
536
+
537
+ @app.websocket("/ws")
538
+ async def websocket_endpoint(websocket: WebSocket):
539
+ await websocket.accept()
540
+ try:
541
+ data = await websocket.receive_json()
542
+ prompts = data.get("prompts", [data])
543
+
544
+ for prompt_data in prompts:
545
+ prompt_data["lora_scale"] = float(prompt_data["lora_scale"])
546
+ prompt_data["guidance_scale"] = float(prompt_data["guidance_scale"])
547
+ prompt_data["prompt_strength"] = float(prompt_data["prompt_strength"])
548
+ prompt_data["num_inference_steps"] = int(prompt_data["num_inference_steps"])
549
+ prompt_data["num_outputs"] = int(prompt_data["num_outputs"])
550
+ prompt_data["output_quality"] = int(prompt_data["output_quality"])
551
+
552
+ # Handle new album and category creation
553
+ album_name = prompt_data.get("album_id")
554
+ category_names = prompt_data.get("category_ids", [])
555
+
556
+ if album_name and not album_name.isdigit():
557
+ with get_db_connection() as conn:
558
+ cursor = conn.cursor()
559
+ cursor.execute(
560
+ "INSERT INTO albums (name) VALUES (?)", (album_name,)
561
+ )
562
+ conn.commit()
563
+ prompt_data["album_id"] = cursor.lastrowid
564
+ else:
565
+ prompt_data["album_id"] = int(album_name) if album_name else None
566
+
567
+ category_ids = []
568
+ for category_name in category_names:
569
+ if not category_name.isdigit():
570
+ with get_db_connection() as conn:
571
+ cursor = conn.cursor()
572
+ cursor.execute(
573
+ "INSERT INTO categories (name) VALUES (?)", (category_name,)
574
+ )
575
+ conn.commit()
576
+ category_ids.append(cursor.lastrowid)
577
+ else:
578
+ category_ids.append(int(category_name) if category_name else None)
579
+ prompt_data["category_ids"] = category_ids
580
+
581
+ args = argparse.Namespace(**prompt_data)
582
+
583
+ # await websocket.send_json({"message": "Optimiere Prompt..."})
584
+ optimized_prompt = (
585
+ optimize_prompt(args.prompt)
586
+ if getattr(args, "agent", False)
587
+ else args.prompt
588
+ )
589
+ await websocket.send_json({"optimized_prompt": optimized_prompt})
590
+
591
+ if prompt_data.get("optimize_only"):
592
+ continue
593
+
594
+ await generate_and_download_image(websocket, args, optimized_prompt)
595
+ except WebSocketDisconnect:
596
+ print("Client disconnected")
597
+ except Exception as e:
598
+ await websocket.send_json({"message": str(e)})
599
+ raise e
600
+ finally:
601
+ await websocket.close()
602
+
603
+ async def fetch_image(item, index, args, filenames, semaphore, websocket, timestamp):
604
+ async with semaphore:
605
+ try:
606
+ response = requests.get(item, timeout=TIMEOUT_DURATION)
607
+ if response.status_code == 200:
608
+ filename = (
609
+ f"{DOWNLOAD_DIR}/image_{timestamp}_{index}.{args.output_format}"
610
+ )
611
+ with open(filename, "wb") as file:
612
+ file.write(response.content)
613
+ filenames.append(
614
+ f"/flux-pics/image_{timestamp}_{index}.{args.output_format}"
615
+ )
616
+ progress = int((index + 1) / args.num_outputs * 100)
617
+ await websocket.send_json({"progress": progress})
618
+ else:
619
+ await websocket.send_json(
620
+ {
621
+ "message": f"Fehler beim Herunterladen des Bildes {index + 1}: {response.status_code}"
622
+ }
623
+ )
624
+ except requests.exceptions.Timeout:
625
+ await websocket.send_json(
626
+ {"message": f"Timeout beim Herunterladen des Bildes {index + 1}"}
627
+ )
628
+
629
+ async def generate_and_download_image(websocket: WebSocket, args, optimized_prompt):
630
+ try:
631
+ input_data = {
632
+ "prompt": optimized_prompt,
633
+ "hf_lora": getattr(
634
+ args, "hf_lora", None
635
+ ), # Use getattr to safely access hf_lora
636
+ "lora_scale": args.lora_scale,
637
+ "num_outputs": args.num_outputs,
638
+ "aspect_ratio": args.aspect_ratio,
639
+ "output_format": args.output_format,
640
+ "guidance_scale": args.guidance_scale,
641
+ "output_quality": args.output_quality,
642
+ "prompt_strength": args.prompt_strength,
643
+ "num_inference_steps": args.num_inference_steps,
644
+ "disable_safety_checker": False,
645
+ }
646
+
647
+ # await websocket.send_json({"message": "Generiere Bilder..."})
648
+
649
+ # Debug: Log the start of the replication process
650
+ print(
651
+ f"Starting replication process for {args.num_outputs} outputs with timeout {TIMEOUT_DURATION}"
652
+ )
653
+
654
+ output = replicate.run(
655
+ "lucataco/flux-dev-lora:091495765fa5ef2725a175a57b276ec30dc9d39c22d30410f2ede68a3eab66b3",
656
+ input=input_data,
657
+ timeout=TIMEOUT_DURATION,
658
+ )
659
+
660
+ if not os.path.exists(DOWNLOAD_DIR):
661
+ os.makedirs(DOWNLOAD_DIR)
662
+
663
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
664
+ filenames = []
665
+ semaphore = Semaphore(3) # Limit concurrent downloads
666
+
667
+ tasks = [
668
+ create_task(
669
+ fetch_image(
670
+ item, index, args, filenames, semaphore, websocket, timestamp
671
+ )
672
+ )
673
+ for index, item in enumerate(output)
674
+ ]
675
+ await gather(*tasks)
676
+
677
+ for file in filenames:
678
+ log_generation(args, optimized_prompt, file)
679
+
680
+ await websocket.send_json(
681
+ {"message": "Bilder erfolgreich generiert", "generated_files": filenames}
682
+ )
683
+ except requests.exceptions.Timeout:
684
+ await websocket.send_json(
685
+ {"message": "Fehler bei der Bildgenerierung: Timeout überschritten"}
686
+ )
687
+ except Exception as e:
688
+ await websocket.send_json(
689
+ {"message": f"Fehler bei der Bildgenerierung: {str(e)}"}
690
+ )
691
+ raise Exception(f"Fehler bei der Bildgenerierung: {str(e)}")
692
+
693
+ def optimize_prompt(prompt):
694
+ api_key = os.environ.get("MISTRAL_API_KEY")
695
+ agent_id = os.environ.get("MISTRAL_FLUX_AGENT")
696
+
697
+ if not api_key or not agent_id:
698
+ raise ValueError("MISTRAL_API_KEY oder MISTRAL_FLUX_AGENT nicht gesetzt")
699
+
700
+ client = Mistral(api_key=api_key)
701
+ chat_response = client.agents.complete(
702
+ agent_id=agent_id,
703
+ messages=[
704
+ {
705
+ "role": "user",
706
+ "content": f"Optimiere folgenden Prompt für Flux Lora: {prompt}",
707
+ }
708
+ ],
709
+ )
710
+
711
+ return chat_response.choices[0].message.content
712
+
713
+ if __name__ == "__main__":
714
+ # Parse command line arguments
715
+ parser = argparse.ArgumentParser(description="Beschreibung")
716
+ parser.add_argument('--hf_lora', default=None, help='HF LoRA Model')
717
+ args = parser.parse_args()
718
+
719
+ # Pass arguments to the FastAPI application
720
+ app.state.args = args
721
+
722
+ # Run the Uvicorn server
723
+ uvicorn.run(
724
+ "app:app",
725
+ host="0.0.0.0",
726
+ port=7860,
727
+ reload=True,
728
+ log_level="debug"
729
+ )