Upload /bck/hugging.py with huggingface_hub
Browse files- bck/hugging.py +523 -0
bck/hugging.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/env python3.11
|
2 |
+
import os
|
3 |
+
import sqlite3
|
4 |
+
import replicate
|
5 |
+
import argparse
|
6 |
+
import requests
|
7 |
+
from datetime import datetime
|
8 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request, Form, Query,Response
|
9 |
+
from fastapi.templating import Jinja2Templates
|
10 |
+
from fastapi.responses import FileResponse
|
11 |
+
from fastapi.staticfiles import StaticFiles
|
12 |
+
from pydantic import BaseModel
|
13 |
+
from typing import Optional, List
|
14 |
+
import uvicorn
|
15 |
+
from asyncio import gather, Semaphore, create_task
|
16 |
+
from mistralai import Mistral
|
17 |
+
from contextlib import contextmanager
|
18 |
+
from io import BytesIO
|
19 |
+
import zipfile
|
20 |
+
|
21 |
+
import sys
|
22 |
+
print(f"Arguments: {sys.argv}")
|
23 |
+
|
24 |
+
|
25 |
+
token = os.getenv("HF_TOKEN")
|
26 |
+
api_key = os.getenv("MISTRAL_API_KEY")
|
27 |
+
agent_id = os.getenv("MISTRAL_FLUX_AGENT")
|
28 |
+
|
29 |
+
|
30 |
+
HEADER = "\033[38;2;255;255;153m"
|
31 |
+
TITLE = "\033[38;2;255;255;153m"
|
32 |
+
MENU = "\033[38;2;255;165;0m"
|
33 |
+
SUCCESS = "\033[38;2;153;255;153m"
|
34 |
+
ERROR = "\033[38;2;255;69;0m"
|
35 |
+
MAIN = "\033[38;2;204;204;255m"
|
36 |
+
SPEAKER1 = "\033[38;2;173;216;230m"
|
37 |
+
SPEAKER2 = "\033[38;2;255;179;102m"
|
38 |
+
RESET = "\033[0m"
|
39 |
+
|
40 |
+
#os.system("clear")
|
41 |
+
|
42 |
+
#print(f"{HEADER}--------------------\nMY FLUX CREATOR v1.0\n--------------------{RESET}\n")
|
43 |
+
|
44 |
+
DOWNLOAD_DIR = "/mnt/d/ai/dialog/2/flux-pics"
|
45 |
+
DATABASE_PATH = "flux_logs_neu.db"
|
46 |
+
TIMEOUT_DURATION = 900 # Timeout-Dauer in Sekunden
|
47 |
+
|
48 |
+
IMAGE_STORAGE_PATH = DOWNLOAD_DIR # Pfad auf flux-pics setzen
|
49 |
+
app = FastAPI()
|
50 |
+
|
51 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
52 |
+
app.mount("/flux-pics", StaticFiles(directory=DOWNLOAD_DIR), name="flux-pics")
|
53 |
+
templates = Jinja2Templates(directory="templates")
|
54 |
+
|
55 |
+
@contextmanager
|
56 |
+
def get_db_connection(db_path=DATABASE_PATH):
|
57 |
+
conn = sqlite3.connect(db_path)
|
58 |
+
try:
|
59 |
+
yield conn
|
60 |
+
finally:
|
61 |
+
conn.close()
|
62 |
+
|
63 |
+
def initialize_database(db_path=DATABASE_PATH):
|
64 |
+
with get_db_connection(db_path) as conn:
|
65 |
+
cursor = conn.cursor()
|
66 |
+
cursor.execute("""
|
67 |
+
CREATE TABLE IF NOT EXISTS generation_logs (
|
68 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
69 |
+
timestamp TEXT,
|
70 |
+
prompt TEXT,
|
71 |
+
optimized_prompt TEXT,
|
72 |
+
hf_lora TEXT,
|
73 |
+
lora_scale REAL,
|
74 |
+
aspect_ratio TEXT,
|
75 |
+
guidance_scale REAL,
|
76 |
+
output_quality INTEGER,
|
77 |
+
prompt_strength REAL,
|
78 |
+
num_inference_steps INTEGER,
|
79 |
+
output_file TEXT,
|
80 |
+
album_id INTEGER,
|
81 |
+
category_id INTEGER
|
82 |
+
)
|
83 |
+
""")
|
84 |
+
cursor.execute("""
|
85 |
+
CREATE TABLE IF NOT EXISTS albums (
|
86 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
87 |
+
name TEXT NOT NULL
|
88 |
+
)
|
89 |
+
""")
|
90 |
+
cursor.execute("""
|
91 |
+
CREATE TABLE IF NOT EXISTS categories (
|
92 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
93 |
+
name TEXT NOT NULL
|
94 |
+
)
|
95 |
+
""")
|
96 |
+
cursor.execute("""
|
97 |
+
CREATE TABLE IF NOT EXISTS pictures (
|
98 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
99 |
+
timestamp TEXT,
|
100 |
+
file_path TEXT,
|
101 |
+
file_name TEXT,
|
102 |
+
album_id INTEGER,
|
103 |
+
FOREIGN KEY (album_id) REFERENCES albums(id)
|
104 |
+
)
|
105 |
+
""")
|
106 |
+
cursor.execute("""
|
107 |
+
CREATE TABLE IF NOT EXISTS picture_categories (
|
108 |
+
picture_id INTEGER,
|
109 |
+
category_id INTEGER,
|
110 |
+
FOREIGN KEY (picture_id) REFERENCES pictures(id),
|
111 |
+
FOREIGN KEY (category_id) REFERENCES categories(id),
|
112 |
+
PRIMARY KEY (picture_id, category_id)
|
113 |
+
)
|
114 |
+
""")
|
115 |
+
conn.commit()
|
116 |
+
|
117 |
+
def log_generation(args, optimized_prompt, image_file):
|
118 |
+
file_path, file_name = os.path.split(image_file)
|
119 |
+
try:
|
120 |
+
with get_db_connection() as conn:
|
121 |
+
cursor = conn.cursor()
|
122 |
+
cursor.execute("""
|
123 |
+
INSERT INTO generation_logs (
|
124 |
+
timestamp, prompt, optimized_prompt, hf_lora, lora_scale, aspect_ratio, guidance_scale,
|
125 |
+
output_quality, prompt_strength, num_inference_steps, output_file, album_id, category_id
|
126 |
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
127 |
+
""", (
|
128 |
+
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
129 |
+
args.prompt,
|
130 |
+
optimized_prompt,
|
131 |
+
args.hf_lora,
|
132 |
+
args.lora_scale,
|
133 |
+
args.aspect_ratio,
|
134 |
+
args.guidance_scale,
|
135 |
+
args.output_quality,
|
136 |
+
args.prompt_strength,
|
137 |
+
args.num_inference_steps,
|
138 |
+
image_file,
|
139 |
+
args.album_id,
|
140 |
+
args.category_id
|
141 |
+
))
|
142 |
+
picture_id = cursor.lastrowid
|
143 |
+
cursor.execute("""
|
144 |
+
INSERT INTO pictures (
|
145 |
+
timestamp, file_path, file_name, album_id
|
146 |
+
) VALUES (?, ?, ?, ?)
|
147 |
+
""", (
|
148 |
+
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
149 |
+
file_path,
|
150 |
+
file_name,
|
151 |
+
args.album_id
|
152 |
+
))
|
153 |
+
picture_id = cursor.lastrowid
|
154 |
+
|
155 |
+
# Insert multiple categories
|
156 |
+
for category_id in args.category_ids:
|
157 |
+
cursor.execute("""
|
158 |
+
INSERT INTO picture_categories (picture_id, category_id)
|
159 |
+
VALUES (?, ?)
|
160 |
+
""", (picture_id, category_id))
|
161 |
+
|
162 |
+
conn.commit()
|
163 |
+
except sqlite3.Error as e:
|
164 |
+
print(f"Error logging generation: {e}")
|
165 |
+
|
166 |
+
@app.on_event("startup")
|
167 |
+
def startup_event():
|
168 |
+
initialize_database()
|
169 |
+
|
170 |
+
@app.get("/")
|
171 |
+
def read_root(request: Request):
|
172 |
+
with get_db_connection() as conn:
|
173 |
+
cursor = conn.cursor()
|
174 |
+
cursor.execute("SELECT id, name FROM albums")
|
175 |
+
albums = cursor.fetchall()
|
176 |
+
cursor.execute("SELECT id, name FROM categories")
|
177 |
+
categories = cursor.fetchall()
|
178 |
+
return templates.TemplateResponse("index.html", {"request": request, "albums": albums, "categories": categories})
|
179 |
+
|
180 |
+
@app.get("/archive")
|
181 |
+
def read_archive(
|
182 |
+
request: Request,
|
183 |
+
album: Optional[str] = Query(None),
|
184 |
+
category: Optional[List[str]] = Query(None),
|
185 |
+
search: Optional[str] = None,
|
186 |
+
items_per_page: int = Query(30),
|
187 |
+
page: int = Query(1)
|
188 |
+
):
|
189 |
+
album_id = int(album) if album and album.isdigit() else None
|
190 |
+
category_ids = [int(cat) for cat in category] if category else []
|
191 |
+
offset = (page - 1) * items_per_page
|
192 |
+
|
193 |
+
with get_db_connection() as conn:
|
194 |
+
cursor = conn.cursor()
|
195 |
+
query = """
|
196 |
+
SELECT gl.timestamp, gl.prompt, gl.optimized_prompt, gl.output_file, a.name as album, c.name as category
|
197 |
+
FROM generation_logs gl
|
198 |
+
LEFT JOIN albums a ON gl.album_id = a.id
|
199 |
+
LEFT JOIN categories c ON gl.category_id = c.id
|
200 |
+
WHERE 1=1
|
201 |
+
"""
|
202 |
+
params = []
|
203 |
+
|
204 |
+
if album_id is not None:
|
205 |
+
query += " AND gl.album_id = ?"
|
206 |
+
params.append(album_id)
|
207 |
+
|
208 |
+
if category_ids:
|
209 |
+
query += " AND gl.category_id IN ({})".format(','.join('?' for _ in category_ids))
|
210 |
+
params.extend(category_ids)
|
211 |
+
|
212 |
+
if search:
|
213 |
+
query += " AND (gl.prompt LIKE ? OR gl.optimized_prompt LIKE ?)"
|
214 |
+
params.append(f'%{search}%')
|
215 |
+
params.append(f'%{search}%')
|
216 |
+
|
217 |
+
query += " ORDER BY gl.timestamp DESC LIMIT ? OFFSET ?"
|
218 |
+
params.extend([items_per_page, offset])
|
219 |
+
cursor.execute(query, params)
|
220 |
+
logs = cursor.fetchall()
|
221 |
+
|
222 |
+
logs = [{
|
223 |
+
"timestamp": log[0],
|
224 |
+
"prompt": log[1],
|
225 |
+
"optimized_prompt": log[2],
|
226 |
+
"output_file": log[3],
|
227 |
+
"album": log[4],
|
228 |
+
"category": log[5]
|
229 |
+
} for log in logs]
|
230 |
+
|
231 |
+
cursor.execute("SELECT id, name FROM albums")
|
232 |
+
albums = cursor.fetchall()
|
233 |
+
|
234 |
+
cursor.execute("SELECT id, name FROM categories")
|
235 |
+
categories = cursor.fetchall()
|
236 |
+
|
237 |
+
return templates.TemplateResponse("archive.html", {
|
238 |
+
"request": request,
|
239 |
+
"logs": logs,
|
240 |
+
"albums": albums,
|
241 |
+
"categories": categories,
|
242 |
+
"selected_album": album,
|
243 |
+
"selected_categories": category_ids,
|
244 |
+
"search_query": search,
|
245 |
+
"items_per_page": items_per_page,
|
246 |
+
"page": page
|
247 |
+
})
|
248 |
+
|
249 |
+
|
250 |
+
@app.get("/backend")
|
251 |
+
def read_backend(request: Request):
|
252 |
+
with get_db_connection() as conn:
|
253 |
+
cursor = conn.cursor()
|
254 |
+
cursor.execute("SELECT id, name FROM albums")
|
255 |
+
albums = cursor.fetchall()
|
256 |
+
cursor.execute("SELECT id, name FROM categories")
|
257 |
+
categories = cursor.fetchall()
|
258 |
+
return templates.TemplateResponse("backend.html", {"request": request, "albums": albums, "categories": categories})
|
259 |
+
|
260 |
+
@app.post("/create_album")
|
261 |
+
def create_album(name: str = Form(...)):
|
262 |
+
try:
|
263 |
+
with get_db_connection() as conn:
|
264 |
+
cursor = conn.cursor()
|
265 |
+
cursor.execute("INSERT INTO albums (name) VALUES (?)", (name,))
|
266 |
+
conn.commit()
|
267 |
+
return {"message": "Album erstellt"}
|
268 |
+
except sqlite3.Error as e:
|
269 |
+
raise HTTPException(status_code=500, detail=f"Error creating album: {e}")
|
270 |
+
|
271 |
+
@app.post("/create_category")
|
272 |
+
def create_category(name: str = Form(...)):
|
273 |
+
try:
|
274 |
+
with get_db_connection() as conn:
|
275 |
+
cursor = conn.cursor()
|
276 |
+
cursor.execute("INSERT INTO categories (name) VALUES (?)", (name,))
|
277 |
+
conn.commit()
|
278 |
+
return {"message": "Kategorie erstellt"}
|
279 |
+
except sqlite3.Error as e:
|
280 |
+
raise HTTPException(status_code=500, detail=f"Error creating category: {e}")
|
281 |
+
|
282 |
+
@app.post("/flux-pics")
|
283 |
+
async def download_images(request: Request):
|
284 |
+
try:
|
285 |
+
body = await request.json()
|
286 |
+
print(f"Received request body: {body}") # Debug log
|
287 |
+
|
288 |
+
image_files = body.get("selectedImages", [])
|
289 |
+
if not image_files:
|
290 |
+
raise HTTPException(status_code=400, detail="Keine Bilder ausgewählt.")
|
291 |
+
|
292 |
+
print(f"Processing image files: {image_files}") # Debug log
|
293 |
+
|
294 |
+
# Überprüfe ob Download-Verzeichnis existiert
|
295 |
+
if not os.path.exists(IMAGE_STORAGE_PATH):
|
296 |
+
print(f"Storage path not found: {IMAGE_STORAGE_PATH}") # Debug log
|
297 |
+
raise HTTPException(status_code=500, detail="Storage path not found")
|
298 |
+
|
299 |
+
zip_buffer = BytesIO()
|
300 |
+
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
|
301 |
+
for image_file in image_files:
|
302 |
+
image_path = os.path.join(IMAGE_STORAGE_PATH, image_file)
|
303 |
+
print(f"Processing file: {image_path}") # Debug log
|
304 |
+
|
305 |
+
if os.path.exists(image_path):
|
306 |
+
zip_file.write(image_path, arcname=image_file)
|
307 |
+
else:
|
308 |
+
print(f"File not found: {image_path}") # Debug log
|
309 |
+
raise HTTPException(status_code=404, detail=f"Bild {image_file} nicht gefunden.")
|
310 |
+
|
311 |
+
zip_buffer.seek(0)
|
312 |
+
|
313 |
+
# Korrekter Response mit Buffer
|
314 |
+
return Response(
|
315 |
+
content=zip_buffer.getvalue(),
|
316 |
+
media_type="application/zip",
|
317 |
+
headers={
|
318 |
+
"Content-Disposition": f"attachment; filename=images.zip"
|
319 |
+
}
|
320 |
+
)
|
321 |
+
|
322 |
+
except Exception as e:
|
323 |
+
print(f"Error in download_images: {str(e)}") # Debug log
|
324 |
+
raise HTTPException(status_code=500, detail=str(e))
|
325 |
+
|
326 |
+
@app.post("/flux-pics/single")
|
327 |
+
async def download_single_image(request: Request):
|
328 |
+
try:
|
329 |
+
data = await request.json()
|
330 |
+
filename = data.get("filename")
|
331 |
+
print(f"Requested file download: {filename}") # Debug log
|
332 |
+
|
333 |
+
if not filename:
|
334 |
+
print("No filename provided") # Debug log
|
335 |
+
raise HTTPException(status_code=400, detail="Kein Dateiname angegeben")
|
336 |
+
|
337 |
+
file_path = os.path.join(IMAGE_STORAGE_PATH, filename)
|
338 |
+
print(f"Full file path: {file_path}") # Debug log
|
339 |
+
|
340 |
+
if not os.path.exists(file_path):
|
341 |
+
print(f"File not found: {file_path}") # Debug log
|
342 |
+
raise HTTPException(status_code=404, detail=f"Datei {filename} nicht gefunden")
|
343 |
+
|
344 |
+
# Determine MIME type
|
345 |
+
file_extension = filename.lower().split('.')[-1]
|
346 |
+
mime_types = {
|
347 |
+
'png': 'image/png',
|
348 |
+
'jpg': 'image/jpeg',
|
349 |
+
'jpeg': 'image/jpeg',
|
350 |
+
'gif': 'image/gif',
|
351 |
+
'webp': 'image/webp'
|
352 |
+
}
|
353 |
+
media_type = mime_types.get(file_extension, 'application/octet-stream')
|
354 |
+
print(f"Serving file with media type: {media_type}") # Debug log
|
355 |
+
|
356 |
+
return FileResponse(
|
357 |
+
path=file_path,
|
358 |
+
filename=filename,
|
359 |
+
media_type=media_type,
|
360 |
+
headers={
|
361 |
+
"Content-Disposition": f"attachment; filename={filename}"
|
362 |
+
}
|
363 |
+
)
|
364 |
+
except Exception as e:
|
365 |
+
print(f"Error in download_single_image: {str(e)}") # Debug log
|
366 |
+
raise HTTPException(status_code=500, detail=str(e))
|
367 |
+
|
368 |
+
@app.websocket("/ws")
|
369 |
+
async def websocket_endpoint(websocket: WebSocket):
|
370 |
+
await websocket.accept()
|
371 |
+
try:
|
372 |
+
data = await websocket.receive_json()
|
373 |
+
prompts = data.get("prompts", [data])
|
374 |
+
|
375 |
+
for prompt_data in prompts:
|
376 |
+
prompt_data['lora_scale'] = float(prompt_data['lora_scale'])
|
377 |
+
prompt_data['guidance_scale'] = float(prompt_data['guidance_scale'])
|
378 |
+
prompt_data['prompt_strength'] = float(prompt_data['prompt_strength'])
|
379 |
+
prompt_data['num_inference_steps'] = int(prompt_data['num_inference_steps'])
|
380 |
+
prompt_data['num_outputs'] = int(prompt_data['num_outputs'])
|
381 |
+
prompt_data['output_quality'] = int(prompt_data['output_quality'])
|
382 |
+
|
383 |
+
# Handle new album and category creation
|
384 |
+
album_name = prompt_data.get('album_id')
|
385 |
+
category_names = prompt_data.get('category_ids', [])
|
386 |
+
|
387 |
+
if album_name and not album_name.isdigit():
|
388 |
+
with get_db_connection() as conn:
|
389 |
+
cursor = conn.cursor()
|
390 |
+
cursor.execute("INSERT INTO albums (name) VALUES (?)", (album_name,))
|
391 |
+
conn.commit()
|
392 |
+
prompt_data['album_id'] = cursor.lastrowid
|
393 |
+
else:
|
394 |
+
prompt_data['album_id'] = int(album_name) if album_name else None
|
395 |
+
|
396 |
+
category_ids = []
|
397 |
+
for category_name in category_names:
|
398 |
+
if not category_name.isdigit():
|
399 |
+
with get_db_connection() as conn:
|
400 |
+
cursor = conn.cursor()
|
401 |
+
cursor.execute("INSERT INTO categories (name) VALUES (?)", (category_name,))
|
402 |
+
conn.commit()
|
403 |
+
category_ids.append(cursor.lastrowid)
|
404 |
+
else:
|
405 |
+
category_ids.append(int(category_name) if category_name else None)
|
406 |
+
prompt_data['category_ids'] = category_ids
|
407 |
+
|
408 |
+
args = argparse.Namespace(**prompt_data)
|
409 |
+
|
410 |
+
await websocket.send_json({"message": "Optimiere Prompt..."})
|
411 |
+
optimized_prompt = optimize_prompt(args.prompt) if getattr(args, 'agent', False) else args.prompt
|
412 |
+
await websocket.send_json({"optimized_prompt": optimized_prompt})
|
413 |
+
|
414 |
+
if prompt_data.get("optimize_only"):
|
415 |
+
continue
|
416 |
+
|
417 |
+
await generate_and_download_image(websocket, args, optimized_prompt)
|
418 |
+
except WebSocketDisconnect:
|
419 |
+
print("Client disconnected")
|
420 |
+
except Exception as e:
|
421 |
+
await websocket.send_json({"message": str(e)})
|
422 |
+
raise e
|
423 |
+
finally:
|
424 |
+
await websocket.close()
|
425 |
+
|
426 |
+
async def fetch_image(item, index, args, filenames, semaphore, websocket, timestamp):
|
427 |
+
async with semaphore:
|
428 |
+
try:
|
429 |
+
response = requests.get(item, timeout=TIMEOUT_DURATION)
|
430 |
+
if response.status_code == 200:
|
431 |
+
filename = f"{DOWNLOAD_DIR}/image_{timestamp}_{index}.{args.output_format}"
|
432 |
+
with open(filename, "wb") as file:
|
433 |
+
file.write(response.content)
|
434 |
+
filenames.append(f"/flux-pics/image_{timestamp}_{index}.{args.output_format}")
|
435 |
+
progress = int((index + 1) / args.num_outputs * 100)
|
436 |
+
await websocket.send_json({"progress": progress})
|
437 |
+
else:
|
438 |
+
await websocket.send_json({"message": f"Fehler beim Herunterladen des Bildes {index + 1}: {response.status_code}"})
|
439 |
+
except requests.exceptions.Timeout:
|
440 |
+
await websocket.send_json({"message": f"Timeout beim Herunterladen des Bildes {index + 1}"})
|
441 |
+
|
442 |
+
async def generate_and_download_image(websocket: WebSocket, args, optimized_prompt):
|
443 |
+
try:
|
444 |
+
input_data = {
|
445 |
+
"prompt": optimized_prompt,
|
446 |
+
"hf_lora": getattr(args, 'hf_lora', None), # Use getattr to safely access hf_lora
|
447 |
+
"lora_scale": args.lora_scale,
|
448 |
+
"num_outputs": args.num_outputs,
|
449 |
+
"aspect_ratio": args.aspect_ratio,
|
450 |
+
"output_format": args.output_format,
|
451 |
+
"guidance_scale": args.guidance_scale,
|
452 |
+
"output_quality": args.output_quality,
|
453 |
+
"prompt_strength": args.prompt_strength,
|
454 |
+
"num_inference_steps": args.num_inference_steps,
|
455 |
+
"disable_safety_checker": False
|
456 |
+
}
|
457 |
+
|
458 |
+
await websocket.send_json({"message": "Generiere Bilder..."})
|
459 |
+
|
460 |
+
# Debug: Log the start of the replication process
|
461 |
+
print(f"Starting replication process for {args.num_outputs} outputs with timeout {TIMEOUT_DURATION}")
|
462 |
+
|
463 |
+
output = replicate.run(
|
464 |
+
"lucataco/flux-dev-lora:091495765fa5ef2725a175a57b276ec30dc9d39c22d30410f2ede68a3eab66b3",
|
465 |
+
input=input_data,
|
466 |
+
timeout=TIMEOUT_DURATION
|
467 |
+
)
|
468 |
+
|
469 |
+
if not os.path.exists(DOWNLOAD_DIR):
|
470 |
+
os.makedirs(DOWNLOAD_DIR)
|
471 |
+
|
472 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
473 |
+
filenames = []
|
474 |
+
semaphore = Semaphore(3) # Limit concurrent downloads
|
475 |
+
|
476 |
+
tasks = [create_task(fetch_image(item, index, args, filenames, semaphore, websocket, timestamp)) for index, item in enumerate(output)]
|
477 |
+
await gather(*tasks)
|
478 |
+
|
479 |
+
for file in filenames:
|
480 |
+
log_generation(args, optimized_prompt, file)
|
481 |
+
|
482 |
+
await websocket.send_json({"message": "Bilder erfolgreich generiert", "generated_files": filenames})
|
483 |
+
except requests.exceptions.Timeout:
|
484 |
+
await websocket.send_json({"message": "Fehler bei der Bildgenerierung: Timeout überschritten"})
|
485 |
+
except Exception as e:
|
486 |
+
await websocket.send_json({"message": f"Fehler bei der Bildgenerierung: {str(e)}"})
|
487 |
+
raise Exception(f"Fehler bei der Bildgenerierung: {str(e)}")
|
488 |
+
|
489 |
+
def optimize_prompt(prompt):
|
490 |
+
api_key = os.environ.get("MISTRAL_API_KEY")
|
491 |
+
agent_id = os.environ.get("MISTRAL_FLUX_AGENT")
|
492 |
+
|
493 |
+
if not api_key or not agent_id:
|
494 |
+
raise ValueError("MISTRAL_API_KEY oder MISTRAL_FLUX_AGENT nicht gesetzt")
|
495 |
+
|
496 |
+
client = Mistral(api_key=api_key)
|
497 |
+
chat_response = client.agents.complete(
|
498 |
+
agent_id=agent_id,
|
499 |
+
messages=[{"role": "user", "content": f"Optimiere folgenden Prompt für Flux Lora: {prompt}"}]
|
500 |
+
)
|
501 |
+
|
502 |
+
return chat_response.choices[0].message.content
|
503 |
+
|
504 |
+
if __name__ == "__main__":
|
505 |
+
# Parse command line arguments
|
506 |
+
parser = argparse.ArgumentParser(description="Beschreibung")
|
507 |
+
parser.add_argument('--hf_lora', default=None, help='HF LoRA Model')
|
508 |
+
args = parser.parse_args()
|
509 |
+
|
510 |
+
# Pass arguments to the FastAPI application
|
511 |
+
app.state.args = args
|
512 |
+
|
513 |
+
# Run the Uvicorn server
|
514 |
+
# uvicorn.run(app, host="0.0.0.0", port=8000, timeout_keep_alive=900)
|
515 |
+
|
516 |
+
# Run server
|
517 |
+
uvicorn.run(
|
518 |
+
"main:app",
|
519 |
+
host="0.0.0.0",
|
520 |
+
port=8000,
|
521 |
+
reload=True,
|
522 |
+
log_level="debug"
|
523 |
+
)
|