Scalino84 commited on
Commit
cf21e3f
·
verified ·
1 Parent(s): 4c7b6b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +598 -599
app.py CHANGED
@@ -1,177 +1,176 @@
1
- #!/bin/env python3.11
2
- import gradio as gr
3
- import os
4
- import sqlite3
5
- import replicate
6
- import argparse
7
- import requests
8
- from datetime import datetime
9
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request, Form, Query,Response
10
- from fastapi.templating import Jinja2Templates
11
- from fastapi.responses import FileResponse
12
- from fastapi.staticfiles import StaticFiles
13
- from pydantic import BaseModel
14
- from typing import Optional, List
15
- import uvicorn
16
- from asyncio import gather, Semaphore, create_task
17
- from mistralai import Mistral
18
- from contextlib import contextmanager
19
- from io import BytesIO
20
- import zipfile
21
-
22
- import sys
23
- print(f"Arguments: {sys.argv}")
24
-
25
-
26
- token = os.getenv("HF_TOKEN")
27
- api_key = os.getenv("MISTRAL_API_KEY")
28
- agent_id = os.getenv("MISTRAL_FLUX_AGENT")
29
-
30
- # ANSI Escape Codes für farbige Ausgabe (kann entfernt werden, falls nicht benötigt)
31
- HEADER = "\033[38;2;255;255;153m"
32
- TITLE = "\033[38;2;255;255;153m"
33
- MENU = "\033[38;2;255;165;0m"
34
- SUCCESS = "\033[38;2;153;255;153m"
35
- ERROR = "\033[38;2;255;69;0m"
36
- MAIN = "\033[38;2;204;204;255m"
37
- SPEAKER1 = "\033[38;2;173;216;230m"
38
- SPEAKER2 = "\033[38;2;255;179;102m"
39
- RESET = "\033[0m"
40
-
41
- DOWNLOAD_DIR = "/mnt/d/ai/dialog/2/flux-pics" # Pfad zu deinen Bildern (sollte korrekt sein)
42
- DATABASE_PATH = "flux_logs_neu.db" # Datenbank-Pfad
43
- TIMEOUT_DURATION = 900 # Timeout-Dauer in Sekunden (scheint angemessen)
44
-
45
- # WICHTIG: Stelle sicher, dass dieses Verzeichnis existiert und die Bilder enthält.
46
- IMAGE_STORAGE_PATH = DOWNLOAD_DIR
47
-
48
- app = FastAPI()
49
- security = HTTPBasic()
50
-
51
- # Umgebungsvariablen für Benutzername und Passwort
52
- USERNAME = os.getenv("CF_USER", "default_user")
53
- PASSWORD = os.getenv("CF_PASSWORD", "default_password")
54
- # StaticFiles Middleware hinzufügen (korrekt und wichtig!)
55
- app.mount("/static", StaticFiles(directory="static"), name="static")
56
- app.mount("/flux-pics", StaticFiles(directory=IMAGE_STORAGE_PATH), name="flux-pics")
57
-
58
- templates = Jinja2Templates(directory="templates")
59
-
60
- # Datenbank-Hilfsfunktionen (sehen gut aus)
61
- @contextmanager
62
- def get_db_connection(db_path=DATABASE_PATH):
63
- conn = sqlite3.connect(db_path)
64
- try:
65
- yield conn
66
- finally:
67
- conn.close()
68
-
69
- def initialize_database(db_path=DATABASE_PATH):
70
- with get_db_connection(db_path) as conn:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  cursor = conn.cursor()
72
- # Tabellen-Erstellung (scheint korrekt, keine Auffälligkeiten)
73
- cursor.execute("""
74
- CREATE TABLE IF NOT EXISTS generation_logs (
75
- id INTEGER PRIMARY KEY AUTOINCREMENT,
76
- timestamp TEXT,
77
- prompt TEXT,
78
- optimized_prompt TEXT,
79
- hf_lora TEXT,
80
- lora_scale REAL,
81
- aspect_ratio TEXT,
82
- guidance_scale REAL,
83
- output_quality INTEGER,
84
- prompt_strength REAL,
85
- num_inference_steps INTEGER,
86
- output_file TEXT,
87
- album_id INTEGER,
88
- category_id INTEGER
89
- )
90
- """)
91
- cursor.execute("""
92
- CREATE TABLE IF NOT EXISTS albums (
93
- id INTEGER PRIMARY KEY AUTOINCREMENT,
94
- name TEXT NOT NULL
95
- )
96
- """)
97
  cursor.execute("""
98
- CREATE TABLE IF NOT EXISTS categories (
99
- id INTEGER PRIMARY KEY AUTOINCREMENT,
100
- name TEXT NOT NULL
101
- )
102
- """)
103
- cursor.execute("""
104
- CREATE TABLE IF NOT EXISTS pictures (
105
- id INTEGER PRIMARY KEY AUTOINCREMENT,
106
- timestamp TEXT,
107
- file_path TEXT,
108
- file_name TEXT,
109
- album_id INTEGER,
110
- FOREIGN KEY (album_id) REFERENCES albums(id)
111
- )
112
- """)
 
 
 
 
 
113
  cursor.execute("""
114
- CREATE TABLE IF NOT EXISTS picture_categories (
115
- picture_id INTEGER,
116
- category_id INTEGER,
117
- FOREIGN KEY (picture_id) REFERENCES pictures(id),
118
- FOREIGN KEY (category_id) REFERENCES categories(id),
119
- PRIMARY KEY (picture_id, category_id)
120
- )
121
- """)
122
- conn.commit()
123
- def log_generation(args, optimized_prompt, image_file):
124
- file_path, file_name = os.path.split(image_file)
125
- try:
126
- with get_db_connection() as conn:
127
- cursor = conn.cursor()
128
  cursor.execute("""
129
- INSERT INTO generation_logs (
130
- timestamp, prompt, optimized_prompt, hf_lora, lora_scale, aspect_ratio, guidance_scale,
131
- output_quality, prompt_strength, num_inference_steps, output_file, album_id, category_id
132
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
133
- """, (
134
- datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
135
- args.prompt,
136
- optimized_prompt,
137
- args.hf_lora,
138
- args.lora_scale,
139
- args.aspect_ratio,
140
- args.guidance_scale,
141
- args.output_quality,
142
- args.prompt_strength,
143
- args.num_inference_steps,
144
- image_file,
145
- args.album_id,
146
- args.category_ids[0] if args.category_ids else None # Hier auf erstes Element zugreifen
147
- ))
148
- picture_id = cursor.lastrowid # Dies scheint nicht korrekt zu sein, da die ID für die Tabelle pictures benötigt wird
149
- cursor.execute("""
150
- INSERT INTO pictures (
151
- timestamp, file_path, file_name, album_id
152
- ) VALUES (?, ?, ?, ?)
153
- """, (
154
- datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
155
- file_path,
156
- file_name,
157
- args.album_id
158
- ))
159
- picture_id = cursor.lastrowid # Korrekte Zeile
160
-
161
- # Insert multiple categories
162
- for category_id in args.category_ids:
163
- cursor.execute("""
164
- INSERT INTO picture_categories (picture_id, category_id)
165
- VALUES (?, ?)
166
- """, (picture_id, category_id))
167
-
168
- conn.commit()
169
- except sqlite3.Error as e:
170
- print(f"Error logging generation: {e}") # Sollte durch logger.error ersetzt werden.
171
-
172
- @app.on_event("startup")
173
- def startup_event():
174
- initialize_database()
175
 
176
  # Authentifizierungsfunktion
177
  def authenticate(credentials: HTTPBasicCredentials = Depends(security)):
@@ -267,463 +266,463 @@ def read_archive(
267
  "username": username,
268
  })
269
 
270
- # Öffentliche Route
271
- @app.get("/")
272
- def read_root(request: Request):
273
- with get_db_connection() as conn:
274
- cursor = conn.cursor()
275
- cursor.execute("SELECT id, name FROM albums")
276
- albums = cursor.fetchall()
277
- cursor.execute("SELECT id, name FROM categories")
278
- categories = cursor.fetchall()
279
- return templates.TemplateResponse("index.html", {"request": request, "albums": albums, "categories": categories})
280
-
281
- @app.get("/backend")
282
- def read_backend(request: Request):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  with get_db_connection() as conn:
284
  cursor = conn.cursor()
285
- cursor.execute("SELECT id, name FROM albums")
286
- albums = cursor.fetchall()
287
- cursor.execute("SELECT id, name FROM categories")
288
- categories = cursor.fetchall()
289
- return templates.TemplateResponse("backend.html", {"request": request, "albums": albums, "categories": categories})
290
-
291
- @app.get("/backend/stats")
292
- async def get_backend_stats():
 
 
293
  with get_db_connection() as conn:
294
  cursor = conn.cursor()
295
-
296
- # Anzahl der Bilder (aus der pictures-Tabelle)
297
- cursor.execute("SELECT COUNT(*) FROM pictures")
298
- total_images = cursor.fetchone()[0]
299
-
300
- # Alben-Statistiken (Anzahl)
301
- cursor.execute("SELECT COUNT(*) FROM albums")
302
- total_albums = cursor.fetchone()[0]
303
-
304
- # Kategorie-Statistiken (Anzahl)
305
- cursor.execute("SELECT COUNT(*) FROM categories")
306
- total_categories = cursor.fetchone()[0]
307
-
308
- # Monatliche Statistiken (Anzahl der Bilder pro Monat)
309
- cursor.execute("""
310
- SELECT strftime('%Y-%m', timestamp) as month, COUNT(*)
311
- FROM pictures
312
- GROUP BY month
313
- ORDER BY month
314
- """)
315
- monthly_stats = [{"month": row[0], "count": row[1]} for row in cursor.fetchall()]
316
-
317
- # Speicherplatzberechnung
318
- total_size = 0
319
- for filename in os.listdir(IMAGE_STORAGE_PATH):
320
- filepath = os.path.join(IMAGE_STORAGE_PATH, filename)
321
- if os.path.isfile(filepath):
322
- total_size += os.path.getsize(filepath)
323
- total_size_mb = total_size / (1024 * 1024)
324
-
325
- # Daten für die Kategorien-Statistik (Beispiel: Anzahl der Bilder pro Kategorie)
326
- cursor.execute("""
327
- SELECT c.name, COUNT(pc.picture_id)
328
- FROM categories c
329
- LEFT JOIN picture_categories pc ON c.id = pc.category_id
330
- GROUP BY c.name
331
- """)
332
- category_stats = [{"name": row[0], "count": row[1]} for row in cursor.fetchall()]
333
-
334
- return {
335
- "total_images": total_images,
336
- "albums": {
337
- "total": total_albums
338
- },
339
- "categories": {
340
- "total": total_categories,
341
- "data": category_stats
342
- },
343
- "storage_usage_mb": total_size_mb,
344
- "monthly": monthly_stats
345
- } # Hier war die Klammer falsch gesetzt
346
-
347
- # Neue Routen für Alben
348
- @app.get("/albums")
349
- async def get_albums():
350
  with get_db_connection() as conn:
351
  cursor = conn.cursor()
352
- cursor.execute("SELECT id, name FROM albums")
353
- result = cursor.fetchall()
354
- albums = [{"id": row[0], "name": row[1]} for row in result]
355
- return albums
356
-
357
- @app.post("/create_album")
358
- async def create_album_route(name: str = Form(...), description: Optional[str] = Form(None)):
359
- try:
360
- with get_db_connection() as conn:
361
- cursor = conn.cursor()
362
- cursor.execute("INSERT INTO albums (name) VALUES (?)", (name,))
363
- conn.commit()
364
- new_album_id = cursor.lastrowid
365
- return {"message": "Album erstellt", "id": new_album_id, "name": name}
366
- except sqlite3.Error as e:
367
- raise HTTPException(status_code=500, detail=f"Error creating album: {e}")
368
-
369
- @app.delete("/delete_album/{album_id}")
370
- async def delete_album(album_id: int):
371
- try:
372
- with get_db_connection() as conn:
373
- cursor = conn.cursor()
374
- # Lösche die Verknüpfungen in picture_categories
375
- cursor.execute("DELETE FROM picture_categories WHERE picture_id IN (SELECT id FROM pictures WHERE album_id = ?)", (album_id,))
376
- # Lösche die Bilder aus der pictures-Tabelle
377
- cursor.execute("DELETE FROM pictures WHERE album_id = ?", (album_id,))
378
- # Lösche die Einträge aus generation_logs
379
- cursor.execute("DELETE FROM generation_logs WHERE album_id = ?", (album_id,))
380
- # Lösche das Album aus der albums-Tabelle
381
- cursor.execute("DELETE FROM albums WHERE id = ?", (album_id,))
382
- conn.commit()
383
- return {"message": f"Album {album_id} und zugehörige Einträge gelöscht"}
384
- except sqlite3.Error as e:
385
- raise HTTPException(status_code=500, detail=f"Error deleting album: {e}")
386
-
387
- @app.put("/update_album/{album_id}")
388
- async def update_album(album_id: int, request: Request):
389
- data = await request.json()
390
- try:
391
- with get_db_connection() as conn:
392
- cursor = conn.cursor()
393
- cursor.execute("UPDATE albums SET name = ? WHERE id = ?", (data["name"], album_id))
394
- conn.commit()
395
- if cursor.rowcount == 0:
396
- raise HTTPException(status_code=404, detail=f"Album {album_id} nicht gefunden")
397
- return {"message": f"Album {album_id} aktualisiert"}
398
- except sqlite3.Error as e:
399
- raise HTTPException(status_code=500, detail=f"Error updating album: {e}")
400
-
401
- # Neue Routen für Kategorien
402
- @app.get("/categories")
403
- async def get_categories():
404
  with get_db_connection() as conn:
405
  cursor = conn.cursor()
406
- cursor.execute("SELECT id, name FROM categories")
407
- result = cursor.fetchall()
408
- categories = [{"id": row[0], "name": row[1]} for row in result]
409
- return categories
410
-
411
- @app.post("/create_category")
412
- async def create_category_route(name: str = Form(...)):
413
- try:
414
- with get_db_connection() as conn:
415
- cursor = conn.cursor()
416
- cursor.execute("INSERT INTO categories (name) VALUES (?)", (name,))
417
- conn.commit()
418
- new_category_id = cursor.lastrowid
419
- return {"message": "Kategorie erstellt", "id": new_category_id, "name": name}
420
- except sqlite3.Error as e:
421
- raise HTTPException(status_code=500, detail=f"Error creating category: {e}")
422
-
423
- @app.delete("/delete_category/{category_id}")
424
- async def delete_category(category_id: int):
425
- try:
426
- with get_db_connection() as conn:
427
- cursor = conn.cursor()
428
- # Lösche die Verknüpfungen in picture_categories
429
- cursor.execute("DELETE FROM picture_categories WHERE category_id = ?", (category_id,))
430
- # Lösche die Kategorie aus der categories-Tabelle
431
- cursor.execute("DELETE FROM categories WHERE id = ?", (category_id,))
432
- conn.commit()
433
- return {"message": f"Kategorie {category_id} und zugehörige Einträge gelöscht"}
434
- except sqlite3.Error as e:
435
- raise HTTPException(status_code=500, detail=f"Error deleting category: {e}")
436
-
437
- @app.put("/update_category/{category_id}")
438
- async def update_category(category_id: int, request: Request):
439
- data = await request.json()
440
- try:
441
- with get_db_connection() as conn:
442
- cursor = conn.cursor()
443
- cursor.execute("UPDATE categories SET name = ? WHERE id = ?", (data["name"], category_id))
444
- conn.commit()
445
- if cursor.rowcount == 0:
446
- raise HTTPException(status_code=404, detail=f"Kategorie {category_id} nicht gefunden")
447
- return {"message": f"Kategorie {category_id} aktualisiert"}
448
- except sqlite3.Error as e:
449
- raise HTTPException(status_code=500, detail=f"Error updating category: {e}")
450
-
451
- @app.post("/flux-pics")
452
- async def download_images(request: Request):
453
- try:
454
- body = await request.json()
455
- logger.info(f"Received request body: {body}")
456
-
457
- image_files = body.get("selectedImages", [])
458
- if not image_files:
459
- raise HTTPException(status_code=400, detail="Keine Bilder ausgewählt.")
460
-
461
- logger.info(f"Processing image files: {image_files}")
462
-
463
- # Überprüfe ob Download-Verzeichnis existiert
464
- if not os.path.exists(IMAGE_STORAGE_PATH):
465
- logger.error(f"Storage path not found: {IMAGE_STORAGE_PATH}")
466
- raise HTTPException(status_code=500, detail="Storage path not found")
467
-
468
- zip_buffer = BytesIO()
469
- with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
470
- for image_file in image_files:
471
- image_path = os.path.join(IMAGE_STORAGE_PATH, image_file)
472
- logger.info(f"Processing file: {image_path}")
473
-
474
- if os.path.exists(image_path):
475
- zip_file.write(image_path, arcname=image_file)
476
- else:
477
- logger.error(f"File not found: {image_path}")
478
- raise HTTPException(status_code=404, detail=f"Bild {image_file} nicht gefunden.")
479
-
480
- zip_buffer.seek(0)
481
-
482
- # Korrekter Response mit Buffer
483
- return Response(
484
- content=zip_buffer.getvalue(),
485
- media_type="application/zip",
486
- headers={
487
- "Content-Disposition": f"attachment; filename=images.zip"
488
- }
489
- )
490
 
491
- except Exception as e:
492
- logger.error(f"Error in download_images: {str(e)}")
493
- raise HTTPException(status_code=500, detail=str(e))
494
 
495
- @app.post("/flux-pics/single")
496
- async def download_single_image(request: Request):
497
- try:
498
- data = await request.json()
499
- filename = data.get("filename")
500
- logger.info(f"Requested file download: {filename}")
501
-
502
- if not filename:
503
- logger.error("No filename provided")
504
- raise HTTPException(status_code=400, detail="Kein Dateiname angegeben")
505
-
506
- file_path = os.path.join(IMAGE_STORAGE_PATH, filename)
507
- logger.info(f"Full file path: {file_path}")
508
-
509
- if not os.path.exists(file_path):
510
- logger.error(f"File not found: {file_path}")
511
- raise HTTPException(status_code=404, detail=f"Datei {filename} nicht gefunden")
512
-
513
- # Determine MIME type
514
- file_extension = filename.lower().split('.')[-1]
515
- mime_types = {
516
- 'png': 'image/png',
517
- 'jpg': 'image/jpeg',
518
- 'jpeg': 'image/jpeg',
519
- 'gif': 'image/gif',
520
- 'webp': 'image/webp'
521
  }
522
- media_type = mime_types.get(file_extension, 'application/octet-stream')
523
- logger.info(f"Serving file with media type: {media_type}")
524
-
525
- return FileResponse(
526
- path=file_path,
527
- filename=filename,
528
- media_type=media_type,
529
- headers={
530
- "Content-Disposition": f"attachment; filename={filename}"
531
- }
532
- )
533
- except Exception as e:
534
- logger.error(f"Error in download_single_image: {str(e)}")
535
- raise HTTPException(status_code=500, detail=str(e))
536
 
537
- @app.websocket("/ws")
538
- async def websocket_endpoint(websocket: WebSocket):
539
- await websocket.accept()
540
- try:
541
- data = await websocket.receive_json()
542
- prompts = data.get("prompts", [data])
543
-
544
- for prompt_data in prompts:
545
- prompt_data["lora_scale"] = float(prompt_data["lora_scale"])
546
- prompt_data["guidance_scale"] = float(prompt_data["guidance_scale"])
547
- prompt_data["prompt_strength"] = float(prompt_data["prompt_strength"])
548
- prompt_data["num_inference_steps"] = int(prompt_data["num_inference_steps"])
549
- prompt_data["num_outputs"] = int(prompt_data["num_outputs"])
550
- prompt_data["output_quality"] = int(prompt_data["output_quality"])
551
-
552
- # Handle new album and category creation
553
- album_name = prompt_data.get("album_id")
554
- category_names = prompt_data.get("category_ids", [])
555
-
556
- if album_name and not album_name.isdigit():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  with get_db_connection() as conn:
558
  cursor = conn.cursor()
559
  cursor.execute(
560
- "INSERT INTO albums (name) VALUES (?)", (album_name,)
561
  )
562
  conn.commit()
563
- prompt_data["album_id"] = cursor.lastrowid
564
  else:
565
- prompt_data["album_id"] = int(album_name) if album_name else None
566
-
567
- category_ids = []
568
- for category_name in category_names:
569
- if not category_name.isdigit():
570
- with get_db_connection() as conn:
571
- cursor = conn.cursor()
572
- cursor.execute(
573
- "INSERT INTO categories (name) VALUES (?)", (category_name,)
574
- )
575
- conn.commit()
576
- category_ids.append(cursor.lastrowid)
577
- else:
578
- category_ids.append(int(category_name) if category_name else None)
579
- prompt_data["category_ids"] = category_ids
580
-
581
- args = argparse.Namespace(**prompt_data)
582
-
583
- # await websocket.send_json({"message": "Optimiere Prompt..."})
584
- optimized_prompt = (
585
- optimize_prompt(args.prompt)
586
- if getattr(args, "agent", False)
587
- else args.prompt
588
- )
589
- await websocket.send_json({"optimized_prompt": optimized_prompt})
590
-
591
- if prompt_data.get("optimize_only"):
592
- continue
593
-
594
- await generate_and_download_image(websocket, args, optimized_prompt)
595
- except WebSocketDisconnect:
596
- print("Client disconnected")
597
- except Exception as e:
598
- await websocket.send_json({"message": str(e)})
599
- raise e
600
- finally:
601
- await websocket.close()
602
-
603
- async def fetch_image(item, index, args, filenames, semaphore, websocket, timestamp):
604
- async with semaphore:
605
- try:
606
- response = requests.get(item, timeout=TIMEOUT_DURATION)
607
- if response.status_code == 200:
608
- filename = (
609
- f"{DOWNLOAD_DIR}/image_{timestamp}_{index}.{args.output_format}"
610
- )
611
- with open(filename, "wb") as file:
612
- file.write(response.content)
613
- filenames.append(
614
- f"/flux-pics/image_{timestamp}_{index}.{args.output_format}"
615
- )
616
- progress = int((index + 1) / args.num_outputs * 100)
617
- await websocket.send_json({"progress": progress})
618
- else:
619
- await websocket.send_json(
620
- {
621
- "message": f"Fehler beim Herunterladen des Bildes {index + 1}: {response.status_code}"
622
- }
623
- )
624
- except requests.exceptions.Timeout:
625
- await websocket.send_json(
626
- {"message": f"Timeout beim Herunterladen des Bildes {index + 1}"}
627
- )
628
-
629
- async def generate_and_download_image(websocket: WebSocket, args, optimized_prompt):
630
- try:
631
- input_data = {
632
- "prompt": optimized_prompt,
633
- "hf_lora": getattr(
634
- args, "hf_lora", None
635
- ), # Use getattr to safely access hf_lora
636
- "lora_scale": args.lora_scale,
637
- "num_outputs": args.num_outputs,
638
- "aspect_ratio": args.aspect_ratio,
639
- "output_format": args.output_format,
640
- "guidance_scale": args.guidance_scale,
641
- "output_quality": args.output_quality,
642
- "prompt_strength": args.prompt_strength,
643
- "num_inference_steps": args.num_inference_steps,
644
- "disable_safety_checker": False,
645
- }
646
-
647
- # await websocket.send_json({"message": "Generiere Bilder..."})
648
 
649
- # Debug: Log the start of the replication process
650
- print(
651
- f"Starting replication process for {args.num_outputs} outputs with timeout {TIMEOUT_DURATION}"
652
- )
653
 
654
- output = replicate.run(
655
- "lucataco/flux-dev-lora:091495765fa5ef2725a175a57b276ec30dc9d39c22d30410f2ede68a3eab66b3",
656
- input=input_data,
657
- timeout=TIMEOUT_DURATION,
 
658
  )
659
-
660
- if not os.path.exists(DOWNLOAD_DIR):
661
- os.makedirs(DOWNLOAD_DIR)
662
-
663
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
664
- filenames = []
665
- semaphore = Semaphore(3) # Limit concurrent downloads
666
-
667
- tasks = [
668
- create_task(
669
- fetch_image(
670
- item, index, args, filenames, semaphore, websocket, timestamp
671
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
  )
673
- for index, item in enumerate(output)
674
- ]
675
- await gather(*tasks)
676
-
677
- for file in filenames:
678
- log_generation(args, optimized_prompt, file)
679
-
680
- await websocket.send_json(
681
- {"message": "Bilder erfolgreich generiert", "generated_files": filenames}
682
- )
683
  except requests.exceptions.Timeout:
684
  await websocket.send_json(
685
- {"message": "Fehler bei der Bildgenerierung: Timeout überschritten"}
686
  )
687
- except Exception as e:
688
- await websocket.send_json(
689
- {"message": f"Fehler bei der Bildgenerierung: {str(e)}"}
690
- )
691
- raise Exception(f"Fehler bei der Bildgenerierung: {str(e)}")
692
-
693
- def optimize_prompt(prompt):
694
- api_key = os.environ.get("MISTRAL_API_KEY")
695
- agent_id = os.environ.get("MISTRAL_FLUX_AGENT")
696
-
697
- if not api_key or not agent_id:
698
- raise ValueError("MISTRAL_API_KEY oder MISTRAL_FLUX_AGENT nicht gesetzt")
699
-
700
- client = Mistral(api_key=api_key)
701
- chat_response = client.agents.complete(
702
- agent_id=agent_id,
703
- messages=[
704
- {
705
- "role": "user",
706
- "content": f"Optimiere folgenden Prompt für Flux Lora: {prompt}",
707
- }
708
- ],
 
 
709
  )
710
 
711
- return chat_response.choices[0].message.content
 
 
 
 
712
 
713
- if __name__ == "__main__":
714
- # Parse command line arguments
715
- parser = argparse.ArgumentParser(description="Beschreibung")
716
- parser.add_argument('--hf_lora', default=None, help='HF LoRA Model')
717
- args = parser.parse_args()
718
 
719
- # Pass arguments to the FastAPI application
720
- app.state.args = args
 
721
 
722
- # Run the Uvicorn server
723
- uvicorn.run(
724
- "app:app",
725
- host="0.0.0.0",
726
- port=7860,
727
- reload=True,
728
- log_level="debug"
 
 
 
 
 
 
 
 
 
 
 
 
729
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sqlite3
4
+ import replicate
5
+ import argparse
6
+ import requests
7
+ from datetime import datetime
8
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request, Form, Query,Response
9
+ from fastapi.templating import Jinja2Templates
10
+ from fastapi.responses import FileResponse
11
+ from fastapi.staticfiles import StaticFiles
12
+ from pydantic import BaseModel
13
+ from typing import Optional, List
14
+ import uvicorn
15
+ from asyncio import gather, Semaphore, create_task
16
+ from mistralai import Mistral
17
+ from contextlib import contextmanager
18
+ from io import BytesIO
19
+ import zipfile
20
+
21
+ import sys
22
+ print(f"Arguments: {sys.argv}")
23
+
24
+
25
+ token = os.getenv("HF_TOKEN")
26
+ api_key = os.getenv("MISTRAL_API_KEY")
27
+ agent_id = os.getenv("MISTRAL_FLUX_AGENT")
28
+
29
+ # ANSI Escape Codes für farbige Ausgabe (kann entfernt werden, falls nicht benötigt)
30
+ HEADER = "\033[38;2;255;255;153m"
31
+ TITLE = "\033[38;2;255;255;153m"
32
+ MENU = "\033[38;2;255;165;0m"
33
+ SUCCESS = "\033[38;2;153;255;153m"
34
+ ERROR = "\033[38;2;255;69;0m"
35
+ MAIN = "\033[38;2;204;204;255m"
36
+ SPEAKER1 = "\033[38;2;173;216;230m"
37
+ SPEAKER2 = "\033[38;2;255;179;102m"
38
+ RESET = "\033[0m"
39
+
40
+ DOWNLOAD_DIR = "/mnt/d/ai/dialog/2/flux-pics" # Pfad zu deinen Bildern (sollte korrekt sein)
41
+ DATABASE_PATH = "flux_logs_neu.db" # Datenbank-Pfad
42
+ TIMEOUT_DURATION = 900 # Timeout-Dauer in Sekunden (scheint angemessen)
43
+
44
+ # WICHTIG: Stelle sicher, dass dieses Verzeichnis existiert und die Bilder enthält.
45
+ IMAGE_STORAGE_PATH = DOWNLOAD_DIR
46
+
47
+ app = FastAPI()
48
+ security = HTTPBasic()
49
+
50
+ # Umgebungsvariablen für Benutzername und Passwort
51
+ USERNAME = os.getenv("CF_USER", "default_user")
52
+ PASSWORD = os.getenv("CF_PASSWORD", "default_password")
53
+ # StaticFiles Middleware hinzufügen (korrekt und wichtig!)
54
+ app.mount("/static", StaticFiles(directory="static"), name="static")
55
+ app.mount("/flux-pics", StaticFiles(directory=IMAGE_STORAGE_PATH), name="flux-pics")
56
+
57
+ templates = Jinja2Templates(directory="templates")
58
+
59
+ # Datenbank-Hilfsfunktionen (sehen gut aus)
60
+ @contextmanager
61
+ def get_db_connection(db_path=DATABASE_PATH):
62
+ conn = sqlite3.connect(db_path)
63
+ try:
64
+ yield conn
65
+ finally:
66
+ conn.close()
67
+
68
+ def initialize_database(db_path=DATABASE_PATH):
69
+ with get_db_connection(db_path) as conn:
70
+ cursor = conn.cursor()
71
+ # Tabellen-Erstellung (scheint korrekt, keine Auffälligkeiten)
72
+ cursor.execute("""
73
+ CREATE TABLE IF NOT EXISTS generation_logs (
74
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
75
+ timestamp TEXT,
76
+ prompt TEXT,
77
+ optimized_prompt TEXT,
78
+ hf_lora TEXT,
79
+ lora_scale REAL,
80
+ aspect_ratio TEXT,
81
+ guidance_scale REAL,
82
+ output_quality INTEGER,
83
+ prompt_strength REAL,
84
+ num_inference_steps INTEGER,
85
+ output_file TEXT,
86
+ album_id INTEGER,
87
+ category_id INTEGER
88
+ )
89
+ """)
90
+ cursor.execute("""
91
+ CREATE TABLE IF NOT EXISTS albums (
92
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
93
+ name TEXT NOT NULL
94
+ )
95
+ """)
96
+ cursor.execute("""
97
+ CREATE TABLE IF NOT EXISTS categories (
98
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
99
+ name TEXT NOT NULL
100
+ )
101
+ """)
102
+ cursor.execute("""
103
+ CREATE TABLE IF NOT EXISTS pictures (
104
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
105
+ timestamp TEXT,
106
+ file_path TEXT,
107
+ file_name TEXT,
108
+ album_id INTEGER,
109
+ FOREIGN KEY (album_id) REFERENCES albums(id)
110
+ )
111
+ """)
112
+ cursor.execute("""
113
+ CREATE TABLE IF NOT EXISTS picture_categories (
114
+ picture_id INTEGER,
115
+ category_id INTEGER,
116
+ FOREIGN KEY (picture_id) REFERENCES pictures(id),
117
+ FOREIGN KEY (category_id) REFERENCES categories(id),
118
+ PRIMARY KEY (picture_id, category_id)
119
+ )
120
+ """)
121
+ conn.commit()
122
+ def log_generation(args, optimized_prompt, image_file):
123
+ file_path, file_name = os.path.split(image_file)
124
+ try:
125
+ with get_db_connection() as conn:
126
  cursor = conn.cursor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  cursor.execute("""
128
+ INSERT INTO generation_logs (
129
+ timestamp, prompt, optimized_prompt, hf_lora, lora_scale, aspect_ratio, guidance_scale,
130
+ output_quality, prompt_strength, num_inference_steps, output_file, album_id, category_id
131
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
132
+ """, (
133
+ datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
134
+ args.prompt,
135
+ optimized_prompt,
136
+ args.hf_lora,
137
+ args.lora_scale,
138
+ args.aspect_ratio,
139
+ args.guidance_scale,
140
+ args.output_quality,
141
+ args.prompt_strength,
142
+ args.num_inference_steps,
143
+ image_file,
144
+ args.album_id,
145
+ args.category_ids[0] if args.category_ids else None # Hier auf erstes Element zugreifen
146
+ ))
147
+ picture_id = cursor.lastrowid # Dies scheint nicht korrekt zu sein, da die ID für die Tabelle pictures benötigt wird
148
  cursor.execute("""
149
+ INSERT INTO pictures (
150
+ timestamp, file_path, file_name, album_id
151
+ ) VALUES (?, ?, ?, ?)
152
+ """, (
153
+ datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
154
+ file_path,
155
+ file_name,
156
+ args.album_id
157
+ ))
158
+ picture_id = cursor.lastrowid # Korrekte Zeile
159
+
160
+ # Insert multiple categories
161
+ for category_id in args.category_ids:
 
162
  cursor.execute("""
163
+ INSERT INTO picture_categories (picture_id, category_id)
164
+ VALUES (?, ?)
165
+ """, (picture_id, category_id))
166
+
167
+ conn.commit()
168
+ except sqlite3.Error as e:
169
+ print(f"Error logging generation: {e}") # Sollte durch logger.error ersetzt werden.
170
+
171
+ @app.on_event("startup")
172
+ def startup_event():
173
+ initialize_database()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  # Authentifizierungsfunktion
176
  def authenticate(credentials: HTTPBasicCredentials = Depends(security)):
 
266
  "username": username,
267
  })
268
 
269
+ # Öffentliche Route
270
+ @app.get("/")
271
+ def read_root(request: Request):
272
+ with get_db_connection() as conn:
273
+ cursor = conn.cursor()
274
+ cursor.execute("SELECT id, name FROM albums")
275
+ albums = cursor.fetchall()
276
+ cursor.execute("SELECT id, name FROM categories")
277
+ categories = cursor.fetchall()
278
+ return templates.TemplateResponse("index.html", {"request": request, "albums": albums, "categories": categories})
279
+
280
+ @app.get("/backend")
281
+ def read_backend(request: Request):
282
+ with get_db_connection() as conn:
283
+ cursor = conn.cursor()
284
+ cursor.execute("SELECT id, name FROM albums")
285
+ albums = cursor.fetchall()
286
+ cursor.execute("SELECT id, name FROM categories")
287
+ categories = cursor.fetchall()
288
+ return templates.TemplateResponse("backend.html", {"request": request, "albums": albums, "categories": categories})
289
+
290
+ @app.get("/backend/stats")
291
+ async def get_backend_stats():
292
+ with get_db_connection() as conn:
293
+ cursor = conn.cursor()
294
+
295
+ # Anzahl der Bilder (aus der pictures-Tabelle)
296
+ cursor.execute("SELECT COUNT(*) FROM pictures")
297
+ total_images = cursor.fetchone()[0]
298
+
299
+ # Alben-Statistiken (Anzahl)
300
+ cursor.execute("SELECT COUNT(*) FROM albums")
301
+ total_albums = cursor.fetchone()[0]
302
+
303
+ # Kategorie-Statistiken (Anzahl)
304
+ cursor.execute("SELECT COUNT(*) FROM categories")
305
+ total_categories = cursor.fetchone()[0]
306
+
307
+ # Monatliche Statistiken (Anzahl der Bilder pro Monat)
308
+ cursor.execute("""
309
+ SELECT strftime('%Y-%m', timestamp) as month, COUNT(*)
310
+ FROM pictures
311
+ GROUP BY month
312
+ ORDER BY month
313
+ """)
314
+ monthly_stats = [{"month": row[0], "count": row[1]} for row in cursor.fetchall()]
315
+
316
+ # Speicherplatzberechnung
317
+ total_size = 0
318
+ for filename in os.listdir(IMAGE_STORAGE_PATH):
319
+ filepath = os.path.join(IMAGE_STORAGE_PATH, filename)
320
+ if os.path.isfile(filepath):
321
+ total_size += os.path.getsize(filepath)
322
+ total_size_mb = total_size / (1024 * 1024)
323
+
324
+ # Daten für die Kategorien-Statistik (Beispiel: Anzahl der Bilder pro Kategorie)
325
+ cursor.execute("""
326
+ SELECT c.name, COUNT(pc.picture_id)
327
+ FROM categories c
328
+ LEFT JOIN picture_categories pc ON c.id = pc.category_id
329
+ GROUP BY c.name
330
+ """)
331
+ category_stats = [{"name": row[0], "count": row[1]} for row in cursor.fetchall()]
332
+
333
+ return {
334
+ "total_images": total_images,
335
+ "albums": {
336
+ "total": total_albums
337
+ },
338
+ "categories": {
339
+ "total": total_categories,
340
+ "data": category_stats
341
+ },
342
+ "storage_usage_mb": total_size_mb,
343
+ "monthly": monthly_stats
344
+ } # Hier war die Klammer falsch gesetzt
345
+
346
+ # Neue Routen für Alben
347
+ @app.get("/albums")
348
+ async def get_albums():
349
+ with get_db_connection() as conn:
350
+ cursor = conn.cursor()
351
+ cursor.execute("SELECT id, name FROM albums")
352
+ result = cursor.fetchall()
353
+ albums = [{"id": row[0], "name": row[1]} for row in result]
354
+ return albums
355
+
356
+ @app.post("/create_album")
357
+ async def create_album_route(name: str = Form(...), description: Optional[str] = Form(None)):
358
+ try:
359
  with get_db_connection() as conn:
360
  cursor = conn.cursor()
361
+ cursor.execute("INSERT INTO albums (name) VALUES (?)", (name,))
362
+ conn.commit()
363
+ new_album_id = cursor.lastrowid
364
+ return {"message": "Album erstellt", "id": new_album_id, "name": name}
365
+ except sqlite3.Error as e:
366
+ raise HTTPException(status_code=500, detail=f"Error creating album: {e}")
367
+
368
+ @app.delete("/delete_album/{album_id}")
369
+ async def delete_album(album_id: int):
370
+ try:
371
  with get_db_connection() as conn:
372
  cursor = conn.cursor()
373
+ # Lösche die Verknüpfungen in picture_categories
374
+ cursor.execute("DELETE FROM picture_categories WHERE picture_id IN (SELECT id FROM pictures WHERE album_id = ?)", (album_id,))
375
+ # Lösche die Bilder aus der pictures-Tabelle
376
+ cursor.execute("DELETE FROM pictures WHERE album_id = ?", (album_id,))
377
+ # Lösche die Einträge aus generation_logs
378
+ cursor.execute("DELETE FROM generation_logs WHERE album_id = ?", (album_id,))
379
+ # Lösche das Album aus der albums-Tabelle
380
+ cursor.execute("DELETE FROM albums WHERE id = ?", (album_id,))
381
+ conn.commit()
382
+ return {"message": f"Album {album_id} und zugehörige Einträge gelöscht"}
383
+ except sqlite3.Error as e:
384
+ raise HTTPException(status_code=500, detail=f"Error deleting album: {e}")
385
+
386
+ @app.put("/update_album/{album_id}")
387
+ async def update_album(album_id: int, request: Request):
388
+ data = await request.json()
389
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  with get_db_connection() as conn:
391
  cursor = conn.cursor()
392
+ cursor.execute("UPDATE albums SET name = ? WHERE id = ?", (data["name"], album_id))
393
+ conn.commit()
394
+ if cursor.rowcount == 0:
395
+ raise HTTPException(status_code=404, detail=f"Album {album_id} nicht gefunden")
396
+ return {"message": f"Album {album_id} aktualisiert"}
397
+ except sqlite3.Error as e:
398
+ raise HTTPException(status_code=500, detail=f"Error updating album: {e}")
399
+
400
+ # Neue Routen für Kategorien
401
+ @app.get("/categories")
402
+ async def get_categories():
403
+ with get_db_connection() as conn:
404
+ cursor = conn.cursor()
405
+ cursor.execute("SELECT id, name FROM categories")
406
+ result = cursor.fetchall()
407
+ categories = [{"id": row[0], "name": row[1]} for row in result]
408
+ return categories
409
+
410
+ @app.post("/create_category")
411
+ async def create_category_route(name: str = Form(...)):
412
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  with get_db_connection() as conn:
414
  cursor = conn.cursor()
415
+ cursor.execute("INSERT INTO categories (name) VALUES (?)", (name,))
416
+ conn.commit()
417
+ new_category_id = cursor.lastrowid
418
+ return {"message": "Kategorie erstellt", "id": new_category_id, "name": name}
419
+ except sqlite3.Error as e:
420
+ raise HTTPException(status_code=500, detail=f"Error creating category: {e}")
421
+
422
+ @app.delete("/delete_category/{category_id}")
423
+ async def delete_category(category_id: int):
424
+ try:
425
+ with get_db_connection() as conn:
426
+ cursor = conn.cursor()
427
+ # Lösche die Verknüpfungen in picture_categories
428
+ cursor.execute("DELETE FROM picture_categories WHERE category_id = ?", (category_id,))
429
+ # Lösche die Kategorie aus der categories-Tabelle
430
+ cursor.execute("DELETE FROM categories WHERE id = ?", (category_id,))
431
+ conn.commit()
432
+ return {"message": f"Kategorie {category_id} und zugehörige Einträge gelöscht"}
433
+ except sqlite3.Error as e:
434
+ raise HTTPException(status_code=500, detail=f"Error deleting category: {e}")
435
+
436
+ @app.put("/update_category/{category_id}")
437
+ async def update_category(category_id: int, request: Request):
438
+ data = await request.json()
439
+ try:
440
+ with get_db_connection() as conn:
441
+ cursor = conn.cursor()
442
+ cursor.execute("UPDATE categories SET name = ? WHERE id = ?", (data["name"], category_id))
443
+ conn.commit()
444
+ if cursor.rowcount == 0:
445
+ raise HTTPException(status_code=404, detail=f"Kategorie {category_id} nicht gefunden")
446
+ return {"message": f"Kategorie {category_id} aktualisiert"}
447
+ except sqlite3.Error as e:
448
+ raise HTTPException(status_code=500, detail=f"Error updating category: {e}")
449
+
450
+ @app.post("/flux-pics")
451
+ async def download_images(request: Request):
452
+ try:
453
+ body = await request.json()
454
+ logger.info(f"Received request body: {body}")
455
+
456
+ image_files = body.get("selectedImages", [])
457
+ if not image_files:
458
+ raise HTTPException(status_code=400, detail="Keine Bilder ausgewählt.")
459
+
460
+ logger.info(f"Processing image files: {image_files}")
461
+
462
+ # Überprüfe ob Download-Verzeichnis existiert
463
+ if not os.path.exists(IMAGE_STORAGE_PATH):
464
+ logger.error(f"Storage path not found: {IMAGE_STORAGE_PATH}")
465
+ raise HTTPException(status_code=500, detail="Storage path not found")
466
+
467
+ zip_buffer = BytesIO()
468
+ with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
469
+ for image_file in image_files:
470
+ image_path = os.path.join(IMAGE_STORAGE_PATH, image_file)
471
+ logger.info(f"Processing file: {image_path}")
472
+
473
+ if os.path.exists(image_path):
474
+ zip_file.write(image_path, arcname=image_file)
475
+ else:
476
+ logger.error(f"File not found: {image_path}")
477
+ raise HTTPException(status_code=404, detail=f"Bild {image_file} nicht gefunden.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
+ zip_buffer.seek(0)
 
 
480
 
481
+ # Korrekter Response mit Buffer
482
+ return Response(
483
+ content=zip_buffer.getvalue(),
484
+ media_type="application/zip",
485
+ headers={
486
+ "Content-Disposition": f"attachment; filename=images.zip"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  }
488
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ except Exception as e:
491
+ logger.error(f"Error in download_images: {str(e)}")
492
+ raise HTTPException(status_code=500, detail=str(e))
493
+
494
+ @app.post("/flux-pics/single")
495
+ async def download_single_image(request: Request):
496
+ try:
497
+ data = await request.json()
498
+ filename = data.get("filename")
499
+ logger.info(f"Requested file download: {filename}")
500
+
501
+ if not filename:
502
+ logger.error("No filename provided")
503
+ raise HTTPException(status_code=400, detail="Kein Dateiname angegeben")
504
+
505
+ file_path = os.path.join(IMAGE_STORAGE_PATH, filename)
506
+ logger.info(f"Full file path: {file_path}")
507
+
508
+ if not os.path.exists(file_path):
509
+ logger.error(f"File not found: {file_path}")
510
+ raise HTTPException(status_code=404, detail=f"Datei {filename} nicht gefunden")
511
+
512
+ # Determine MIME type
513
+ file_extension = filename.lower().split('.')[-1]
514
+ mime_types = {
515
+ 'png': 'image/png',
516
+ 'jpg': 'image/jpeg',
517
+ 'jpeg': 'image/jpeg',
518
+ 'gif': 'image/gif',
519
+ 'webp': 'image/webp'
520
+ }
521
+ media_type = mime_types.get(file_extension, 'application/octet-stream')
522
+ logger.info(f"Serving file with media type: {media_type}")
523
+
524
+ return FileResponse(
525
+ path=file_path,
526
+ filename=filename,
527
+ media_type=media_type,
528
+ headers={
529
+ "Content-Disposition": f"attachment; filename={filename}"
530
+ }
531
+ )
532
+ except Exception as e:
533
+ logger.error(f"Error in download_single_image: {str(e)}")
534
+ raise HTTPException(status_code=500, detail=str(e))
535
+
536
+ @app.websocket("/ws")
537
+ async def websocket_endpoint(websocket: WebSocket):
538
+ await websocket.accept()
539
+ try:
540
+ data = await websocket.receive_json()
541
+ prompts = data.get("prompts", [data])
542
+
543
+ for prompt_data in prompts:
544
+ prompt_data["lora_scale"] = float(prompt_data["lora_scale"])
545
+ prompt_data["guidance_scale"] = float(prompt_data["guidance_scale"])
546
+ prompt_data["prompt_strength"] = float(prompt_data["prompt_strength"])
547
+ prompt_data["num_inference_steps"] = int(prompt_data["num_inference_steps"])
548
+ prompt_data["num_outputs"] = int(prompt_data["num_outputs"])
549
+ prompt_data["output_quality"] = int(prompt_data["output_quality"])
550
+
551
+ # Handle new album and category creation
552
+ album_name = prompt_data.get("album_id")
553
+ category_names = prompt_data.get("category_ids", [])
554
+
555
+ if album_name and not album_name.isdigit():
556
+ with get_db_connection() as conn:
557
+ cursor = conn.cursor()
558
+ cursor.execute(
559
+ "INSERT INTO albums (name) VALUES (?)", (album_name,)
560
+ )
561
+ conn.commit()
562
+ prompt_data["album_id"] = cursor.lastrowid
563
+ else:
564
+ prompt_data["album_id"] = int(album_name) if album_name else None
565
+
566
+ category_ids = []
567
+ for category_name in category_names:
568
+ if not category_name.isdigit():
569
  with get_db_connection() as conn:
570
  cursor = conn.cursor()
571
  cursor.execute(
572
+ "INSERT INTO categories (name) VALUES (?)", (category_name,)
573
  )
574
  conn.commit()
575
+ category_ids.append(cursor.lastrowid)
576
  else:
577
+ category_ids.append(int(category_name) if category_name else None)
578
+ prompt_data["category_ids"] = category_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
 
580
+ args = argparse.Namespace(**prompt_data)
 
 
 
581
 
582
+ # await websocket.send_json({"message": "Optimiere Prompt..."})
583
+ optimized_prompt = (
584
+ optimize_prompt(args.prompt)
585
+ if getattr(args, "agent", False)
586
+ else args.prompt
587
  )
588
+ await websocket.send_json({"optimized_prompt": optimized_prompt})
589
+
590
+ if prompt_data.get("optimize_only"):
591
+ continue
592
+
593
+ await generate_and_download_image(websocket, args, optimized_prompt)
594
+ except WebSocketDisconnect:
595
+ print("Client disconnected")
596
+ except Exception as e:
597
+ await websocket.send_json({"message": str(e)})
598
+ raise e
599
+ finally:
600
+ await websocket.close()
601
+
602
+ async def fetch_image(item, index, args, filenames, semaphore, websocket, timestamp):
603
+ async with semaphore:
604
+ try:
605
+ response = requests.get(item, timeout=TIMEOUT_DURATION)
606
+ if response.status_code == 200:
607
+ filename = (
608
+ f"{DOWNLOAD_DIR}/image_{timestamp}_{index}.{args.output_format}"
609
+ )
610
+ with open(filename, "wb") as file:
611
+ file.write(response.content)
612
+ filenames.append(
613
+ f"/flux-pics/image_{timestamp}_{index}.{args.output_format}"
614
+ )
615
+ progress = int((index + 1) / args.num_outputs * 100)
616
+ await websocket.send_json({"progress": progress})
617
+ else:
618
+ await websocket.send_json(
619
+ {
620
+ "message": f"Fehler beim Herunterladen des Bildes {index + 1}: {response.status_code}"
621
+ }
622
  )
 
 
 
 
 
 
 
 
 
 
623
  except requests.exceptions.Timeout:
624
  await websocket.send_json(
625
+ {"message": f"Timeout beim Herunterladen des Bildes {index + 1}"}
626
  )
627
+
628
+ async def generate_and_download_image(websocket: WebSocket, args, optimized_prompt):
629
+ try:
630
+ input_data = {
631
+ "prompt": optimized_prompt,
632
+ "hf_lora": getattr(
633
+ args, "hf_lora", None
634
+ ), # Use getattr to safely access hf_lora
635
+ "lora_scale": args.lora_scale,
636
+ "num_outputs": args.num_outputs,
637
+ "aspect_ratio": args.aspect_ratio,
638
+ "output_format": args.output_format,
639
+ "guidance_scale": args.guidance_scale,
640
+ "output_quality": args.output_quality,
641
+ "prompt_strength": args.prompt_strength,
642
+ "num_inference_steps": args.num_inference_steps,
643
+ "disable_safety_checker": False,
644
+ }
645
+
646
+ # await websocket.send_json({"message": "Generiere Bilder..."})
647
+
648
+ # Debug: Log the start of the replication process
649
+ print(
650
+ f"Starting replication process for {args.num_outputs} outputs with timeout {TIMEOUT_DURATION}"
651
  )
652
 
653
+ output = replicate.run(
654
+ "lucataco/flux-dev-lora:091495765fa5ef2725a175a57b276ec30dc9d39c22d30410f2ede68a3eab66b3",
655
+ input=input_data,
656
+ timeout=TIMEOUT_DURATION,
657
+ )
658
 
659
+ if not os.path.exists(DOWNLOAD_DIR):
660
+ os.makedirs(DOWNLOAD_DIR)
 
 
 
661
 
662
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
663
+ filenames = []
664
+ semaphore = Semaphore(3) # Limit concurrent downloads
665
 
666
+ tasks = [
667
+ create_task(
668
+ fetch_image(
669
+ item, index, args, filenames, semaphore, websocket, timestamp
670
+ )
671
+ )
672
+ for index, item in enumerate(output)
673
+ ]
674
+ await gather(*tasks)
675
+
676
+ for file in filenames:
677
+ log_generation(args, optimized_prompt, file)
678
+
679
+ await websocket.send_json(
680
+ {"message": "Bilder erfolgreich generiert", "generated_files": filenames}
681
+ )
682
+ except requests.exceptions.Timeout:
683
+ await websocket.send_json(
684
+ {"message": "Fehler bei der Bildgenerierung: Timeout überschritten"}
685
  )
686
+ except Exception as e:
687
+ await websocket.send_json(
688
+ {"message": f"Fehler bei der Bildgenerierung: {str(e)}"}
689
+ )
690
+ raise Exception(f"Fehler bei der Bildgenerierung: {str(e)}")
691
+
692
+ def optimize_prompt(prompt):
693
+ api_key = os.environ.get("MISTRAL_API_KEY")
694
+ agent_id = os.environ.get("MISTRAL_FLUX_AGENT")
695
+
696
+ if not api_key or not agent_id:
697
+ raise ValueError("MISTRAL_API_KEY oder MISTRAL_FLUX_AGENT nicht gesetzt")
698
+
699
+ client = Mistral(api_key=api_key)
700
+ chat_response = client.agents.complete(
701
+ agent_id=agent_id,
702
+ messages=[
703
+ {
704
+ "role": "user",
705
+ "content": f"Optimiere folgenden Prompt für Flux Lora: {prompt}",
706
+ }
707
+ ],
708
+ )
709
+
710
+ return chat_response.choices[0].message.content
711
+
712
+ if __name__ == "__main__":
713
+ # Parse command line arguments
714
+ parser = argparse.ArgumentParser(description="Beschreibung")
715
+ parser.add_argument('--hf_lora', default=None, help='HF LoRA Model')
716
+ args = parser.parse_args()
717
+
718
+ # Pass arguments to the FastAPI application
719
+ app.state.args = args
720
+
721
+ # Run the Uvicorn server
722
+ uvicorn.run(
723
+ "app:app",
724
+ host="0.0.0.0",
725
+ port=7860,
726
+ reload=True,
727
+ log_level="trace"
728
+ )