File size: 34,343 Bytes
57728d7
d924e11
 
 
 
 
 
57728d7
d924e11
 
14e747f
 
d924e11
57728d7
 
8ffbf61
 
 
14e747f
d924e11
 
 
 
 
57728d7
d924e11
 
2642664
57728d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2642664
 
d924e11
2642664
 
d924e11
 
 
 
14e747f
d924e11
 
 
8ffbf61
 
d924e11
 
8ffbf61
 
 
d924e11
8ffbf61
d924e11
8ffbf61
d924e11
8ffbf61
 
d924e11
 
8ffbf61
d924e11
8ffbf61
d924e11
8ffbf61
 
 
 
 
 
 
 
d924e11
 
 
8ffbf61
 
 
2642664
 
14e747f
8ffbf61
d924e11
2642664
 
14e747f
d924e11
14e747f
d924e11
8ffbf61
57728d7
 
d924e11
 
 
 
 
 
2642664
8ffbf61
14e747f
d924e11
 
 
 
 
57728d7
d924e11
 
57728d7
 
8ffbf61
 
 
 
 
 
 
d924e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2642664
d924e11
8ffbf61
14e747f
57728d7
d924e11
57728d7
d924e11
 
14e747f
57728d7
d924e11
 
14e747f
d924e11
 
 
 
 
 
 
 
 
 
 
 
 
 
57728d7
 
d924e11
 
 
 
 
dcbe972
 
57728d7
 
 
 
8ffbf61
dcbe972
57728d7
dcbe972
8ffbf61
dcbe972
57728d7
dcbe972
57728d7
dcbe972
57728d7
8ffbf61
d924e11
dcbe972
 
 
 
d924e11
dcbe972
f5b3885
dcbe972
d924e11
f5b3885
 
 
 
 
dcbe972
 
 
 
 
 
d924e11
dcbe972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d924e11
 
 
dcbe972
d924e11
dcbe972
d924e11
 
 
 
 
 
 
 
 
8ffbf61
d924e11
 
dcbe972
8ffbf61
dcbe972
8ffbf61
 
d924e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcbe972
 
 
27c67d8
dcbe972
 
8ffbf61
dcbe972
d924e11
 
dcbe972
 
d924e11
 
 
8ffbf61
d924e11
 
 
 
 
 
 
 
dcbe972
d924e11
 
dcbe972
d924e11
dcbe972
 
d924e11
 
dcbe972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27c67d8
dcbe972
 
d924e11
 
8ffbf61
dcbe972
d924e11
8ffbf61
d924e11
 
 
 
57728d7
d924e11
 
 
 
 
 
 
 
 
 
dcbe972
d924e11
 
 
 
 
 
8ffbf61
d924e11
 
dcbe972
d924e11
 
 
 
 
 
8ffbf61
d924e11
 
 
 
dcbe972
d924e11
 
57728d7
d924e11
 
 
 
 
 
 
 
 
 
 
 
 
 
8ffbf61
d924e11
 
 
 
dcbe972
d924e11
 
 
 
 
 
 
 
 
 
 
dcbe972
 
d924e11
 
 
 
8ffbf61
d924e11
 
 
 
 
 
 
 
 
8ffbf61
 
 
 
 
 
 
d924e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ffbf61
 
d924e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ffbf61
 
57728d7
dcbe972
57728d7
 
f5b3885
 
dcbe972
 
 
f5b3885
 
 
 
2642664
 
d924e11
 
 
8ffbf61
dcbe972
 
d924e11
 
 
8ffbf61
57728d7
 
 
8ffbf61
d924e11
 
 
 
 
 
 
 
 
 
 
 
57728d7
8ffbf61
57728d7
 
d924e11
 
 
 
 
57728d7
 
d924e11
8ffbf61
57728d7
d924e11
 
 
 
 
 
 
8ffbf61
d924e11
dcbe972
d924e11
 
dcbe972
d924e11
 
dcbe972
d924e11
 
 
 
 
 
 
 
 
dcbe972
d924e11
 
 
 
 
 
 
 
 
 
 
dcbe972
57728d7
8ffbf61
dcbe972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ffbf61
dcbe972
 
 
 
 
 
 
 
d924e11
dcbe972
d924e11
dcbe972
 
d924e11
 
 
 
 
 
dcbe972
 
d924e11
 
 
 
 
 
 
 
 
 
 
dcbe972
 
d924e11
dcbe972
 
d924e11
 
 
 
 
 
 
 
 
 
 
dcbe972
 
 
 
d924e11
 
 
 
 
8ffbf61
 
 
dcbe972
 
8ffbf61
d924e11
8ffbf61
dcbe972
 
d924e11
 
 
dcbe972
 
8ffbf61
 
 
 
dcbe972
8ffbf61
d924e11
 
dcbe972
d924e11
 
dcbe972
f5b3885
 
57728d7
 
d924e11
 
 
 
 
 
 
57728d7
8ffbf61
57728d7
2642664
57728d7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
import os
import shutil
import tempfile
import base64
import asyncio
from io import BytesIO

import cv2
import numpy as np
import torch
import onnxruntime as rt
from PIL import Image
import gradio as gr
from transformers import pipeline
from huggingface_hub import hf_hub_download

# Import necessary function from aesthetic_predictor_v2_5
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip


#####################################
#         Model Definitions         #
#####################################

class MLP(torch.nn.Module):
    """A simple multi-layer perceptron for image feature regression."""
    def __init__(self, input_size: int, batch_norm: bool = True):
        super().__init__()
        self.input_size = input_size
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(self.input_size, 2048),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(2048) if batch_norm else torch.nn.Identity(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(2048, 512),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(512) if batch_norm else torch.nn.Identity(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(512, 256),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(256) if batch_norm else torch.nn.Identity(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(256, 128),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(128) if batch_norm else torch.nn.Identity(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(128, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)


class WaifuScorer:
    """WaifuScorer model that uses CLIP for feature extraction and a custom MLP for scoring."""
    def __init__(self, model_path: str = None, device: str = 'cuda', cache_dir: str = None, verbose: bool = False):
        self.verbose = verbose
        self.device = device
        self.dtype = torch.float32
        self.available = False

        try:
            import clip  # local import to avoid dependency issues
            # Set default model path if not provided
            if model_path is None:
                model_path = "Eugeoter/waifu-scorer-v3/model.pth"
                if self.verbose:
                    print(f"Model path not provided. Using default: {model_path}")

            # Download model if not found locally
            if not os.path.isfile(model_path):
                username, repo_id, model_name = model_path.split("/")[-3:]
                model_path = hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=cache_dir)

            if self.verbose:
                print(f"Loading WaifuScorer model from: {model_path}")

            # Initialize MLP model
            self.mlp = MLP(input_size=768)
            # Load state dict
            if model_path.endswith(".safetensors"):
                from safetensors.torch import load_file
                state_dict = load_file(model_path)
            else:
                state_dict = torch.load(model_path, map_location=device)
            self.mlp.load_state_dict(state_dict)
            self.mlp.to(device)
            self.mlp.eval()

            # Load CLIP model for image preprocessing and feature extraction
            self.clip_model, self.preprocess = clip.load("ViT-L/14", device=device)
            self.available = True
        except Exception as e:
            print(f"Unable to initialize WaifuScorer: {e}")

    @torch.no_grad()
    def __call__(self, images):
        if not self.available:
            return [None] * (len(images) if isinstance(images, list) else 1)
        if isinstance(images, Image.Image):
            images = [images]
        n = len(images)
        # Ensure at least two images for CLIP model compatibility
        if n == 1:
            images = images * 2

        image_tensors = [self.preprocess(img).unsqueeze(0) for img in images]
        image_batch = torch.cat(image_tensors).to(self.device)
        image_features = self.clip_model.encode_image(image_batch)
        # Normalize features
        norm = image_features.norm(2, dim=-1, keepdim=True)
        norm[norm == 0] = 1
        im_emb = (image_features / norm).to(device=self.device, dtype=self.dtype)
        predictions = self.mlp(im_emb)
        scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
        return scores[:n]


#####################################
#   Aesthetic Predictor Functions   #
#####################################

def load_aesthetic_predictor_v2_5():
    """Load and return an instance of Aesthetic Predictor V2.5 with batch processing support."""
    class AestheticPredictorV2_5_Impl:
        def __init__(self):
            print("Loading Aesthetic Predictor V2.5...")
            self.model, self.preprocessor = convert_v2_5_from_siglip(
                low_cpu_mem_usage=True,
                trust_remote_code=True,
            )
            if torch.cuda.is_available():
                self.model = self.model.to(torch.bfloat16).cuda()

        def inference(self, image):
            if isinstance(image, list):
                images_rgb = [img.convert("RGB") for img in image]
                pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt").pixel_values
                if torch.cuda.is_available():
                    pixel_values = pixel_values.to(torch.bfloat16).cuda()
                with torch.inference_mode():
                    scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
                if scores.ndim == 0:
                    scores = np.array([scores])
                return scores.tolist()
            else:
                pixel_values = self.preprocessor(images=image.convert("RGB"), return_tensors="pt").pixel_values
                if torch.cuda.is_available():
                    pixel_values = pixel_values.to(torch.bfloat16).cuda()
                with torch.inference_mode():
                    score = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
                return score

    return AestheticPredictorV2_5_Impl()


def load_anime_aesthetic_model():
    """Load and return the Anime Aesthetic ONNX model."""
    model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
    return rt.InferenceSession(model_path, providers=['CPUExecutionProvider'])


def predict_anime_aesthetic(img, model):
    """Predict Anime Aesthetic score for a single image."""
    img_np = np.array(img).astype(np.float32) / 255.0
    s = 768
    h, w = img_np.shape[:2]
    if h > w:
        new_h, new_w = s, int(s * w / h)
    else:
        new_h, new_w = int(s * h / w), s
    resized = cv2.resize(img_np, (new_w, new_h))
    # Center the resized image in a square canvas
    canvas = np.zeros((s, s, 3), dtype=np.float32)
    pad_h = (s - new_h) // 2
    pad_w = (s - new_w) // 2
    canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
    # Prepare input for model
    input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
    pred = model.run(None, {"img": input_tensor})[0].item()
    return pred


#####################################
#      Image Evaluation Tool        #
#####################################

class ModelManager:
    """Manages model loading and processing requests using a queue."""
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        print("Loading models... This may take some time.")

        # Load models once during initialization
        print("Loading Aesthetic Shadow model...")
        self.aesthetic_shadow_model = pipeline("image-classification", model="NeoChen1024/aesthetic-shadow-v2-backup", device=self.device)
        print("Loading Waifu Scorer model...")
        self.waifu_scorer_model = WaifuScorer(device=self.device, verbose=True)
        print("Loading Aesthetic Predictor V2.5...")
        self.aesthetic_predictor_model = load_aesthetic_predictor_v2_5()
        print("Loading Anime Aesthetic model...")
        self.anime_aesthetic_model = load_anime_aesthetic_model()
        print("All models loaded successfully!")

        self.available_models = {
            "aesthetic_shadow": {"name": "Aesthetic Shadow", "process": self._process_aesthetic_shadow, "model": self.aesthetic_shadow_model},
            "waifu_scorer": {"name": "Waifu Scorer", "process": self._process_waifu_scorer, "model": self.waifu_scorer_model},
            "aesthetic_predictor_v2_5": {"name": "Aesthetic V2.5", "process": self._process_aesthetic_predictor_v2_5, "model": self.aesthetic_predictor_model},
            "anime_aesthetic": {"name": "Anime Score", "process": self._process_anime_aesthetic, "model": self.anime_aesthetic_model},
        }
        self.processing_queue: asyncio.Queue = asyncio.Queue()
        self.worker_task = None # Initialize worker_task to None
        self.temp_dir = tempfile.mkdtemp()

    async def start_worker(self):
        """Start the background worker task."""
        if self.worker_task is None:
            self.worker_task = asyncio.create_task(self._worker())

    async def _worker(self):
        """Background worker to process image evaluation requests from the queue."""
        while True:
            request = await self.processing_queue.get()
            if request is None: # Shutdown signal
                self.processing_queue.task_done()
                break
            try:
                results = await self._process_request(request)
                request['results_future'].set_result(results) # Fulfill the future with results
            except Exception as e:
                request['results_future'].set_exception(e) # Set exception if processing fails
            finally:
                self.processing_queue.task_done()

    async def submit_request(self, request_data):
        """Submit a new image processing request to the queue."""
        results_future = asyncio.Future() # Future to hold the results
        request = {**request_data, 'results_future': results_future}
        await self.processing_queue.put(request)
        return await results_future # Wait for and return results

    async def _process_request(self, request):
        """Process a single image evaluation request."""
        file_paths = request['file_paths']
        auto_batch = request['auto_batch']
        manual_batch_size = request['manual_batch_size']
        selected_models = request['selected_models']
        log_events = []
        images = []
        file_names = []
        final_results = []

        # Prepare images and file names
        total_files = len(file_paths)
        log_events.append(f"Starting to load {total_files} images...")
        for f in file_paths:
            try:
                img = Image.open(f).convert("RGB")
                images.append(img)
                file_names.append(os.path.basename(f))
            except Exception as e:
                log_events.append(f"Error opening {f}: {e}")

        if not images:
            log_events.append("No valid images loaded.")
            return [], log_events, 0, manual_batch_size

        log_events.append("Images loaded. Determining batch size...")

        try:
            manual_batch_size = int(manual_batch_size) if manual_batch_size is not None else 1
        except ValueError:
            manual_batch_size = 1
            log_events.append("Invalid manual batch size. Defaulting to 1.")

        optimal_batch = self.auto_tune_batch_size(images) if auto_batch else manual_batch_size
        log_events.append(f"Using batch size: {optimal_batch}")

        total_images = len(images)
        for i in range(0, total_images, optimal_batch):
            batch_images = images[i:i+optimal_batch]
            batch_file_names = file_names[i:i+optimal_batch]
            batch_index = i // optimal_batch + 1
            log_events.append(f"Processing batch {batch_index}: images {i+1} to {min(i+optimal_batch, total_images)}")

            batch_results = {}

            # Process selected models
            for model_key in selected_models:
                if self.available_models[model_key]['selected']: # Ensure model is selected
                    batch_results[model_key] = await self.available_models[model_key]['process'](batch_images, log_events) # Removed 'self' here
                else:
                    batch_results[model_key] = [None] * len(batch_images)

            # Combine results and create final results list
            for j in range(len(batch_images)):
                scores_to_average = []
                for model_key in selected_models:
                     if self.available_models[model_key]['selected']: # Ensure model is selected
                        score = batch_results[model_key][j]
                        if score is not None:
                            scores_to_average.append(score)

                final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None
                thumbnail = batch_images[j].copy()
                thumbnail.thumbnail((200, 200))
                result = {
                    'file_name': batch_file_names[j],
                    'img_data': self.image_to_base64(thumbnail),  # Keep this for the HTML display
                    'final_score': final_score,
                }
                for model_key in selected_models: # Add model scores to result
                    if self.available_models[model_key]['selected']:
                        result[model_key] = batch_results[model_key][j]
                final_results.append(result)

        log_events.append("All images processed.")
        return final_results, log_events, 100, optimal_batch


    def image_to_base64(self, image: Image.Image) -> str:
        """Convert PIL Image to base64 encoded JPEG string."""
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')

    def auto_tune_batch_size(self, images: list) -> int:
        """Automatically determine the optimal batch size for processing."""
        batch_size = 1
        max_batch = len(images)
        test_image = images[0:1]
        while batch_size <= max_batch:
            try:
                if "aesthetic_shadow" in self.available_models and self.available_models["aesthetic_shadow"]['selected']: # Check if model is available and selected
                    _ = self.available_models["aesthetic_shadow"]['model'](test_image * batch_size)
                if "waifu_scorer" in self.available_models and self.available_models["waifu_scorer"]['selected']: # Check if model is available and selected
                    _ = self.available_models["waifu_scorer"]['model'](test_image * batch_size)
                if "aesthetic_predictor_v2_5" in self.available_models and self.available_models["aesthetic_predictor_v2_5"]['selected']: # Check if model is available and selected
                    _ = self.available_models["aesthetic_predictor_v2_5"]['model'].inference(test_image * batch_size)
                batch_size *= 2
                if batch_size > max_batch:
                    break
            except Exception:
                break
        optimal = max(1, batch_size // 2)
        if optimal > 64:
            optimal = 64
            print(f"Optimal batch size determined: {optimal}")
        print(f"Optimal batch size determined: {optimal}")
        return optimal

    async def _process_aesthetic_shadow(self, batch_images, log_events):
        try:
            shadow_results = self.available_models["aesthetic_shadow"]['model'](batch_images)
            log_events.append("Aesthetic Shadow processed for batch.")
        except Exception as e:
            log_events.append(f"Error in Aesthetic Shadow: {e}")
            shadow_results = [None] * len(batch_images)
        aesthetic_shadow_scores = []
        for res in shadow_results:
            try:
                hq_score = next(p for p in res if p['label'] == 'hq')['score']
                score = float(np.clip(hq_score * 10.0, 0.0, 10.0))
            except Exception:
                score = None
            aesthetic_shadow_scores.append(score)
        log_events.append("Aesthetic Shadow scores computed for batch.")
        return aesthetic_shadow_scores

    async def _process_waifu_scorer(self, batch_images, log_events):
        try:
            waifu_scores = self.available_models["waifu_scorer"]['model'](batch_images)
            waifu_scores = [float(np.clip(s, 0.0, 10.0)) if s is not None else None for s in waifu_scores]
            log_events.append("Waifu Scorer processed for batch.")
        except Exception as e:
            log_events.append(f"Error in Waifu Scorer: {e}")
            waifu_scores = [None] * len(batch_images)
        return waifu_scores

    async def _process_aesthetic_predictor_v2_5(self, batch_images, log_events):
        try:
            v2_5_scores = self.available_models["aesthetic_predictor_v2_5"]['model'].inference(batch_images)
            v2_5_scores = [float(np.round(np.clip(s, 0.0, 10.0), 4)) if s is not None else None for s in v2_5_scores]
            log_events.append("Aesthetic Predictor V2.5 processed for batch.")
        except Exception as e:
            log_events.append(f"Error in Aesthetic Predictor V2.5: {e}")
            v2_5_scores = [None] * len(batch_images)
        return v2_5_scores

    async def _process_anime_aesthetic(self, batch_images, log_events):
        anime_scores = []
        for j, img in enumerate(batch_images):
            try:
                score = predict_anime_aesthetic(img, self.available_models["anime_aesthetic"]['model'])
                anime_scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
                log_events.append(f"Anime Aesthetic processed for image {j + 1}.")
            except Exception as e:
                log_events.append(f"Error in Anime Aesthetic for image {j + 1}: {e}")
                anime_scores.append(None)
        return anime_scores


    def _generate_progress_html(self, percentage: float) -> str:
        """Generate HTML for a progress bar given a percentage."""
        return f"""
            <div style="width:100%;background-color:#ddd; border-radius:5px;">
              <div style="width:{percentage:.1f}%; background-color:#4CAF50; text-align:center; padding:5px 0; border-radius:5px;">
                {percentage:.1f}%
              </div>
            </div>
        """

    def _format_logs(self, logs: list) -> str:
        """Format log events into an HTML string."""
        return "<div style='max-height:300px; overflow-y:auto;'>" + "<br>".join(logs) + "</div>"

    def sort_results(self, results, sort_by: str = "Final Score") -> list:
        """Sort results based on the specified column."""
        key_map = {
            "Final Score": "final_score",
            "File Name": "file_name",
            "Aesthetic Shadow": "aesthetic_shadow",
            "Waifu Scorer": "waifu_scorer",
            "Aesthetic V2.5": "aesthetic_predictor_v2_5",
            "Anime Score": "anime_aesthetic"
        }
        key = key_map.get(sort_by, "final_score")
        reverse = sort_by != "File Name"
        results.sort(key=lambda r: r.get(key) if r.get(key) is not None else (-float('inf') if not reverse else float('inf')), reverse=reverse)
        return results

    def generate_html_table(self, results: list, selected_models) -> str:
        """Generate an HTML table to display the evaluation results."""
        table_html = """
        <style>
            .results-table { width: 100%; border-collapse: collapse; margin: 20px 0; font-family: Arial, sans-serif; }
            .results-table th, .results-table td { color: #eee; border: 1px solid #ddd; padding: 8px; text-align: center; }
            .results-table th { font-weight: bold; }
            .results-table tr:nth-child(even) { background-color: transparent; }
            .results-table tr:hover { background-color: rgba(255, 255, 255, 0.1); }
            .image-preview { max-width: 150px; max-height: 150px; display: block; margin: 0 auto; }
            .good-score { color: #0f0; font-weight: bold; }
            .bad-score { color: #f00; font-weight: bold; }
            .medium-score { color: orange; font-weight: bold; }
        </style>
        <table class="results-table">
            <thead>
                <tr>
                    <th>Image</th>
                    <th>File Name</th>
        """
        visible_models = [] # Keep track of visible model columns
        if "aesthetic_shadow" in selected_models:
            table_html += "<th>Aesthetic Shadow</th>"
            visible_models.append("aesthetic_shadow")
        if "waifu_scorer" in selected_models:
            table_html += "<th>Waifu Scorer</th>"
            visible_models.append("waifu_scorer")
        if "aesthetic_predictor_v2_5" in selected_models:
            table_html += "<th>Aesthetic V2.5</th>"
            visible_models.append("aesthetic_predictor_v2_5")
        if "anime_aesthetic" in selected_models:
            table_html += "<th>Anime Score</th>"
            visible_models.append("anime_aesthetic")
        table_html += "<th>Final Score</th>"
        table_html += "</tr></thead><tbody>"

        for result in results:
            table_html += "<tr>"
            table_html += f'<td><img src="data:image/jpeg;base64,{result["img_data"]}" class="image-preview"></td>'
            table_html += f'<td>{result["file_name"]}</td>'
            for model_key in visible_models: # Iterate through visible models only
                score = result.get(model_key)
                table_html += self._format_score_cell(score)

            score = result.get("final_score")
            table_html += self._format_score_cell(score)
            table_html += "</tr>"
        table_html += """</tbody></table>"""
        return table_html

    def _format_score_cell(self, score):
        score_str = f"{score:.4f}" if isinstance(score, (int, float)) else "N/A"
        score_class = ""
        if isinstance(score, (int, float)):
            if score >= 7:
                score_class = "good-score"
            elif score >= 5:
                score_class = "medium-score"
            else:
                score_class = "bad-score"
        return f'<td class="{score_class}">{score_str}</td>'


    def cleanup(self):
        """Clean up temporary directories and shutdown worker."""
        if os.path.exists(self.temp_dir):
            shutil.rmtree(self.temp_dir)
        if self.worker_task is not None: # Check if worker_task was started
            asyncio.run(self.shutdown()) # Shutdown worker gracefully

    async def shutdown(self):
        """Send shutdown signal to worker and wait for it to finish."""
        if self.worker_task is not None: # Check if worker_task was started
            await self.processing_queue.put(None) # Send shutdown signal
            await self.worker_task # Wait for worker task to complete
            await self.processing_queue.join() # Wait for queue to be empty


#####################################
#             Interface             #
#####################################

model_manager = ModelManager() # Initialize ModelManager once outside the interface function

def create_interface():
    sort_options = ["Final Score", "File Name", "Aesthetic Shadow", "Waifu Scorer", "Aesthetic V2.5", "Anime Score"]
    model_options = ["aesthetic_shadow", "waifu_scorer", "aesthetic_predictor_v2_5", "anime_aesthetic"]

    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # Comprehensive Image Evaluation Tool

        Upload images to evaluate them using multiple aesthetic and quality prediction models.

        **New features:**
        - **Dynamic Final Score:** Final score recalculates on model selection changes.
        - **Model Selection:** Choose which models to use for evaluation.
        - **Dynamic Table Updates:** Table updates automatically based on model selection.
        - **Automatic Sorting:** Table is automatically sorted by 'Final Score'.
        - **Detailed Logs:** See major processing events (limited to the last 10).
        - **Progress Bar:** Visual indication of processing status.
        - **Asynchronous Updates:** Streaming status and logs during processing.
        - **Batch Size Controls:** Choose manual batch size or let the tool auto-detect it.
        - **Download Results:** Export the evaluation results as CSV.
        """)

        with gr.Row():
            with gr.Column(scale=1):
                input_images = gr.Files(label="Upload Images", file_count="multiple")
                model_checkboxes = gr.CheckboxGroup(model_options, label="Select Models", value=model_options, info="Choose models for evaluation.")
                auto_batch_checkbox = gr.Checkbox(label="Automatic Batch Size Detection", value=False, info="Enable to automatically determine the optimal batch size.")
                batch_size_input = gr.Number(label="Batch Size", value=1, interactive=True, info="Manually specify the batch size if auto-detection is disabled.")
                sort_dropdown = gr.Dropdown(sort_options, value="Final Score", label="Sort by", info="Select the column to sort results by.")
                process_btn = gr.Button("Evaluate Images", variant="primary")
                clear_btn = gr.Button("Clear Results")
                download_csv = gr.Button("Download CSV", variant="secondary")

            with gr.Column(scale=2):
                progress_bar = gr.HTML(label="Progress Bar", value="""
                <div style='width:100%;background-color:#ddd;'>
                  <div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div>
                </div>
                """)
                log_window = gr.HTML(label="Detailed Logs", value="<div style='max-height:300px; overflow-y:auto;'>Logs will appear here...</div>")
                status_html = gr.HTML(label="Status")
                output_html = gr.HTML(label="Evaluation Results")
                download_file_output = gr.File() # Initialize gr.File component without filename
                global_results_state = gr.State([]) # Initialize a global state to hold results

        # Function to convert results to CSV format, excluding 'img_data'.
        def results_to_csv(results, selected_models): # Take results as input
            import csv
            import io
            if not results:
                return None # Return None when no results are available
            output = io.StringIO()
            fieldnames = ['file_name', 'final_score'] # Base fieldnames
            for model_key in selected_models: # Add selected model names as fieldnames
                if model_key in selected_models: # Double check if model_key is indeed in selected_models list
                    fieldnames.append(model_key)

            writer = csv.DictWriter(output, fieldnames=fieldnames)
            writer.writeheader()
            for res in results:
                row_dict = {'file_name': res['file_name'], 'final_score': res['final_score']} # Base data
                for model_key in selected_models: # Add selected model scores
                    if model_key in selected_models: # Double check before accessing res[model_key]
                        row_dict[model_key] = res.get(model_key, 'N/A') # Use get with default 'N/A' if model not in result (shouldn't happen but for safety)
                writer.writerow(row_dict)
            return output.getvalue()


        def update_batch_size_interactivity(auto_batch):
            return gr.update(interactive=not auto_batch)

        async def process_images_and_update(files, auto_batch, manual_batch, selected_models, current_results):
            file_paths = [f.name for f in files]

            # Prepare request data for the ModelManager
            request_data = {
                'file_paths': file_paths,
                'auto_batch': auto_batch,
                'manual_batch_size': manual_batch,
                'selected_models': {model: {'selected': model in selected_models} for model in model_options} # Pass model selections
            }
            # Submit request and get results from ModelManager
            results, logs, progress_percent, updated_batch = await model_manager.submit_request(request_data)

            updated_results = current_results + results # Append new results to current results

            html_table = model_manager.generate_html_table(updated_results, selected_models)
            progress_html = model_manager._generate_progress_html(progress_percent)
            log_html = model_manager._format_logs(logs[-10:])

            return status_html, html_table, log_html, progress_html, gr.update(value=updated_batch, interactive=not auto_batch), updated_results


        def update_table_sort(sort_by_column, selected_models, current_results):
            sorted_results = model_manager.sort_results(current_results, sort_by_column)
            return model_manager.generate_html_table(sorted_results, selected_models), sorted_results # Return sorted results

        def update_table_model_selection(selected_models, current_results):
            # Recalculate final scores based on selected models
            for result in current_results:
                scores_to_average = []
                for model_key in model_options: # Use model_options here, not available_models from manager in UI context
                    if model_key in selected_models and model_key in model_manager.available_models and model_manager.available_models[model_key]['selected']: # consider only selected models from checkbox group and available_models
                        score = result.get(model_key)
                        if score is not None:
                            scores_to_average.append(score)
                final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None
                result['final_score'] = final_score

            sorted_results = model_manager.sort_results(current_results, "Final Score") # Keep sorting by Final Score when models change
            return model_manager.generate_html_table(sorted_results, selected_models), sorted_results


        def clear_results():
            return (gr.update(value=""),
                    gr.update(value=""),
                    gr.update(value=""),
                    gr.update(value="""
                    <div style='width:100%;background-color:#ddd;'>
                      <div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div>
                    </div>
                    """),
                    gr.update(value=1),
                    []) # Clear results state

        def download_results_csv_trigger(selected_models, current_results): # Changed function name to avoid conflict and clarify purpose
            csv_content = results_to_csv(current_results, selected_models)
            if csv_content is None:
                return None # Indicate no file to download

            # Create a temporary file to save the CSV data
            with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file:
                tmp_file.write(csv_content.encode())
                temp_file_path = tmp_file.name # Get the path to the temporary file

            return temp_file_path # Return the path to the temporary file


        # Set initial selection state for models in ModelManager (important!)
        for model_key in model_options:
            model_manager.available_models[model_key]['selected'] = True # Default to all selected initially

        auto_batch_checkbox.change(
            update_batch_size_interactivity,
            inputs=[auto_batch_checkbox],
            outputs=[batch_size_input]
        )

        process_btn.click(
            process_images_and_update,
            inputs=[input_images, auto_batch_checkbox, batch_size_input, model_checkboxes, global_results_state],
            outputs=[status_html, output_html, log_window, progress_bar, batch_size_input, global_results_state]
        )
        sort_dropdown.change(
            update_table_sort,
            inputs=[sort_dropdown, model_checkboxes, global_results_state],
            outputs=[output_html, global_results_state]
        )
        model_checkboxes.change( # Added change event for model checkboxes
            update_table_model_selection,
            inputs=[model_checkboxes, global_results_state],
            outputs=[output_html, global_results_state]
        )
        clear_btn.click(
            clear_results,
            inputs=[],
            outputs=[status_html, output_html, log_window, progress_bar, batch_size_input, global_results_state]
        )
        download_csv.click(
            download_results_csv_trigger, # Call the trigger function
            inputs=[model_checkboxes, global_results_state],
            outputs=[download_file_output] # Output is now the gr.File component
        )
        demo.load(lambda: update_table_sort("Final Score", model_options, []), inputs=None, outputs=[output_html, global_results_state]) # Initial sort and table render, pass empty initial results
        demo.load(model_manager.start_worker) # Start the worker task on demo load

        gr.Markdown("""
        ### Notes
        - Select models to use for evaluation using the checkboxes.
        - The 'Final Score' recalculates dynamically when models are selected/deselected.
        - The table updates automatically when models are selected/deselected and is always sorted by 'Final Score'.
        - The log window displays the most recent 10 events.
        - The progress bar shows overall processing status.
        - When 'Automatic Batch Size Detection' is enabled, the batch size field becomes disabled.
        - Use the download button to export your evaluation results as CSV.
        """)

    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.queue().launch()