VOIDER's picture
Update app.py
d4c04db verified
import gradio as gr
import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPVisionModel
from PIL import Image
import pandas as pd
from huggingface_hub import hf_hub_download
import os
from collections import OrderedDict
# --- 1. Define the Full Model Architecture ---
class AestheticScorerModel(nn.Module):
def __init__(self, clip_model_id, embedding_dim):
super().__init__()
self.backbone = CLIPVisionModel.from_pretrained(clip_model_id)
self.aesthetic_head = nn.Sequential(nn.Linear(embedding_dim, 1))
self.quality_head = nn.Sequential(nn.Linear(embedding_dim, 1))
self.composition_head = nn.Sequential(nn.Linear(embedding_dim, 1))
self.light_head = nn.Sequential(nn.Linear(embedding_dim, 1))
self.color_head = nn.Sequential(nn.Linear(embedding_dim, 1))
self.dof_head = nn.Sequential(nn.Linear(embedding_dim, 1))
self.content_head = nn.Sequential(nn.Linear(embedding_dim, 1))
def forward(self, pixel_values):
outputs = self.backbone(pixel_values=pixel_values)
embedding = outputs.pooler_output
scores = torch.cat([
self.aesthetic_head(embedding), self.quality_head(embedding),
self.composition_head(embedding), self.light_head(embedding),
self.color_head(embedding), self.dof_head(embedding),
self.content_head(embedding)
], dim=1)
return scores
# --- 2. Model & Processor Loading ---
print("Loading model and processor...")
CACHE_DIR = "hf_cache"
MODEL_REPO_ID = "rsinema/aesthetic-scorer"
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
EMBEDDING_DIM = 768
processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID, cache_dir=CACHE_DIR)
model = AestheticScorerModel(clip_model_id=CLIP_MODEL_ID, embedding_dim=EMBEDDING_DIM)
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="model.pt", cache_dir=CACHE_DIR)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
corrected_state_dict = OrderedDict()
for key, value in state_dict.items():
if key.startswith('backbone.'):
new_key = 'backbone.vision_model.' + key[len('backbone.'):]
corrected_state_dict[new_key] = value
else:
corrected_state_dict[key] = value
model.load_state_dict(corrected_state_dict, strict=False)
model.eval()
print("Model and processor loaded successfully.")
AESTHETIC_CATEGORIES = ["Overall", "Quality", "Composition", "Lighting", "Color", "Depth of Field", "Content"]
TABLE_HEADERS = ["Preview", "File"] + AESTHETIC_CATEGORIES
# --- 3. Core Processing Function ---
def score_images(files):
if not files:
return pd.DataFrame(columns=TABLE_HEADERS)
results_list = []
for file_obj in files:
file_path = file_obj.name
filename = os.path.basename(file_path)
try:
image = Image.open(file_path).convert("RGB")
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt")
scores = model(**inputs)[0]
# --- THIS IS THE CORRECTED LINE ---
# Provide the image data as a (filepath, alt_text) tuple for rendering.
image_scores = {"Preview": (file_path, filename), "File": filename}
for category, score in zip(AESTHETIC_CATEGORIES, scores):
image_scores[category] = f"{score.item():.2f} / 5"
results_list.append(image_scores)
except Exception as e:
print(f"Error processing {filename}: {e}")
error_row = {"Preview": None, "File": filename, **{cat: "Processing Error" for cat in AESTHETIC_CATEGORIES}}
results_list.append(error_row)
return pd.DataFrame(results_list, columns=TABLE_HEADERS)
def clear_all():
return None, None
# --- 4. Gradio Interface using Blocks for Layout Control ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as interface:
gr.Markdown("# Multi-Image Aesthetic Scorer")
gr.Markdown("Upload one or more images to compare their aesthetic scores across seven distinct categories. This application uses the **rsinema/aesthetic-scorer** model.")
with gr.Column():
file_uploader = gr.File(
label="Upload Images",
file_count="multiple",
file_types=["image"],
)
with gr.Row():
clear_button = gr.Button("Clear")
submit_button = gr.Button("Submit", variant="primary")
output_df = gr.Dataframe(
headers=TABLE_HEADERS,
label="Aesthetic Scores Comparison",
interactive=False,
# Specify the data type for the Preview column as 'image'
datatype=["image", "str"] + ["str"] * len(AESTHETIC_CATEGORIES),
column_widths=[15, 35] + [10] * len(AESTHETIC_CATEGORIES)
)
submit_button.click(fn=score_images, inputs=file_uploader, outputs=output_df)
clear_button.click(fn=clear_all, inputs=None, outputs=[file_uploader, output_df])
if __name__ == "__main__":
interface.launch()