|
import os |
|
import shutil |
|
import tempfile |
|
import base64 |
|
import asyncio |
|
from io import BytesIO |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import onnxruntime as rt |
|
from PIL import Image |
|
import gradio as gr |
|
from transformers import pipeline |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip |
|
|
|
|
|
|
|
|
|
|
|
|
|
class MLP(torch.nn.Module): |
|
"""A simple multi-layer perceptron for image feature regression.""" |
|
def __init__(self, input_size: int, batch_norm: bool = True): |
|
super().__init__() |
|
self.input_size = input_size |
|
self.layers = torch.nn.Sequential( |
|
torch.nn.Linear(self.input_size, 2048), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(2048) if batch_norm else torch.nn.Identity(), |
|
torch.nn.Dropout(0.3), |
|
torch.nn.Linear(2048, 512), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(512) if batch_norm else torch.nn.Identity(), |
|
torch.nn.Dropout(0.3), |
|
torch.nn.Linear(512, 256), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(256) if batch_norm else torch.nn.Identity(), |
|
torch.nn.Dropout(0.2), |
|
torch.nn.Linear(256, 128), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(128) if batch_norm else torch.nn.Identity(), |
|
torch.nn.Dropout(0.1), |
|
torch.nn.Linear(128, 32), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(32, 1) |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.layers(x) |
|
|
|
|
|
class WaifuScorer: |
|
"""WaifuScorer model that uses CLIP for feature extraction and a custom MLP for scoring.""" |
|
def __init__(self, model_path: str = None, device: str = 'cuda', cache_dir: str = None, verbose: bool = False): |
|
self.verbose = verbose |
|
self.device = device |
|
self.dtype = torch.float32 |
|
self.available = False |
|
|
|
try: |
|
import clip |
|
|
|
if model_path is None: |
|
model_path = "Eugeoter/waifu-scorer-v3/model.pth" |
|
if self.verbose: |
|
print(f"Model path not provided. Using default: {model_path}") |
|
|
|
|
|
if not os.path.isfile(model_path): |
|
username, repo_id, model_name = model_path.split("/")[-3:] |
|
model_path = hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=cache_dir) |
|
|
|
if self.verbose: |
|
print(f"Loading WaifuScorer model from: {model_path}") |
|
|
|
|
|
self.mlp = MLP(input_size=768) |
|
|
|
if model_path.endswith(".safetensors"): |
|
from safetensors.torch import load_file |
|
state_dict = load_file(model_path) |
|
else: |
|
state_dict = torch.load(model_path, map_location=device) |
|
self.mlp.load_state_dict(state_dict) |
|
self.mlp.to(device) |
|
self.mlp.eval() |
|
|
|
|
|
self.clip_model, self.preprocess = clip.load("ViT-L/14", device=device) |
|
self.available = True |
|
except Exception as e: |
|
print(f"Unable to initialize WaifuScorer: {e}") |
|
|
|
@torch.no_grad() |
|
def __call__(self, images): |
|
if not self.available: |
|
return [None] * (len(images) if isinstance(images, list) else 1) |
|
if isinstance(images, Image.Image): |
|
images = [images] |
|
n = len(images) |
|
|
|
if n == 1: |
|
images = images * 2 |
|
|
|
image_tensors = [self.preprocess(img).unsqueeze(0) for img in images] |
|
image_batch = torch.cat(image_tensors).to(self.device) |
|
image_features = self.clip_model.encode_image(image_batch) |
|
|
|
norm = image_features.norm(2, dim=-1, keepdim=True) |
|
norm[norm == 0] = 1 |
|
im_emb = (image_features / norm).to(device=self.device, dtype=self.dtype) |
|
predictions = self.mlp(im_emb) |
|
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist() |
|
return scores[:n] |
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_aesthetic_predictor_v2_5(): |
|
"""Load and return an instance of Aesthetic Predictor V2.5 with batch processing support.""" |
|
class AestheticPredictorV2_5_Impl: |
|
def __init__(self): |
|
print("Loading Aesthetic Predictor V2.5...") |
|
self.model, self.preprocessor = convert_v2_5_from_siglip( |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
) |
|
if torch.cuda.is_available(): |
|
self.model = self.model.to(torch.bfloat16).cuda() |
|
|
|
def inference(self, image): |
|
if isinstance(image, list): |
|
images_rgb = [img.convert("RGB") for img in image] |
|
pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt").pixel_values |
|
if torch.cuda.is_available(): |
|
pixel_values = pixel_values.to(torch.bfloat16).cuda() |
|
with torch.inference_mode(): |
|
scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy() |
|
if scores.ndim == 0: |
|
scores = np.array([scores]) |
|
return scores.tolist() |
|
else: |
|
pixel_values = self.preprocessor(images=image.convert("RGB"), return_tensors="pt").pixel_values |
|
if torch.cuda.is_available(): |
|
pixel_values = pixel_values.to(torch.bfloat16).cuda() |
|
with torch.inference_mode(): |
|
score = self.model(pixel_values).logits.squeeze().float().cpu().numpy() |
|
return score |
|
|
|
return AestheticPredictorV2_5_Impl() |
|
|
|
|
|
def load_anime_aesthetic_model(): |
|
"""Load and return the Anime Aesthetic ONNX model.""" |
|
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx") |
|
return rt.InferenceSession(model_path, providers=['CPUExecutionProvider']) |
|
|
|
|
|
def predict_anime_aesthetic(img, model): |
|
"""Predict Anime Aesthetic score for a single image.""" |
|
img_np = np.array(img).astype(np.float32) / 255.0 |
|
s = 768 |
|
h, w = img_np.shape[:2] |
|
if h > w: |
|
new_h, new_w = s, int(s * w / h) |
|
else: |
|
new_h, new_w = int(s * h / w), s |
|
resized = cv2.resize(img_np, (new_w, new_h)) |
|
|
|
canvas = np.zeros((s, s, 3), dtype=np.float32) |
|
pad_h = (s - new_h) // 2 |
|
pad_w = (s - new_w) // 2 |
|
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized |
|
|
|
input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :] |
|
pred = model.run(None, {"img": input_tensor})[0].item() |
|
return pred |
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelManager: |
|
"""Manages model loading and processing requests using a queue.""" |
|
def __init__(self): |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
print(f"Using device: {self.device}") |
|
print("Loading models... This may take some time.") |
|
|
|
|
|
print("Loading Aesthetic Shadow model...") |
|
self.aesthetic_shadow_model = pipeline("image-classification", model="NeoChen1024/aesthetic-shadow-v2-backup", device=self.device) |
|
print("Loading Waifu Scorer model...") |
|
self.waifu_scorer_model = WaifuScorer(device=self.device, verbose=True) |
|
print("Loading Aesthetic Predictor V2.5...") |
|
self.aesthetic_predictor_model = load_aesthetic_predictor_v2_5() |
|
print("Loading Anime Aesthetic model...") |
|
self.anime_aesthetic_model = load_anime_aesthetic_model() |
|
print("All models loaded successfully!") |
|
|
|
self.available_models = { |
|
"aesthetic_shadow": {"name": "Aesthetic Shadow", "process": self._process_aesthetic_shadow, "model": self.aesthetic_shadow_model}, |
|
"waifu_scorer": {"name": "Waifu Scorer", "process": self._process_waifu_scorer, "model": self.waifu_scorer_model}, |
|
"aesthetic_predictor_v2_5": {"name": "Aesthetic V2.5", "process": self._process_aesthetic_predictor_v2_5, "model": self.aesthetic_predictor_model}, |
|
"anime_aesthetic": {"name": "Anime Score", "process": self._process_anime_aesthetic, "model": self.anime_aesthetic_model}, |
|
} |
|
self.processing_queue: asyncio.Queue = asyncio.Queue() |
|
self.worker_task = None |
|
self.temp_dir = tempfile.mkdtemp() |
|
|
|
async def start_worker(self): |
|
"""Start the background worker task.""" |
|
if self.worker_task is None: |
|
self.worker_task = asyncio.create_task(self._worker()) |
|
|
|
async def _worker(self): |
|
"""Background worker to process image evaluation requests from the queue.""" |
|
while True: |
|
request = await self.processing_queue.get() |
|
if request is None: |
|
self.processing_queue.task_done() |
|
break |
|
try: |
|
results = await self._process_request(request) |
|
request['results_future'].set_result(results) |
|
except Exception as e: |
|
request['results_future'].set_exception(e) |
|
finally: |
|
self.processing_queue.task_done() |
|
|
|
async def submit_request(self, request_data): |
|
"""Submit a new image processing request to the queue.""" |
|
results_future = asyncio.Future() |
|
request = {**request_data, 'results_future': results_future} |
|
await self.processing_queue.put(request) |
|
return await results_future |
|
|
|
async def _process_request(self, request): |
|
"""Process a single image evaluation request.""" |
|
file_paths = request['file_paths'] |
|
auto_batch = request['auto_batch'] |
|
manual_batch_size = request['manual_batch_size'] |
|
selected_models = request['selected_models'] |
|
log_events = [] |
|
images = [] |
|
file_names = [] |
|
final_results = [] |
|
|
|
|
|
total_files = len(file_paths) |
|
log_events.append(f"Starting to load {total_files} images...") |
|
for f in file_paths: |
|
try: |
|
img = Image.open(f).convert("RGB") |
|
images.append(img) |
|
file_names.append(os.path.basename(f)) |
|
except Exception as e: |
|
log_events.append(f"Error opening {f}: {e}") |
|
|
|
if not images: |
|
log_events.append("No valid images loaded.") |
|
return [], log_events, 0, manual_batch_size |
|
|
|
log_events.append("Images loaded. Determining batch size...") |
|
|
|
try: |
|
manual_batch_size = int(manual_batch_size) if manual_batch_size is not None else 1 |
|
except ValueError: |
|
manual_batch_size = 1 |
|
log_events.append("Invalid manual batch size. Defaulting to 1.") |
|
|
|
optimal_batch = self.auto_tune_batch_size(images) if auto_batch else manual_batch_size |
|
log_events.append(f"Using batch size: {optimal_batch}") |
|
|
|
total_images = len(images) |
|
for i in range(0, total_images, optimal_batch): |
|
batch_images = images[i:i+optimal_batch] |
|
batch_file_names = file_names[i:i+optimal_batch] |
|
batch_index = i // optimal_batch + 1 |
|
log_events.append(f"Processing batch {batch_index}: images {i+1} to {min(i+optimal_batch, total_images)}") |
|
|
|
batch_results = {} |
|
|
|
|
|
for model_key in selected_models: |
|
if self.available_models[model_key]['selected']: |
|
batch_results[model_key] = await self.available_models[model_key]['process'](batch_images, log_events) |
|
else: |
|
batch_results[model_key] = [None] * len(batch_images) |
|
|
|
|
|
for j in range(len(batch_images)): |
|
scores_to_average = [] |
|
for model_key in selected_models: |
|
if self.available_models[model_key]['selected']: |
|
score = batch_results[model_key][j] |
|
if score is not None: |
|
scores_to_average.append(score) |
|
|
|
final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None |
|
thumbnail = batch_images[j].copy() |
|
thumbnail.thumbnail((200, 200)) |
|
result = { |
|
'file_name': batch_file_names[j], |
|
'img_data': self.image_to_base64(thumbnail), |
|
'final_score': final_score, |
|
} |
|
for model_key in selected_models: |
|
if self.available_models[model_key]['selected']: |
|
result[model_key] = batch_results[model_key][j] |
|
final_results.append(result) |
|
|
|
log_events.append("All images processed.") |
|
return final_results, log_events, 100, optimal_batch |
|
|
|
|
|
def image_to_base64(self, image: Image.Image) -> str: |
|
"""Convert PIL Image to base64 encoded JPEG string.""" |
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
return base64.b64encode(buffered.getvalue()).decode('utf-8') |
|
|
|
def auto_tune_batch_size(self, images: list) -> int: |
|
"""Automatically determine the optimal batch size for processing.""" |
|
batch_size = 1 |
|
max_batch = len(images) |
|
test_image = images[0:1] |
|
while batch_size <= max_batch: |
|
try: |
|
if "aesthetic_shadow" in self.available_models and self.available_models["aesthetic_shadow"]['selected']: |
|
_ = self.available_models["aesthetic_shadow"]['model'](test_image * batch_size) |
|
if "waifu_scorer" in self.available_models and self.available_models["waifu_scorer"]['selected']: |
|
_ = self.available_models["waifu_scorer"]['model'](test_image * batch_size) |
|
if "aesthetic_predictor_v2_5" in self.available_models and self.available_models["aesthetic_predictor_v2_5"]['selected']: |
|
_ = self.available_models["aesthetic_predictor_v2_5"]['model'].inference(test_image * batch_size) |
|
batch_size *= 2 |
|
if batch_size > max_batch: |
|
break |
|
except Exception: |
|
break |
|
optimal = max(1, batch_size // 2) |
|
if optimal > 64: |
|
optimal = 64 |
|
print(f"Optimal batch size determined: {optimal}") |
|
print(f"Optimal batch size determined: {optimal}") |
|
return optimal |
|
|
|
async def _process_aesthetic_shadow(self, batch_images, log_events): |
|
try: |
|
shadow_results = self.available_models["aesthetic_shadow"]['model'](batch_images) |
|
log_events.append("Aesthetic Shadow processed for batch.") |
|
except Exception as e: |
|
log_events.append(f"Error in Aesthetic Shadow: {e}") |
|
shadow_results = [None] * len(batch_images) |
|
aesthetic_shadow_scores = [] |
|
for res in shadow_results: |
|
try: |
|
hq_score = next(p for p in res if p['label'] == 'hq')['score'] |
|
score = float(np.clip(hq_score * 10.0, 0.0, 10.0)) |
|
except Exception: |
|
score = None |
|
aesthetic_shadow_scores.append(score) |
|
log_events.append("Aesthetic Shadow scores computed for batch.") |
|
return aesthetic_shadow_scores |
|
|
|
async def _process_waifu_scorer(self, batch_images, log_events): |
|
try: |
|
waifu_scores = self.available_models["waifu_scorer"]['model'](batch_images) |
|
waifu_scores = [float(np.clip(s, 0.0, 10.0)) if s is not None else None for s in waifu_scores] |
|
log_events.append("Waifu Scorer processed for batch.") |
|
except Exception as e: |
|
log_events.append(f"Error in Waifu Scorer: {e}") |
|
waifu_scores = [None] * len(batch_images) |
|
return waifu_scores |
|
|
|
async def _process_aesthetic_predictor_v2_5(self, batch_images, log_events): |
|
try: |
|
v2_5_scores = self.available_models["aesthetic_predictor_v2_5"]['model'].inference(batch_images) |
|
v2_5_scores = [float(np.round(np.clip(s, 0.0, 10.0), 4)) if s is not None else None for s in v2_5_scores] |
|
log_events.append("Aesthetic Predictor V2.5 processed for batch.") |
|
except Exception as e: |
|
log_events.append(f"Error in Aesthetic Predictor V2.5: {e}") |
|
v2_5_scores = [None] * len(batch_images) |
|
return v2_5_scores |
|
|
|
async def _process_anime_aesthetic(self, batch_images, log_events): |
|
anime_scores = [] |
|
for j, img in enumerate(batch_images): |
|
try: |
|
score = predict_anime_aesthetic(img, self.available_models["anime_aesthetic"]['model']) |
|
anime_scores.append(float(np.clip(score * 10.0, 0.0, 10.0))) |
|
log_events.append(f"Anime Aesthetic processed for image {j + 1}.") |
|
except Exception as e: |
|
log_events.append(f"Error in Anime Aesthetic for image {j + 1}: {e}") |
|
anime_scores.append(None) |
|
return anime_scores |
|
|
|
|
|
def _generate_progress_html(self, percentage: float) -> str: |
|
"""Generate HTML for a progress bar given a percentage.""" |
|
return f""" |
|
<div style="width:100%;background-color:#ddd; border-radius:5px;"> |
|
<div style="width:{percentage:.1f}%; background-color:#4CAF50; text-align:center; padding:5px 0; border-radius:5px;"> |
|
{percentage:.1f}% |
|
</div> |
|
</div> |
|
""" |
|
|
|
def _format_logs(self, logs: list) -> str: |
|
"""Format log events into an HTML string.""" |
|
return "<div style='max-height:300px; overflow-y:auto;'>" + "<br>".join(logs) + "</div>" |
|
|
|
def sort_results(self, results, sort_by: str = "Final Score") -> list: |
|
"""Sort results based on the specified column.""" |
|
key_map = { |
|
"Final Score": "final_score", |
|
"File Name": "file_name", |
|
"Aesthetic Shadow": "aesthetic_shadow", |
|
"Waifu Scorer": "waifu_scorer", |
|
"Aesthetic V2.5": "aesthetic_predictor_v2_5", |
|
"Anime Score": "anime_aesthetic" |
|
} |
|
key = key_map.get(sort_by, "final_score") |
|
reverse = sort_by != "File Name" |
|
results.sort(key=lambda r: r.get(key) if r.get(key) is not None else (-float('inf') if not reverse else float('inf')), reverse=reverse) |
|
return results |
|
|
|
def generate_html_table(self, results: list, selected_models) -> str: |
|
"""Generate an HTML table to display the evaluation results.""" |
|
table_html = """ |
|
<style> |
|
.results-table { width: 100%; border-collapse: collapse; margin: 20px 0; font-family: Arial, sans-serif; } |
|
.results-table th, .results-table td { color: #eee; border: 1px solid #ddd; padding: 8px; text-align: center; } |
|
.results-table th { font-weight: bold; } |
|
.results-table tr:nth-child(even) { background-color: transparent; } |
|
.results-table tr:hover { background-color: rgba(255, 255, 255, 0.1); } |
|
.image-preview { max-width: 150px; max-height: 150px; display: block; margin: 0 auto; } |
|
.good-score { color: #0f0; font-weight: bold; } |
|
.bad-score { color: #f00; font-weight: bold; } |
|
.medium-score { color: orange; font-weight: bold; } |
|
</style> |
|
<table class="results-table"> |
|
<thead> |
|
<tr> |
|
<th>Image</th> |
|
<th>File Name</th> |
|
""" |
|
visible_models = [] |
|
if "aesthetic_shadow" in selected_models: |
|
table_html += "<th>Aesthetic Shadow</th>" |
|
visible_models.append("aesthetic_shadow") |
|
if "waifu_scorer" in selected_models: |
|
table_html += "<th>Waifu Scorer</th>" |
|
visible_models.append("waifu_scorer") |
|
if "aesthetic_predictor_v2_5" in selected_models: |
|
table_html += "<th>Aesthetic V2.5</th>" |
|
visible_models.append("aesthetic_predictor_v2_5") |
|
if "anime_aesthetic" in selected_models: |
|
table_html += "<th>Anime Score</th>" |
|
visible_models.append("anime_aesthetic") |
|
table_html += "<th>Final Score</th>" |
|
table_html += "</tr></thead><tbody>" |
|
|
|
for result in results: |
|
table_html += "<tr>" |
|
table_html += f'<td><img src="data:image/jpeg;base64,{result["img_data"]}" class="image-preview"></td>' |
|
table_html += f'<td>{result["file_name"]}</td>' |
|
for model_key in visible_models: |
|
score = result.get(model_key) |
|
table_html += self._format_score_cell(score) |
|
|
|
score = result.get("final_score") |
|
table_html += self._format_score_cell(score) |
|
table_html += "</tr>" |
|
table_html += """</tbody></table>""" |
|
return table_html |
|
|
|
def _format_score_cell(self, score): |
|
score_str = f"{score:.4f}" if isinstance(score, (int, float)) else "N/A" |
|
score_class = "" |
|
if isinstance(score, (int, float)): |
|
if score >= 7: |
|
score_class = "good-score" |
|
elif score >= 5: |
|
score_class = "medium-score" |
|
else: |
|
score_class = "bad-score" |
|
return f'<td class="{score_class}">{score_str}</td>' |
|
|
|
|
|
def cleanup(self): |
|
"""Clean up temporary directories and shutdown worker.""" |
|
if os.path.exists(self.temp_dir): |
|
shutil.rmtree(self.temp_dir) |
|
if self.worker_task is not None: |
|
asyncio.run(self.shutdown()) |
|
|
|
async def shutdown(self): |
|
"""Send shutdown signal to worker and wait for it to finish.""" |
|
if self.worker_task is not None: |
|
await self.processing_queue.put(None) |
|
await self.worker_task |
|
await self.processing_queue.join() |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_manager = ModelManager() |
|
|
|
def create_interface(): |
|
sort_options = ["Final Score", "File Name", "Aesthetic Shadow", "Waifu Scorer", "Aesthetic V2.5", "Anime Score"] |
|
model_options = ["aesthetic_shadow", "waifu_scorer", "aesthetic_predictor_v2_5", "anime_aesthetic"] |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# Comprehensive Image Evaluation Tool |
|
|
|
Upload images to evaluate them using multiple aesthetic and quality prediction models. |
|
|
|
**New features:** |
|
- **Dynamic Final Score:** Final score recalculates on model selection changes. |
|
- **Model Selection:** Choose which models to use for evaluation. |
|
- **Dynamic Table Updates:** Table updates automatically based on model selection. |
|
- **Automatic Sorting:** Table is automatically sorted by 'Final Score'. |
|
- **Detailed Logs:** See major processing events (limited to the last 10). |
|
- **Progress Bar:** Visual indication of processing status. |
|
- **Asynchronous Updates:** Streaming status and logs during processing. |
|
- **Batch Size Controls:** Choose manual batch size or let the tool auto-detect it. |
|
- **Download Results:** Export the evaluation results as CSV. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_images = gr.Files(label="Upload Images", file_count="multiple") |
|
model_checkboxes = gr.CheckboxGroup(model_options, label="Select Models", value=model_options, info="Choose models for evaluation.") |
|
auto_batch_checkbox = gr.Checkbox(label="Automatic Batch Size Detection", value=False, info="Enable to automatically determine the optimal batch size.") |
|
batch_size_input = gr.Number(label="Batch Size", value=1, interactive=True, info="Manually specify the batch size if auto-detection is disabled.") |
|
sort_dropdown = gr.Dropdown(sort_options, value="Final Score", label="Sort by", info="Select the column to sort results by.") |
|
process_btn = gr.Button("Evaluate Images", variant="primary") |
|
clear_btn = gr.Button("Clear Results") |
|
download_csv = gr.Button("Download CSV", variant="secondary") |
|
|
|
with gr.Column(scale=2): |
|
progress_bar = gr.HTML(label="Progress Bar", value=""" |
|
<div style='width:100%;background-color:#ddd;'> |
|
<div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div> |
|
</div> |
|
""") |
|
log_window = gr.HTML(label="Detailed Logs", value="<div style='max-height:300px; overflow-y:auto;'>Logs will appear here...</div>") |
|
status_html = gr.HTML(label="Status") |
|
output_html = gr.HTML(label="Evaluation Results") |
|
download_file_output = gr.File() |
|
global_results_state = gr.State([]) |
|
|
|
|
|
def results_to_csv(results, selected_models): |
|
import csv |
|
import io |
|
if not results: |
|
return None |
|
output = io.StringIO() |
|
fieldnames = ['file_name', 'final_score'] |
|
for model_key in selected_models: |
|
if model_key in selected_models: |
|
fieldnames.append(model_key) |
|
|
|
writer = csv.DictWriter(output, fieldnames=fieldnames) |
|
writer.writeheader() |
|
for res in results: |
|
row_dict = {'file_name': res['file_name'], 'final_score': res['final_score']} |
|
for model_key in selected_models: |
|
if model_key in selected_models: |
|
row_dict[model_key] = res.get(model_key, 'N/A') |
|
writer.writerow(row_dict) |
|
return output.getvalue() |
|
|
|
|
|
def update_batch_size_interactivity(auto_batch): |
|
return gr.update(interactive=not auto_batch) |
|
|
|
async def process_images_and_update(files, auto_batch, manual_batch, selected_models, current_results): |
|
file_paths = [f.name for f in files] |
|
|
|
|
|
request_data = { |
|
'file_paths': file_paths, |
|
'auto_batch': auto_batch, |
|
'manual_batch_size': manual_batch, |
|
'selected_models': {model: {'selected': model in selected_models} for model in model_options} |
|
} |
|
|
|
results, logs, progress_percent, updated_batch = await model_manager.submit_request(request_data) |
|
|
|
updated_results = current_results + results |
|
|
|
html_table = model_manager.generate_html_table(updated_results, selected_models) |
|
progress_html = model_manager._generate_progress_html(progress_percent) |
|
log_html = model_manager._format_logs(logs[-10:]) |
|
|
|
return status_html, html_table, log_html, progress_html, gr.update(value=updated_batch, interactive=not auto_batch), updated_results |
|
|
|
|
|
def update_table_sort(sort_by_column, selected_models, current_results): |
|
sorted_results = model_manager.sort_results(current_results, sort_by_column) |
|
return model_manager.generate_html_table(sorted_results, selected_models), sorted_results |
|
|
|
def update_table_model_selection(selected_models, current_results): |
|
|
|
for result in current_results: |
|
scores_to_average = [] |
|
for model_key in model_options: |
|
if model_key in selected_models and model_key in model_manager.available_models and model_manager.available_models[model_key]['selected']: |
|
score = result.get(model_key) |
|
if score is not None: |
|
scores_to_average.append(score) |
|
final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None |
|
result['final_score'] = final_score |
|
|
|
sorted_results = model_manager.sort_results(current_results, "Final Score") |
|
return model_manager.generate_html_table(sorted_results, selected_models), sorted_results |
|
|
|
|
|
def clear_results(): |
|
return (gr.update(value=""), |
|
gr.update(value=""), |
|
gr.update(value=""), |
|
gr.update(value=""" |
|
<div style='width:100%;background-color:#ddd;'> |
|
<div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div> |
|
</div> |
|
"""), |
|
gr.update(value=1), |
|
[]) |
|
|
|
def download_results_csv_trigger(selected_models, current_results): |
|
csv_content = results_to_csv(current_results, selected_models) |
|
if csv_content is None: |
|
return None |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file: |
|
tmp_file.write(csv_content.encode()) |
|
temp_file_path = tmp_file.name |
|
|
|
return temp_file_path |
|
|
|
|
|
|
|
for model_key in model_options: |
|
model_manager.available_models[model_key]['selected'] = True |
|
|
|
auto_batch_checkbox.change( |
|
update_batch_size_interactivity, |
|
inputs=[auto_batch_checkbox], |
|
outputs=[batch_size_input] |
|
) |
|
|
|
process_btn.click( |
|
process_images_and_update, |
|
inputs=[input_images, auto_batch_checkbox, batch_size_input, model_checkboxes, global_results_state], |
|
outputs=[status_html, output_html, log_window, progress_bar, batch_size_input, global_results_state] |
|
) |
|
sort_dropdown.change( |
|
update_table_sort, |
|
inputs=[sort_dropdown, model_checkboxes, global_results_state], |
|
outputs=[output_html, global_results_state] |
|
) |
|
model_checkboxes.change( |
|
update_table_model_selection, |
|
inputs=[model_checkboxes, global_results_state], |
|
outputs=[output_html, global_results_state] |
|
) |
|
clear_btn.click( |
|
clear_results, |
|
inputs=[], |
|
outputs=[status_html, output_html, log_window, progress_bar, batch_size_input, global_results_state] |
|
) |
|
download_csv.click( |
|
download_results_csv_trigger, |
|
inputs=[model_checkboxes, global_results_state], |
|
outputs=[download_file_output] |
|
) |
|
demo.load(lambda: update_table_sort("Final Score", model_options, []), inputs=None, outputs=[output_html, global_results_state]) |
|
demo.load(model_manager.start_worker) |
|
|
|
gr.Markdown(""" |
|
### Notes |
|
- Select models to use for evaluation using the checkboxes. |
|
- The 'Final Score' recalculates dynamically when models are selected/deselected. |
|
- The table updates automatically when models are selected/deselected and is always sorted by 'Final Score'. |
|
- The log window displays the most recent 10 events. |
|
- The progress bar shows overall processing status. |
|
- When 'Automatic Batch Size Detection' is enabled, the batch size field becomes disabled. |
|
- Use the download button to export your evaluation results as CSV. |
|
""") |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.queue().launch() |