Spaces:
Running
Running
| from time import time | |
| from statistics import mean | |
| from fastapi import BackgroundTasks, FastAPI | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from modules.details import rand_details | |
| from modules.inference import generate_image | |
| app = FastAPI(docs_url=None, redoc_url=None) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| tasks = {} | |
| def get_place_in_queue(task_id): | |
| queued_tasks = list(task for task in tasks.values() | |
| if task["status"] == "queued" or task["status"] == "processing") | |
| queued_tasks.sort(key=lambda task: task["created_at"]) | |
| queued_task_ids = list(task["task_id"] for task in queued_tasks) | |
| try: | |
| return queued_task_ids.index(task_id) + 1 | |
| except: | |
| return 0 | |
| def calculate_eta(task_id): | |
| total_durations = list(task["completed_at"] - task["started_at"] | |
| for task in tasks.values() if "completed_at" in task) | |
| initial_place_in_queue = tasks[task_id]["initial_place_in_queue"] | |
| if len(total_durations): | |
| eta = initial_place_in_queue * mean(total_durations) | |
| else: | |
| eta = initial_place_in_queue * 40 | |
| return round(eta, 1) | |
| def process_task(task_id): | |
| if 'processing' in list(task['status'] for task in tasks.values()): | |
| return | |
| tasks[task_id]["status"] = "processing" | |
| tasks[task_id]["started_at"] = time() | |
| try: | |
| tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"]) | |
| except Exception as ex: | |
| tasks[task_id]["status"] = "failed" | |
| tasks[task_id]["error"] = repr(ex) | |
| else: | |
| tasks[task_id]["status"] = "completed" | |
| finally: | |
| tasks[task_id]["completed_at"] = time() | |
| queued_tasks = list(task for task in tasks.values() if task["status"] == "queued") | |
| if queued_tasks: | |
| print(f"Tasks remaining: {len(queued_tasks)}") | |
| process_task(queued_tasks[0]["task_id"]) | |
| def index(): | |
| return FileResponse(path="static/index.html", media_type="text/html") | |
| def generate_details(): | |
| return rand_details() | |
| def create_task(background_tasks: BackgroundTasks, prompt: str = "покемон"): | |
| created_at = time() | |
| task_id = f"{str(created_at)}_{prompt}" | |
| tasks[task_id] = { | |
| "task_id": task_id, | |
| "created_at": created_at, | |
| "prompt": prompt, | |
| "status": "queued", | |
| "poll_count": 0, | |
| } | |
| tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id) | |
| background_tasks.add_task(process_task, task_id) | |
| return tasks[task_id] | |
| def poll_task(task_id: str): | |
| tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id) | |
| tasks[task_id]["eta"] = calculate_eta(task_id) | |
| tasks[task_id]["poll_count"] += 1 | |
| return tasks[task_id] | |