Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from transformers import CLIPProcessor, CLIPVisionModel | |
from PIL import Image | |
import pandas as pd | |
from huggingface_hub import hf_hub_download | |
import os | |
from collections import OrderedDict | |
# --- 1. Define the Full Model Architecture --- | |
class AestheticScorerModel(nn.Module): | |
def __init__(self, clip_model_id, embedding_dim): | |
super().__init__() | |
self.backbone = CLIPVisionModel.from_pretrained(clip_model_id) | |
self.aesthetic_head = nn.Sequential(nn.Linear(embedding_dim, 1)) | |
self.quality_head = nn.Sequential(nn.Linear(embedding_dim, 1)) | |
self.composition_head = nn.Sequential(nn.Linear(embedding_dim, 1)) | |
self.light_head = nn.Sequential(nn.Linear(embedding_dim, 1)) | |
self.color_head = nn.Sequential(nn.Linear(embedding_dim, 1)) | |
self.dof_head = nn.Sequential(nn.Linear(embedding_dim, 1)) | |
self.content_head = nn.Sequential(nn.Linear(embedding_dim, 1)) | |
def forward(self, pixel_values): | |
outputs = self.backbone(pixel_values=pixel_values) | |
embedding = outputs.pooler_output | |
scores = torch.cat([ | |
self.aesthetic_head(embedding), self.quality_head(embedding), | |
self.composition_head(embedding), self.light_head(embedding), | |
self.color_head(embedding), self.dof_head(embedding), | |
self.content_head(embedding) | |
], dim=1) | |
return scores | |
# --- 2. Model & Processor Loading --- | |
print("Loading model and processor...") | |
CACHE_DIR = "hf_cache" | |
MODEL_REPO_ID = "rsinema/aesthetic-scorer" | |
CLIP_MODEL_ID = "openai/clip-vit-base-patch32" | |
EMBEDDING_DIM = 768 | |
processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID, cache_dir=CACHE_DIR) | |
model = AestheticScorerModel(clip_model_id=CLIP_MODEL_ID, embedding_dim=EMBEDDING_DIM) | |
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="model.pt", cache_dir=CACHE_DIR) | |
state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
corrected_state_dict = OrderedDict() | |
for key, value in state_dict.items(): | |
if key.startswith('backbone.'): | |
new_key = 'backbone.vision_model.' + key[len('backbone.'):] | |
corrected_state_dict[new_key] = value | |
else: | |
corrected_state_dict[key] = value | |
model.load_state_dict(corrected_state_dict, strict=False) | |
model.eval() | |
print("Model and processor loaded successfully.") | |
AESTHETIC_CATEGORIES = ["Overall", "Quality", "Composition", "Lighting", "Color", "Depth of Field", "Content"] | |
TABLE_HEADERS = ["Preview", "File"] + AESTHETIC_CATEGORIES | |
# --- 3. Core Processing Function --- | |
def score_images(files): | |
if not files: | |
return pd.DataFrame(columns=TABLE_HEADERS) | |
results_list = [] | |
for file_obj in files: | |
file_path = file_obj.name | |
filename = os.path.basename(file_path) | |
try: | |
image = Image.open(file_path).convert("RGB") | |
with torch.no_grad(): | |
inputs = processor(images=image, return_tensors="pt") | |
scores = model(**inputs)[0] | |
# --- THIS IS THE CORRECTED LINE --- | |
# Provide the image data as a (filepath, alt_text) tuple for rendering. | |
image_scores = {"Preview": (file_path, filename), "File": filename} | |
for category, score in zip(AESTHETIC_CATEGORIES, scores): | |
image_scores[category] = f"{score.item():.2f} / 5" | |
results_list.append(image_scores) | |
except Exception as e: | |
print(f"Error processing {filename}: {e}") | |
error_row = {"Preview": None, "File": filename, **{cat: "Processing Error" for cat in AESTHETIC_CATEGORIES}} | |
results_list.append(error_row) | |
return pd.DataFrame(results_list, columns=TABLE_HEADERS) | |
def clear_all(): | |
return None, None | |
# --- 4. Gradio Interface using Blocks for Layout Control --- | |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as interface: | |
gr.Markdown("# Multi-Image Aesthetic Scorer") | |
gr.Markdown("Upload one or more images to compare their aesthetic scores across seven distinct categories. This application uses the **rsinema/aesthetic-scorer** model.") | |
with gr.Column(): | |
file_uploader = gr.File( | |
label="Upload Images", | |
file_count="multiple", | |
file_types=["image"], | |
) | |
with gr.Row(): | |
clear_button = gr.Button("Clear") | |
submit_button = gr.Button("Submit", variant="primary") | |
output_df = gr.Dataframe( | |
headers=TABLE_HEADERS, | |
label="Aesthetic Scores Comparison", | |
interactive=False, | |
# Specify the data type for the Preview column as 'image' | |
datatype=["image", "str"] + ["str"] * len(AESTHETIC_CATEGORIES), | |
column_widths=[15, 35] + [10] * len(AESTHETIC_CATEGORIES) | |
) | |
submit_button.click(fn=score_images, inputs=file_uploader, outputs=output_df) | |
clear_button.click(fn=clear_all, inputs=None, outputs=[file_uploader, output_df]) | |
if __name__ == "__main__": | |
interface.launch() |